Linear Probe와 Attentive Probe 모두 pre-trained 모델의 representation 능력을 평가하는 방법으로, 주로 downstream task에 transfer하는 것을 목적으로 pre-trained 된 모델을 대상으로 한다. 두 방법 모두 pre-trained 모델을 freeze 한 후에 테스트 하기 때문에 모델의 파라미터를 업데이트 하는 방법이 아니다.
Linear Probe
Linear Probe는 간단한 선형 분류기를 이용하여 pre-trained 모델을 테스트한다. 일반적으로 fully-connected layer 하나만 추가하여 분류 작업을 테스트한다.
import torch
import torch.nn as nn
# 사전 학습된 pre-trained 모델
pretrained_model = torch.hub.load()
for param in pretrained_model.parameters():
param.requires_grad = False # 가중치 동결
# 512차원 입력을 분류할 선형 분류기 정의
linear_probe = nn.Linear(512, num_classes)
# pre-trained 모델에서 표현 추출
features = pretrained_model(images) # 이미지에서 표현 추출
# Linear probe로 분류
logits = linear_probe(features)
Python
복사
Attentive Probe
Attentive Probe는 linear probe 보다 복잡한 방식을 사용하는데, 학습 가능한 query 토큰을 도입하여 cross-attention layer에서 입력 시퀀스와 상호작용하게 한다. pre-trained 모델은 freeze이므로 모델을 업데이트하지 않지만, downstream task에 대해 모델의 성능을 높일 수 있게 한다.
import torch
import torch.nn as nn
import torch.nn.functional as F
class AttentiveProbe(nn.Module):
def __init__(self, backbone, input_dim, num_heads, hidden_dim, num_classes):
super(AttentiveProbe, self).__init__()
self.backbone = backbone # 사전 학습된 백본 (동결된 상태)
# 교차 주의 레이어
self.query_token = nn.Parameter(torch.zeros(1, 1, input_dim)) # 학습 가능한 쿼리 토큰
self.cross_attention = nn.MultiheadAttention(embed_dim=input_dim, num_heads=num_heads)
# MLP 헤드 (MLP에서 사용되는 레이어)
self.mlp = nn.Sequential(
nn.Linear(input_dim, hidden_dim),
nn.GELU(),
nn.Linear(hidden_dim, input_dim),
nn.LayerNorm(input_dim)
)
# 최종 분류기
self.classifier = nn.Linear(input_dim, num_classes)
def forward(self, x):
# 백본을 사용하여 비디오나 이미지 클립에서 특징 추출
with torch.no_grad():
features = self.backbone(x) # 백본은 동결된 상태에서 사용
# 학습 가능한 쿼리 토큰과 결합
batch_size = features.size(0)
query = self.query_token.expand(batch_size, -1, -1) # 쿼리 토큰을 배치 크기로 확장
# 교차 주의 (query와 features의 상호작용)
attn_output, _ = self.cross_attention(query, features, features)
# 잔차 연결 (Residual connection)
attn_output = attn_output + query
# MLP 헤드 통과
attn_output = self.mlp(attn_output)
# 최종 분류
logits = self.classifier(attn_output.squeeze(1)) # 쿼리의 최종 출력으로 분류
return logits
# 사전 학습된 백본 (예: ViT 모델)
backbone = torch.hub.load('facebookresearch/dino:main', 'dino_vitb16')
# Attentive Probe 모델 초기화
model = AttentiveProbe(
backbone=backbone, # 동결된 사전 학습 백본
input_dim=768, # 백본의 출력 차원
num_heads=8, # 교차 주의에서 사용할 헤드 개수
hidden_dim=1024, # MLP에서 사용할 은닉 차원
num_classes=1000 # 분류할 클래스 수
)
# 임의의 입력 데이터 (예: 비디오 클립)
input_data = torch.randn(8, 16, 3, 224, 224) # 배치 크기 8, 16프레임, 224x224 크기 이미지
# 모델을 통해 예측
logits = model(input_data)
print(logits.shape) # [8, 1000] -> 1000개의 클래스에 대한 예측
Python
복사