Search
Duplicate

AI/ Linear Probe, Attentive Probe

Linear Probe와 Attentive Probe 모두 pre-trained 모델의 representation 능력을 평가하는 방법으로, 주로 downstream task에 transfer하는 것을 목적으로 pre-trained 된 모델을 대상으로 한다. 두 방법 모두 pre-trained 모델을 freeze 한 후에 테스트 하기 때문에 모델의 파라미터를 업데이트 하는 방법이 아니다.

Linear Probe

Linear Probe는 간단한 선형 분류기를 이용하여 pre-trained 모델을 테스트한다. 일반적으로 fully-connected layer 하나만 추가하여 분류 작업을 테스트한다.
import torch import torch.nn as nn # 사전 학습된 pre-trained 모델 pretrained_model = torch.hub.load() for param in pretrained_model.parameters(): param.requires_grad = False # 가중치 동결 # 512차원 입력을 분류할 선형 분류기 정의 linear_probe = nn.Linear(512, num_classes) # pre-trained 모델에서 표현 추출 features = pretrained_model(images) # 이미지에서 표현 추출 # Linear probe로 분류 logits = linear_probe(features)
Python
복사

Attentive Probe

Attentive Probe는 linear probe 보다 복잡한 방식을 사용하는데, 학습 가능한 query 토큰을 도입하여 cross-attention layer에서 입력 시퀀스와 상호작용하게 한다. pre-trained 모델은 freeze이므로 모델을 업데이트하지 않지만, downstream task에 대해 모델의 성능을 높일 수 있게 한다.
import torch import torch.nn as nn import torch.nn.functional as F class AttentiveProbe(nn.Module): def __init__(self, backbone, input_dim, num_heads, hidden_dim, num_classes): super(AttentiveProbe, self).__init__() self.backbone = backbone # 사전 학습된 백본 (동결된 상태) # 교차 주의 레이어 self.query_token = nn.Parameter(torch.zeros(1, 1, input_dim)) # 학습 가능한 쿼리 토큰 self.cross_attention = nn.MultiheadAttention(embed_dim=input_dim, num_heads=num_heads) # MLP 헤드 (MLP에서 사용되는 레이어) self.mlp = nn.Sequential( nn.Linear(input_dim, hidden_dim), nn.GELU(), nn.Linear(hidden_dim, input_dim), nn.LayerNorm(input_dim) ) # 최종 분류기 self.classifier = nn.Linear(input_dim, num_classes) def forward(self, x): # 백본을 사용하여 비디오나 이미지 클립에서 특징 추출 with torch.no_grad(): features = self.backbone(x) # 백본은 동결된 상태에서 사용 # 학습 가능한 쿼리 토큰과 결합 batch_size = features.size(0) query = self.query_token.expand(batch_size, -1, -1) # 쿼리 토큰을 배치 크기로 확장 # 교차 주의 (query와 features의 상호작용) attn_output, _ = self.cross_attention(query, features, features) # 잔차 연결 (Residual connection) attn_output = attn_output + query # MLP 헤드 통과 attn_output = self.mlp(attn_output) # 최종 분류 logits = self.classifier(attn_output.squeeze(1)) # 쿼리의 최종 출력으로 분류 return logits # 사전 학습된 백본 (예: ViT 모델) backbone = torch.hub.load('facebookresearch/dino:main', 'dino_vitb16') # Attentive Probe 모델 초기화 model = AttentiveProbe( backbone=backbone, # 동결된 사전 학습 백본 input_dim=768, # 백본의 출력 차원 num_heads=8, # 교차 주의에서 사용할 헤드 개수 hidden_dim=1024, # MLP에서 사용할 은닉 차원 num_classes=1000 # 분류할 클래스 수 ) # 임의의 입력 데이터 (예: 비디오 클립) input_data = torch.randn(8, 16, 3, 224, 224) # 배치 크기 8, 16프레임, 224x224 크기 이미지 # 모델을 통해 예측 logits = model(input_data) print(logits.shape) # [8, 1000] -> 1000개의 클래스에 대한 예측
Python
복사

참고