Search
Duplicate

AI/ Masked Auto-Encoder(MAE)

Masked Auto-Encoder(MAE)

Masked Auto-Encoder는 입력을 encoder를 통해 저차원 latent space로 투영했다가 decoder를 통해 해당 latent space에서 입력을 복원하는 구조를 갖는 전통적인 Autoencdoer의 변종으로 입력에 Mask를 씌운 후에 Mask된 부분을 픽셀 레벨로 복원하는 모델이다.
개념 자체는 Autoencoder를 따르기 때문에 간단한데, encoder에 ViT를 사용하고, decoder에는 Lightweight Transformer decoder를 사용하여 전통적인 Autoencoder와 달리 비대칭이라는 특징이 있다. 전체적인 아키텍쳐가 간단해서 모델을 확장하기 용이하다는 장점이 있다.

Masking

입력을 Masking 할 때는 ViT처럼 겹치지 않는 Grid로 Patch를 분할한 후, 무작위로 전체 Patch의 (무려) 75%를 Masking하여 사용한다. Encoder에서는 Masking 되지 않은 나머지에 대해서만 positional embedding을 추가한 다음 ViT를 통해 latent space로 embedding한다.
Decoder에서는 encoder의 embedding 데이터와 Masking 된 patch를 모두 입력 받고 전체에 대해 positional embedding을 추가한 다음 원본 이미지를 복원한다. Transformer의 decoder를 사용하지만 auto-regressive하게 결과를 출력하지 않고 한 번에 결과를 출력한다.
마지막으로 복원된 결과 픽셀과 원본 입력의 픽셀 사이의 MSE(mean squared error)를 loss로 사용하여 파라미터를 학습한다.
더 자세한 내용은 참고의 논문 페이지 참조.

Sample Code

아래는 encoder에 ViT를 사용하고 decoder에 Lightweight Transformer decoder를 사용하는 예시 코드이다.
encoder에서 입력을 patch로 만든 후에 visible_patch만 encoding을 수행한다. decoder에서는 encoder의 출력과 mask 된 입력을 함께 사용하여 결과를 출력한다.
import torch import torch.nn as nn from timm.models.vision_transformer import VisionTransformer class MAE(nn.Module): def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, decoder_dim=512, mask_ratio=0.75): super(MAE, self).__init__() # ViT Encoder self.encoder = VisionTransformer(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) # Decoder: Lightweight Transformer Decoder self.decoder = nn.Transformer(d_model=decoder_dim, nhead=8, num_encoder_layers=6, num_decoder_layers=6) # Masking Ratio self.mask_ratio = mask_ratio def forward(self, x): # Split the image into patches B, C, H, W = x.shape patches = self.encoder.patch_embed(x) # Convert image into patches # Mask some of the patches randomly num_patches = patches.shape[1] num_mask = int(self.mask_ratio * num_patches) mask = torch.rand(B, num_patches).argsort(dim=-1)[:, :num_mask] # Encode only visible patches visible_patches = patches.clone() visible_patches[:, mask] = 0 # Set masked patches to zero latent = self.encoder(visible_patches) # Prepare decoder input (concatenate mask tokens) mask_tokens = torch.zeros_like(latent) mask_tokens[:, mask] = self.encoder.pos_embed[:, mask] # Add positional embedding for masked tokens decoder_input = torch.cat([latent, mask_tokens], dim=1) # Decode to reconstruct the image reconstructed_patches = self.decoder(decoder_input, decoder_input) return reconstructed_patches # Example instantiation model = MAE(img_size=224, patch_size=16, embed_dim=768, decoder_dim=512, mask_ratio=0.75)
Python
복사
학습 예시 코드. 모델의 결과를 MSE loss를 이용해서 loss를 계산하고 파라미터를 업데이트한다.
import torch.optim as optim # Loss function (MSE) loss_fn = nn.MSELoss() # Optimizer optimizer = optim.AdamW(model.parameters(), lr=1e-4) # Dummy input (batch size 8, 3 channels, 224x224 image) dummy_input = torch.randn(8, 3, 224, 224) # Training loop for epoch in range(10): # Train for 10 epochs model.train() # Forward pass reconstructed_patches = model(dummy_input) # Compute loss (between reconstructed patches and original input patches) original_patches = model.encoder.patch_embed(dummy_input) # Original patches for comparison loss = loss_fn(reconstructed_patches, original_patches) # Backward pass and optimization optimizer.zero_grad() loss.backward() optimizer.step() print(f'Epoch {epoch+1}, Loss: {loss.item()}')
Python
복사

참고