Search
Duplicate

AI/ Multiplicative Interaction

Multiplicative Interaction

Multiplicative Interaction이란 입력 간의 상호작용(이것을 곱셈으로 사용)을 이용해 모델링 하는 방법을 의미한다. 예컨대 2개의 입력에 대한 가법적(additive) 모델이 다음과 같이 정의된다고 하자.
y=w1x1+w2x2+by = w_1 x_1 + w_2x_2 + b
여기서 w1,w2w_1, w_2는 가중치이고 bb는 편향이다. 이것은 x1x_1x2x_2가 독립적으로 yy에 영향을 미친다는 가정에 기반한 것이다. 위 모델에 대해 multiplicative interaction이 추가된 식은 다음과 같다.
y=w1x1+w2x2+w3(x1x2)+by = w_1 x_1 + w_2x_2 + w_3(x_1 \cdot x_2) + b
이런 식으로 입력 간의 곱셈 상호작용을 이용하여 더 복잡한 패턴을 모델링할 수 있다.
단순 곱을 WX\bold{WX}로 표현 하는 것과 달리 Multiplication은 W(X)\bold{W}(\bold{X})와 같이 표현된다. Attention은 다음과 같이 정의되는 Multiplicative Interaction이다.
Z=φ(XW(X))\bold{Z} = \varphi(\bold{XW}(\bold{X}))
일반적으로 많이 사용되는 query Q\bold{Q}, key K\bold{K}, value V\bold{V}를 이용한 Attention은 다음처럼 정의할 수 있다.
Z=φ(VW(Q,K))\bold{Z} = \varphi(\bold{VW}(\bold{Q}, \bold{K}))

Example

입력 X\bold{X}에 따라 가중치 W\bold{W}가 조정되도록 곱셈 상호작용 W(X)\bold{W}(\bold{X})을 다음과 같이 계산할 수 있다. 우선 입력 X\bold{X}와 query와 key에 대한 가중치 행렬 Wq,Wk\bold{W}_q, \bold{W}_k가 각각 다음과 같다고 하자.
X=[1234],Wq=[1001],Wk=[0110]\bold{X} = \begin{bmatrix} 1 & 2 \\ 3 & 4 \end{bmatrix}, \bold{W}_q = \begin{bmatrix} 1 & 0 \\ 0 & 1 \end{bmatrix}, \bold{W}_k = \begin{bmatrix} 0 & 1 \\ 1 & 0 \end{bmatrix}
이에 대해 query Q\bold{Q}와 key K\bold{K}를 각각 다음과 같이 구할 수 있다.
Q=XWq=[1234][1001]=[1234]K=XWk=[1234][0110]=[2143]\bold{Q} = \bold{XW}_q = \begin{bmatrix} 1 & 2 \\ 3 & 4 \end{bmatrix} \begin{bmatrix} 1 & 0 \\ 0 & 1 \end{bmatrix} = \begin{bmatrix} 1 & 2 \\ 3 & 4 \end{bmatrix} \\ \bold{K} = \bold{XW}_k = \begin{bmatrix} 1 & 2 \\ 3 & 4 \end{bmatrix} \begin{bmatrix} 0 & 1 \\ 1 & 0 \end{bmatrix} = \begin{bmatrix} 2 & 1 \\ 4 & 3 \end{bmatrix}
Q\bold{Q}K\bold{K}를 이용하여 W(X)\bold{W}(\bold{X})을 계산할 수 있다. 일반적인 attention에서 많이 사용되는 것과 같이 내적을 사용한다면 다음과 같다.
W(X)=QK=[1234][2413]=[4101025]\bold{W}(\bold{X}) = \bold{QK}^\top = \begin{bmatrix} 1 & 2 \\ 3 & 4 \end{bmatrix} \begin{bmatrix} 2 & 4 \\ 1 & 3 \end{bmatrix} = \begin{bmatrix} 4 & 10 \\ 10 & 25 \end{bmatrix}
이 결과는 각 입력 벡터가 다른 입력 벡터에 얼마나 attention 해야 하는지를 나타내는 가중치 행렬이 된다. 일반적인 attention에서 이 결과를 softmax 함수에 통과시킨 후에 다시 value V=XWv\bold{V} = \bold{XW}_v와 곱해 출력 Z\bold{Z}를 계산한다.

참고