Linear Attention
다음과 같이 정의되는 일반적인 Attention 메커니즘은 입력 시퀀스의 길이 에 대해 의 계산 시간을 갖는다는 단점을 갖는다.
이를 해결하기 위한 많은 방법들이 존재했는데, 그 중에서 Linear Attention은 일반적인 Attention 함수에서 softmax 연산을 제거하고 Attention 가중치 계산을 선형 형태로 변환하여 계산 속도를 으로 줄이는 방법이다.
여기서 는 비선형 함수로 ReLU나 ELU 등을 사용할 수 있다.
속도를 높일 수 있다는 장점이 있지만 정확도 손실이 발생하고 다양한 종류의 데이터에 대해 동일한 성능을 보장하지 않는다.
참고로 Flash Attention은 기존 Attention 구조를 변경하지 않고 하드웨어의 구조를 고려하여 소프트웨어적으로 계산 속도를 높인 방법이다.
Example
Linear Attention의 구현 예. 여기서 비선형 함수를 ReLU로 사용하였음
class LinearAttention(nn.Module):
def __init__(self, input_dim):
super(LinearAttention, self).__init__()
self.input_dim = input_dim
self.query = nn.Linear(input_dim, input_dim)
self.key = nn.Linear(input_dim, input_dim)
self.value = nn.Linear(input_dim, input_dim)
self.phi = nn.ReLU() # 비선형 함수로 ReLU 사용
def forward(self, x):
Q = self.phi(self.query(x))
K = self.phi(self.key(x))
V = self.value(x)
KV = torch.matmul(K.transpose(-2, -1), V)
attention = torch.matmul(Q, KV)
return attention
Python
복사