Search
Duplicate

AI/ Ranking Loss - Pairwise, Triplet, N-Pair, InfoNCE

Ranking Loss는 유사한 예제들이 유사하지 않은 예제들보다 가깝도록 하는 방법이다. 이러한 방법들은 대부분 명시적인 클래스 라벨이 필요하지 않다.

Pairwise (contrastive) loss and Siamese networks

유사/비유사 쌍에서 representation을 배우는 초기 접근은 다음의 contrastive loss를 최소화하는 것에 기반한다.
L(θ;xi,xj)=I(yi=yj)d(xi,xj)2+I(yiyj)[md(xi,xj)]+2\mathcal{L}(\boldsymbol{\theta}; \bold{x}_i, \bold{x}_j) = \mathbb{I}(y_i = y_j) d(\bold{x}_i, \bold{x}_j)^2 + \mathbb{I}(y_i \ne y_j)[m - d(\bold{x}_i, \bold{x}_j)]_+^2
여기서 [z]+=max(0,z)[z]_+ = \max(0, z)는 hinge 손실이고 m>0m > 0은 margin 파라미터이다. 이것은 직관적으로 같은 라벨인 positive 쌍을 가깝게 하고 다른 라벨인 negative 쌍을 어떤 margin 보다 멀리 떨어뜨리기는 방법이다. 데이터의 모든 쌍에 대한 이 손실을 최소화하여 이것은 O(N2)O(N^2) 시간을 갖는다.
두 입력 사이의 거리를 계산할 때 f(;θ)\bold{f}(\cdot;\boldsymbol{\theta}) 같은 feature extractor를 사용한 것을 Siamese Network(이후에 Siamese twins가 됨)이 된다.

Triplet loss

Pairwise loss의 단점은 positive 쌍의 최적화가 negative 쌍에 독립이라 크기를 비교할 수 없다는 것이다. 이에 대한 해결책은 다음처럼 정의되는 triplet loss를 사용하는 것이다.
L(θ;xi,xi+,xi)=[dθ(xi,xi+)2dθ(xi,xi)2+m]+\mathcal{L}(\boldsymbol{\theta}; \bold{x}_i, \bold{x}_i^+, \bold{x}_i^-) = [d_{\boldsymbol{\theta}}(\bold{x}_i, \bold{x}_i^+)^2 - d_{\boldsymbol{\theta}}(\bold{x}_i, \bold{x}_i^-)^2 + m]_+
여기서 anchor라 불리는 각 예제 ii에 대해 positive 예제 xi+\bold{x}_i^+와 negative 예제 xi\bold{x}_i^-를 찾고 전체 triple에 대해 평균을 구하여 다음의 손실을 최소화 한다.
직관적으로 이것은 anchor-positive 사이의 거리가 anchor-negative 사이의 거리보다 어떤 margin 보다 작도록 하는 방법이다. 단순하게 triplet 손실을 최소화하는데는 O(N3)O(N^3)의 시간이 걸린다.

N-pairs loss

Triplet loss의 단점은 각 anchor가 한 번에 하나의 negative와만 비교가 된다는 것이다. 이것은 충분히 강한 학습 신호를 줄 수 없다. 이를 해결하기 위해 각 anchor 별로 1개의 positive와 N-1개의 negative를 갖도록 하여 multi-class 분류 문제를 만들 수 있다. 이것을 N-pair loss라고 하며 다음과 같이 정의된다. 여기서 e^θ()\hat{\bold{e}}_{\boldsymbol{\theta}}()는 입력을 임베딩 공간에 매핑하는 함수이다.
L(θ;x,x+,{xk}k=1N1)=log(1+[k=1N1exp(e^θ(x)e^θ(xk))]exp(e^θ(x)e^θ(x+)))=logexp(e^θ(x)e^θ(x+))exp(e^θ(x)e^θ(x+))+k=1N1exp(e^θ(x)e^θ(xk))\begin{aligned} &\mathcal{L}(\boldsymbol{\theta};\bold{x}, \bold{x}^+, \{\bold{x}_k^-\}_{k=1}^{N-1}) \\&= \log \left( 1 + \left[ \sum_{k=1}^{N-1} \exp (\hat{\bold{e}}_{\boldsymbol{\theta}}(\bold{x})^\top \hat{\bold{e}}_{\boldsymbol{\theta}}(\bold{x}_k^-))\right] - \exp(\hat{\bold{e}}_{\boldsymbol{\theta}}(\bold{x})^\top \hat{\bold{e}}_{\boldsymbol{\theta}}(\bold{x}^+)) \right) \\ &= - \log{\exp(\hat{\bold{e}}_{\boldsymbol{\theta}}(\bold{x})^\top \hat{\bold{e}}_{\boldsymbol{\theta}}(\bold{x}^+)) \over \exp(\hat{\bold{e}}_{\boldsymbol{\theta}}(\bold{x})^\top\hat{\bold{e}}_{\boldsymbol{\theta}}(\bold{x}^+)) + \sum_{k=1}^{N-1} \exp(\hat{\bold{e}}_{\boldsymbol{\theta}}(\bold{x})^\top \hat{\bold{e}}_{\boldsymbol{\theta}}(\bold{x}_k^-))} \end{aligned}
이것은 InfoNCE loss와 유사하다. 아래 설명.

InfoNCE

InfoNCE는 이항분포를 이용한 NCE를 multi-class 분류로 확장한 버전이다. 다만 multi-class로 확장하여 softmax 함수를 사용하므로 정규화 상수를 계산하지 않고 때문에 InfoNCE는 NCE와 달리 Energy-Based Models(EBMs)에는 사용할 수 없다.
InfoNCE의 loss는 기본적으로 N-pair loss와 동일한데, 여기에 temperature 항을 추가하여 확장한 버전을 NT-Xent라고 한다. 여기서 온도 파라미터는 데이터가 존재하는 hypersphere의 반경을 scaling하는 것으로 볼 수 있다.
L(θ;x,x+,{xk}k=1N1)=logexp(e^θ(x)e^θ(x+)/τ)exp(e^θ(x)e^θ(x+)/τ)+k=1N1exp(e^θ(x)e^θ(xk)/τ)\begin{aligned} &\mathcal{L}(\boldsymbol{\theta};\bold{x}, \bold{x}^+, \{\bold{x}_k^-\}_{k=1}^{N-1}) \\&= - \log{\exp(\hat{\bold{e}}_{\boldsymbol{\theta}}(\bold{x})^\top \hat{\bold{e}}_{\boldsymbol{\theta}}(\bold{x}^+)/\tau) \over \exp(\hat{\bold{e}}_{\boldsymbol{\theta}}(\bold{x})^\top\hat{\bold{e}}_{\boldsymbol{\theta}}(\bold{x}^+)/\tau) + \sum_{k=1}^{N-1} \exp(\hat{\bold{e}}_{\boldsymbol{\theta}}(\bold{x})^\top \hat{\bold{e}}_{\boldsymbol{\theta}}(\bold{x}_k^-)/\tau)} \end{aligned}
InfoNCE는 상호정보량(Mutual Information) 측면에서도 다음과 같은 lower bound를 작성할 수 있다. 이 lower bound를 최대화하면 InfoNCE loss 함수를 달성하는 것이 된다.
INCE=E[1Ki=1Klogexp(f(xi,xi+))1Kj=1Kexp(f(xi,xj))]=logKE[1Ki=1Klog(1+jiKexp(f(xi,xj)f(xi,xi+)))]\begin{aligned} \mathbb{I}_\text{NCE} &= \mathbb{E} \left[{1\over K} \sum_{i=1}^K \log {\exp\Big({f(\bold{x}_i,\bold{x}_i^+)\Big)} \over {1\over K}\sum_{j=1}^K \exp\Big({f(\bold{x}_i,\bold{x}_j^-)}\Big)} \right] \\ &= \log K - \mathbb{E} \left[{1\over K} \sum_{i=1}^K \log \left( 1 + \sum_{j\ne i}^K \exp \Big({f(\bold{x}_i,\bold{x}_j^-) - f(\bold{x}_i,\bold{x}_i^+)\Big)} \right) \right] \end{aligned}
여기서 함수 f(,)f(\cdot, \cdot)은 입력과 positive, 입력과 negative 사이의 유사도 함수로 입력과 positive 사이의 유사도를 최대화하고, 입력과 negative 사이의 유사도를 최소화한다.

Sample Code

Model

입력과 positive, negative를 embedding하는 경우 모델은 다음과 같이 정의할 수 있다.
# 임베딩 모델 정의 class EmbeddingModel(nn.Module): def __init__(self, input_dim, embedding_dim): super(EmbeddingModel, self).__init__() self.fc1 = nn.Linear(input_dim, 256) self.fc2 = nn.Linear(256, embedding_dim) def forward(self, x): x = F.relu(self.fc1(x)) x = self.fc2(x) return x
Python
복사

Loss

InfoNCE의 식을 따라 loss는 다음과 같이 구현한다.
# InfoNCE 손실 함수 정의 def info_nce_loss(anchor_emb, positive_emb, negatives_emb, temperature=0.1): # 양성 유사도 계산 pos_sim = torch.exp(torch.matmul(anchor_emb, positive_emb.T) / temperature) # 부정 유사도 계산 neg_sims = torch.exp(torch.matmul(anchor_emb, negatives_emb.T) / temperature) # 유사도 결합 denominator = pos_sim + neg_sims.sum(dim=1, keepdim=True) # 손실 계산 loss = -torch.log(pos_sim / denominator).mean() return loss
Python
복사

Train

학습은 데이터와 positives, negatives를 받아 loss를 계산하여 역전파하는 식으로 구성된다. 유사도 계산은 단순 행렬 곱이므로 embedding 모델이 학습 되게 된다.
# 설정 batch_size = 5 input_dim = 128 embedding_dim = 64 num_negatives = 10 temperature = 0.1 num_epochs = 10 learning_rate = 0.001 # 임베딩 모델 생성 model = EmbeddingModel(input_dim, embedding_dim) optimizer = optim.Adam(model.parameters(), lr=learning_rate) # 임의의 데이터 생성 (원래 차원으로) torch.manual_seed(0) anchors = torch.randn(batch_size, input_dim) positives = torch.randn(batch_size, input_dim) negatives = torch.randn(num_negatives, input_dim) # 학습 for epoch in range(num_epochs): model.train() optimizer.zero_grad() # anchors, positives, negatives에 대해 embedding 수행 anchor_emb = model(anchors) positive_emb = model(positives) negatives_emb = model(negatives) # 손실 함수 계산 loss = info_nce_loss(model, anchors, positives, negatives, temperature) # 역전파 및 최적화 loss.backward() optimizer.step() # 결과 출력 print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item()}") print("Training completed.")
Python
복사

참고