Search
Duplicate

AI/ Score Matching(SM)

Score Function

log pdf에 대해 입력 x\bold{x}에 관한 1차 gradient 함수 xlogpθ(x)\nabla_{\bold{x}}\log p_{\boldsymbol{\theta}}(\bold{x})을 (Stein) score function이라고 부르며 다음과 같이 정의한다. (pmf의 경우 미분이 직접적으로 적용되지 않으므로 일반적으로 pdf에 대해 사용된다.)
sθ(x)xlogpθ(x)\bold{s}_{\boldsymbol{\theta}}(\bold{x}) \triangleq \nabla_{\bold{x}} \log p_{\boldsymbol{\theta}}(\bold{x})
EBM의 식에 대해 입력에 관한 1차 gradient를 취하면 다음과 같이 정리된다.
xlogpθ(x)=xEθ(x)xlogZθ=0=xEθ(x)\nabla_{\bold{x}} \log p_{\boldsymbol{\theta}}(\bold{x}) = -\nabla_{\bold{x}} \mathcal{E}_{\boldsymbol{\theta}}(\bold{x}) - \underbrace{\nabla_{\bold{x}} \log Z_{\boldsymbol{\theta}}}_{=0} = -\nabla_{\bold{x}} \mathcal{E}_{\boldsymbol{\theta}}(\bold{x})
결국 EBM에 대해 Score 함수를 이용하면 계산이 까다로운 정규화 상수를 계산하지 않고도 모델을 학습할 수 있다.

Score Matching(SM)

Score Matching은 데이터의 score 함수와 모델의 score 함수를 구하고 그 둘 사이의 차이를 최소화하는 방법을 말한다. Score Matching 목적은 Fisher Divergence라 부르는 두 분포의 score 함수 사이의 차이를 최소화한다.
DF(pD(x)pθ(x))=EpD(x)[12xlogpD(x)xlogpθ(x)2]D_F(p_\mathcal{D}(\bold{x})\|p_{\boldsymbol{\theta}}(\bold{x})) = \mathbb{E}_{p_\mathcal{D}(\bold{x})}\left[ {1\over2}\|\nabla_\bold{x} \log p_\mathcal{D}(\bold{x}) - \nabla_\bold{x} \log p_{\boldsymbol{\theta}}(\bold{x})\|^2 \right]
참고로 Score Matching과 Contrastive Divergence는 매우 다른 접근으로 보이지만 서로 밀접하게 연결되어 있다. 실제로 Score Matching은 특정한 MCMC 샘플러의 극한에서 Contrastive Divergence의 special case로 볼 수 있다.

Basic Score Matching

특정한 regularity 조건 아래 Fisher divergence에서 pD(x)p_\mathcal{D}(\bold{x})의 알려지지 않은 1차 도함수를 Eθ(x)\mathcal{E}_{\boldsymbol{\theta}}(\bold{x})의 2차 도함수로 교체하여 부분별 적분을 사용할 수 있다.
DF(pD(x)pθ(x))=EpD(x)[12i=1d(Eθ(x)xi2Eθ(x)xi2)2]+const=EpD(x)[12sθ(x)2+tr(Jxsθ(x))]+const\begin{aligned} D_F(p_\mathcal{D}(\bold{x})\|p_{\boldsymbol{\theta}}(\bold{x})) &= \mathbb{E}_{p_\mathcal{D}(\bold{x})}\left[{1\over2}\sum_{i=1}^d\left({\partial \mathcal{E}_{\boldsymbol{\theta}}(\bold{x}) \over \partial x_i} - {\partial^2 \mathcal{E}_{\boldsymbol{\theta}}(\bold{x})\over \partial x_i^2} \right)^2 \right] + \text{const} \\ &= \mathbb{E}_{p_\mathcal{D}(\bold{x})}\left[{1\over2}\|\bold{s}_{\boldsymbol{\theta}}(\bold{x})\|^2 + \text{tr}(\bold{J}_\bold{x}\bold{s}_{\boldsymbol{\theta}}(\bold{x})) \right] + \text{const} \end{aligned}
여기서 ddx\bold{x}의 차원이고 Jxsθ(x)\bold{J}_\bold{x}\bold{s}_{\boldsymbol{\theta}}(\bold{x})는 score 함수의 야코비안이다. 상수는 최적화에 영향이 없으므로 학습하는 동안 제거할 수 있다.
위 목적의 단점은 야코비안의 trace를 계산하는데 O(d2)O(d^2) 시간이 걸린다는 것이다. 이러한 이유로 위 방정식의 암시적 SM 공식은 2차 도함수 계산이 다루기 용이한 상대적으로 단순한 에너지 함수에만 적용되었다.

Denoising Score Matching(DSM)

기본 Score Matching의 목적은 어디에서나 연속이고 미분가능하고 유한해야 한다는 조건을 요구하는데, 이미지 데이터는 픽셀 값이 0-255로 이산이기 때문에 이것을 직접 사용할 수 없다. 이를 완화하기 위한 방법으로 데이터 포인트에 노이즈를 추가할 수 있는데, 이 방법을 Denoising Score Matching(DSM)이라 한다.
이를 위해 데이터 포인트에 다음과 같이 노이즈를 추가한다.
x~=x+ϵ\tilde{\bold{x}} = \bold{x} + \boldsymbol{\epsilon}
노이즈 분포 p(ϵ)p(\boldsymbol{\epsilon})가 smooth인 한 다음의 노이즈 데이터 분포 또한 smooth이고,
q(x~)=q(x~x)pD(x)dxq(\tilde{\bold{x}}) = \int q(\tilde{\bold{x}}|\bold{x})p_\mathcal{D}(\bold{x})d\bold{x}
따라서 Fisher divergnece DF(q(x~)pθ(x~))D_F(q(\tilde{\bold{x}})\| p_{\boldsymbol{\theta}}(\tilde{\bold{x}}))는 적합한 목적이다
DF(q(x~)pθ(x~))=Eq(x~)[12x~logq(x~)x~logpθ(x~)2]D_F(q(\tilde{\bold{x}})\| p_{\boldsymbol{\theta}}(\tilde{\bold{x}})) = \mathbb{E}_{q(\tilde{\bold{x}})}\left[ {1\over2}\|\nabla_{\tilde{\bold{x}}} \log q(\tilde{\bold{x}}) - \nabla_{\tilde{\bold{x}}} \log p_{\boldsymbol{\theta}}(\tilde{\bold{x}})\|^2 \right]
이것은 노이즈가 없는 score matching에 규제항을 더한 것으로 근사할 수 있다. 이 규제는 score matching을 더 넓은 범위의 데이터 분포에 적용할 수 있게 하지만 2차 도함수가 비싸다. 이에 대한 우하하고 확장 가능한 해가 다음과 같이 제안 되었다.
DF(q(x~)pθ(x~))=Eq(x~)[12x~logpθ(x~)x~logq(x~)22]=Eq(x,x~)[12x~logpθ(x~)x~logq(x~x)22]+const=12Eq(x,x~)[sθ(x~)(xx~)σ222]\begin{aligned} D_F(q(\tilde{\bold{x}})\|p_{\boldsymbol{\theta}}(\tilde{\bold{x}})) &= \mathbb{E}_{q(\tilde{\bold{x}})}\left[{1\over2}\|\nabla_{\tilde{\bold{x}}} \log p_{\boldsymbol{\theta}}(\tilde{\bold{x}}) - \nabla_{\tilde{\bold{x}}} \log q(\tilde{\bold{x}})\|_2^2 \right] \\ &= \mathbb{E}_{q(\bold{x},\tilde{\bold{x}})}\left[{1\over2}\|\nabla_{\tilde{\bold{x}}}\log p_{\boldsymbol{\theta}}(\tilde{\bold{x}}) - \nabla_{\tilde{\bold{x}}} \log q(\tilde{\bold{x}}|\bold{x})\|_2^2 \right] + \text{const} \\ &={1\over2}\mathbb{E}_{q(\bold{x},\tilde{\bold{x}})}\left[\left\|s_{\boldsymbol{\theta}}(\tilde{\bold{x}}) - {(\bold{x}-\tilde{\bold{x}}) \over \sigma^2} \right\|_2^2 \right] \end{aligned}
여기서 sθ(x~)=x~logpθ(x~)s_{\boldsymbol{\theta}}(\tilde{\bold{x}}) = \nabla_{\tilde{\bold{x}}} \log p_{\boldsymbol{\theta}}(\tilde{\bold{x}})는 추정된 score 함수이고
xlogq(x~x)=xlogN(x~x,σ2I)=(x~x)σ2+const\nabla_\bold{x} \log q(\tilde{\bold{x}}|\bold{x}) = \nabla_\bold{x} \log \mathcal{N}(\tilde{\bold{x}}|\bold{x},\sigma^2\bold{I}) = {-(\tilde{\bold{x}}-\bold{x}) \over \sigma^2} + \text{const}
directional 항 xx~\bold{x} -\tilde{\bold{x}}는 노이즈 입력에서 깨끗한 입력으로 이동하는 것에 해당하고 score 함수가 denoising 연산을 근사하기를 원한다. 이 아이디어는 diffusion model로 이어진다.
12Eq(x,x~)[sθ(x~)(xx~)σ222]{1\over2}\mathbb{E}_{q(\bold{x},\tilde{\bold{x}})}\left[\left\|s_{\boldsymbol{\theta}}(\tilde{\bold{x}}) - {(\bold{x}-\tilde{\bold{x}}) \over \sigma^2} \right\|_2^2 \right] 를 계산하기 위해 pD(x)p_\mathcal{D}(\bold{x})에서 샘플하고 노이즈 항 x~\tilde{\bold{x}}를 샘플할 수 있다. (상수 항은 최적화에 영향을 주지 않으며 최적화 해를 변경하지 않고 무시할 수 있다.) 이 추정 방법을 denoising score matching(DSM)이라고 한다.

Sliced Score Matching(SSM)

데이터에 노이즈를 추가하여 DSM은 2차 도함수의 값비싼 계산을 피할 수 있지만 DSM 목적을 최소화하는 최적 EBM은 원래의 노이즈 없는 데이터 분포 pD(x)p_\mathcal{D}(\bold{x})가 아니라 노이즈-섭동된 데이터 q(x~)q(\tilde{\bold{x}})의 분포에 해당한다. 다시 말해 DSM은 데이터 분포의 일관성 있는 추정기를 제공하지 않는다. 즉 무한한 데이터를 가져도 데이터 분포와 정확하게 일치하는 EBM을 직접적으로 얻을 수 없다.
Sliced score matching(SSM)은 Denoising score matching의 대안 중 하나로 일관성 있고 계산적으로 효율적이다. 두 벡터값 점수 사이의 Fisher 다이버전스를 최소화하는 대신 SSM은 랜덤으로 투영 벡터 v\bold{v}를 샘플한 다음 v\bold{v}와 두 score 사이의 내적을 취하고 두 스칼라 결과를 비교한다. 더 구체적으로 sliced score matching은 sliced Fisher divergence라고 부르는 다음의 다이버전스를 최소화한다.
DSF(pD(x)pθ(x))=EpD(x)Ep(v)[12(vxlogpD(x)vxlogpθ(x))2]D_\text{SF}(p_\mathcal{D}(\bold{x})\|p_{\boldsymbol{\theta}}(\bold{x})) = \mathbb{E}_{p_\mathcal{D}(\bold{x})}\mathbb{E}_{p(\bold{v})}\left[{1\over2}(\bold{v}^\top \nabla_\bold{x} \log p_\mathcal{D}(\bold{x}) - \bold{v}^\top \nabla_\bold{x} \log p_{\boldsymbol{\theta}}(\bold{x}))^2 \right]
여기서 p(v)p(\bold{v})Ep(v)[vv]\mathbb{E}_{p(\bold{v})}[\bold{vv}^\top]가 양의 정부호인 투영 분포를 표기한다.
Fisher 다이버전스와 유사하게 sliced Fisher 다이버전스는 알려지지 않은 xlogpD(x)\nabla_\bold{x} \log p_\mathcal{D}(\bold{x})을 포함하지 않는 암시적 형식을 갖는다.
DSF(pD(x)pθ(x))=EpD(x)Ep(v)[12i=1d(Eθ(x)xivi)2+i=1dj=1d2Eθ(x)xixjvivj]+CD_\text{SF}(p_\mathcal{D}(\bold{x})\|p_{\boldsymbol{\theta}}(\bold{x})) = \mathbb{E}_{p_\mathcal{D}(\bold{x})}\mathbb{E}_{p(\bold{v})} \left[{1\over2} \sum_{i=1}^d \left( {\partial \mathcal{E}_{\boldsymbol{\theta}}(\bold{x}) \over \partial x_i} v_i \right)^2 + \sum_{i=1}^d \sum_{j=1}^d {\partial^2 \mathcal{E}_{\boldsymbol{\theta}}(\bold{x}) \over \partial x_i \partial x_j} v_iv_j \right] + C
위의 목적에서 모든 기대는 경험적 평균을 사용하여 추정될 수 있고 상수 항 CC는 학습에 영향을 주지 않으므로 제거할 수 있다. 두 번째 항은 Eθ(x)\mathcal{E}_{\boldsymbol{\theta}}(\bold{x})의 2차 도함수를 포함하지만 SM에서와 달리 다음과 같이 계산을 묶어서 할 수 있기 때문에 차원 dd에서 선형 비용으로 효율적으로 계산될 수 있다.
i=1dj=1d2Eθ(x)xixjvivj=i=1dxi(j=1dEθ(x)xjvj):=f(x)vi\sum_{i=1}^d \sum_{j=1}^d {\partial^2 \mathcal{E}_{\boldsymbol{\theta}}(\bold{x}) \over \partial x_i \partial x_j} v_iv_j = \sum_{i=1}^d {\partial \over \partial x_i} \underbrace{\left(\sum_{j=1}^d{\partial \mathcal{E}_{\boldsymbol{\theta}}(\bold{x}) \over \partial x_j}v_j \right)}_{:=f(\bold{x})}v_i
여기서 f(x)f(\bold{x})ii의 서로 다른 값에 대해 동일하다. 그러므로 위 방정식을 평가하기 위해 O(d)O(d) 계산에 더해 바깥 합에 대한 또 다른 O(d)O(d) 계산을 한 번만 계산하면 된다. 반면에 원래의 SM 목적은 O(d2)O(d^2) 계산이 필요하다.

Sample Code

Model

Model은 간단히 구현한다. 신경망의 모델은 하나의 함수로 볼 수 있으며 이 경우 에너지 함수에 해당한다.
# EBM 모델 정의. 이것을 energy 함수라고 할 수 있다. class EBM(nn.Module): def __init__(self): super(EBM, self).__init__() self.model = nn.Sequential( nn.Linear(28*28, 128), nn.ReLU(), nn.Linear(128, 1), ) def forward(self, x): return self.model(x.view(x.size(0), -1))
Python
복사

Objective

SM은 에너지 함수(모델)의 결과에 대한 log gradient와 실제 데이터 분포에 대한 log gradient의 차이를 최소화하는 방식으로 구현된다. 다만 실제 데이터의 log gradient를 직접 구할 수 없기 때문에 아래와 같은 방법으로 근사로 구현한다.
def score_matching_loss(ebm, x): x.requires_grad = True # 에너지 값을 계산 energy = ebm(x) # 에너지 값에 대한 그라디언트 (score 함수) energy_grad = grad(energy.sum(), x, create_graph=True)[0] # 에너지 그라디언트의 각 컴포넌트에 대해 다시 그라디언트를 계산하여 다이버전스 계산 jacobians = [] for i in range(x.shape[1]): # x의 각 feature에 대하여 grad_component = grad(energy_grad[:, i].sum(), x, create_graph=True)[0] jacobians.append(grad_component[:, i]) # 자코비안의 대각성분을 추출 divergence = torch.stack(jacobians).sum(dim=0) # 자코비안 대각성분의 합 (트레이스) # Score matching loss 계산 loss = (0.5 * energy_grad.pow(2).sum(dim=1) + divergence).mean() return loss
Python
복사

Train

모델 학습은 다음과 같이 수행한다.
import torch import torch.nn as nn from torch.autograd import grad from torch.optim import Adam from torchvision.datasets import MNIST from torchvision import transforms from torch.utils.data import DataLoader import numpy as np import matplotlib.pyplot as plt # gpu를 사용하는 경우 gpu에서 처리 device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"Using device: {device}") # 데이터 로더 설정 def get_data_loader(batch_size=128): transform = transforms.Compose([ transforms.ToTensor(), # 이미지를 PyTorch 텐서로 변환 transforms.Normalize((0.5,), (0.5,)) # 정규화: [0,1] -> [-1,1] ]) train_dataset = MNIST(root='./data', train=True, transform=transform, download=True) train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True) return train_loader data_loader = get_data_loader() # 모델 및 최적화 설정 ebm_sm = EBM().to(device) optim_sm = torch.optim.Adam(ebm_sm.parameters(), lr=0.001) data_loader = get_data_loader() # 학습 루프 epochs = 10 for epoch in range(epochs): for data, _ in data_loader: data = data.view(data.size(0), -1).to(device) # score matching을 통해 손실을 구하고 역전파한다. loss = score_matching_loss(ebm_sm, data) optim_sm.zero_grad() loss.backward() optim_sm.step() print(f"Epoch {epoch+1}, Loss: {loss.item()}")
Python
복사

참고