Search
Duplicate

AI/ Score Based Generative Model(SGM)

Score Based Generative Model(SGM)

EBM의 Score Matching 방법은 모델의 score 함수를 데이터의 score 함수에 일치시키는 것이었다. 스칼라 에너지 함수를 맞추고 score를 계산하는 대신 score를 직접 학습할 수 있는데, 이것을 score-based generative model(SGM)이라 한다.
basic score matching, denoising score matching, sliced score matching 등을 사용하여 score 함수 sθ(x)s_{\boldsymbol{\theta}}(\bold{x})를 최적화 할 수 있다.

Multiple Scale Noise

일반적으로 데이터의 low density region이 존재하면 score matching 방법이 어렵다. 이것은 score matching이 Fisher Divergence를 최소화하기 때문으로 여겨진다.
Ep(x)[xlogp(x)sθ(x)22]=p(x)xlogp(x)sθ(x)22dx.\mathbb{E}_{p(\mathbf{x})}[\| \nabla_\mathbf{x} \log p(\mathbf{x}) - \mathbf{s}_\theta(\mathbf{x}) \|_2^2] = \int p(\mathbf{x}) \| \nabla_\mathbf{x} \log p(\mathbf{x}) - \mathbf{s}_\theta(\mathbf{x}) \|_2^2 \mathrm{d}\mathbf{x}.
실제 데이터 score 함수와 score-based 모델간의 L2 차이는 p(x)p(x)로 가중치를 부여하므로, p(x)p(x)가 작은 low density region에서는 대체로 무시된다. 이는 수준 이하의 결과를 초래할 수 있다.
Langevin 다이나믹스로 샘플링할 때 데이터가 고차원에 있으면 초기 샘플은 low density region에 있을 가능성이 높고 부정확한 score-based 모델을 사용하면 프로세스 초기부터 랑주뱅 다이나믹스를 탈선시켜 고품질의 샘플을 생성하는 것을 막는다.
이러한 문제를 해결하기 위해 데이터를 노이즈로 교란시키고 노이즈가 있는 데이터 포인트에서 score-based 모델을 학습하는 방법을 사용할 수 있다. 노이즈의 크기가 충분히 크면 low density region을 채워서 score 추정의 정확도를 높일 수 있다.
그러나 큰 노이즈는 데이터를 과도하게 손상시켜서 원래의 분포를 크게 바꿀 수 있다. 반면 작은 노이즈 분포는 원본 분포를 크게 바꾸지 않지만 low density region을 충분히 커버할 수 없다.
이 문제를 해결하기 위해 Song과 Ermon은 학습 데이터를 다양한 스케일의 노이즈로 교란하여 문제를 해결했다. 구체적으로 다음을 사용한다.
qσ(x~x)=N(x~x,σ2I)qσ(x~)=pD(x)qσ(x~x)dx\begin{aligned} q_\sigma(\tilde{\bold{x}}|\bold{x}) &= \mathcal{N}(\tilde{\bold{x}}|\bold{x},\sigma^2\bold{I}) \\ q_\sigma(\tilde{\bold{x}}) &= \int p_\mathcal{D}(\bold{x})q_\sigma(\tilde{\bold{x}}|\bold{x})d\bold{x} \end{aligned}
Annealed Lagevin Dynamics나 Diffusion Sampling 등을 사용하면 우선 가장 노이즈 교란된 분포에서 샘플링한 다음 가장 작은 분포에 도달할 때까지 노이즈 스케일의 규모를 완만하게 감소시킬 수 있다.
실제에서 모든 score 모델은 가중치를 공유하며 노이즈 스케일에 조건화된 단일 신경망을 사용하여 구현되는데 이것을 noise conditional score network라고 하고 sθ(x,σ)\bold{s}_{\boldsymbol{\theta}}(\bold{x},\sigma) 형식을 갖는다. 서로 다른 스케일의 score는 노이즈 스케일 당 하나씩 score matching 목적을 혼합하여 학습함으로써 추정된다. denoising score matching 목적을 사용하면 다음을 얻는다.
L(θ;σ)=Eq(x,x~)[12xlogpθ(x~,σ)xlogqσ(x~x)22]=12EpD(x)Ex~N(x,σ2I){sθ(x~,σ)+(x~x)σ222}\begin{aligned} \mathcal{L}(\boldsymbol{\theta};\sigma) &= \mathbb{E}_{q(\bold{x},\tilde{\bold{x}})}\left[{1\over2} \|\nabla_{\bold{x}} \log p_{\boldsymbol{\theta}}(\tilde{\bold{x}},\sigma) - \nabla_\bold{x}\log q_\sigma(\tilde{\bold{x}}|\bold{x}) \|_2^2 \right] \\ &= {1\over2}\mathbb{E}_{p_\mathcal{D}(\bold{x})}\mathbb{E}_{\tilde{\bold{x}}\sim\mathcal{N}(\bold{x},\sigma^2\bold{I})}\left\{\left\|\bold{s}_{\boldsymbol{\theta}}(\tilde{\bold{x}},\sigma) + {(\tilde{\bold{x}}-\bold{x}) \over \sigma^2} \right\|_2^2 \right\} \end{aligned}
여기서 가우시안에 대해 score가 다음과 같이 주어진다는 사실을 이용했다.
xlogN(x~x,σ2I)=x12σ2(xx~)(xx~)=xx~σ2\nabla_\bold{x} \log \mathcal{N}(\tilde{\bold{x}}|\bold{x},\sigma^2\bold{I}) = -\nabla_\bold{x}{1\over 2\sigma^2}(\bold{x}-\tilde{\bold{x}})^\top(\bold{x}-\tilde{\bold{x}}) = {\bold{x} - \tilde{\bold{x}} \over \sigma^2}
TT개 다른 노이즈 스케일을 가지면 다음을 사용하여 가중치 방식으로 손실을 결합할 수 있다.
L(θ;σ1:T)=t=1TλtL(θ;σt)\mathcal{L}(\boldsymbol{\theta};\sigma_{1:T}) = \sum_{t=1}^T \lambda_t \mathcal{L}(\boldsymbol{\theta};\sigma_t)
여기서 σ1>σ2>...>σT\sigma_1 > \sigma_2 > ... > \sigma_T를 선택하고 가중치 항은 λt>0\lambda_t > 0을 만족한다.

DDPM

앞선 score 기반 생성 모델이 DDPM과 동등함을 보일 수 있다. 이것을 보기 위해 우선 pD(x)p_\mathcal{D}(\bold{x})q0(x0)q_0(\bold{x}_0)x~\tilde{\bold{x}}xt\bold{x}_t로, sθ(x~,σ)\bold{s}_{\boldsymbol{\theta}}(\tilde{\bold{x}},\sigma)sθ(xt,σ)\bold{s}_{\boldsymbol{\theta}}(\bold{x}_t,\sigma)로 교체한다. 또한 랜덤으로 균등하게 시간 단계를 샘플링하여 L(θ;σ1:T)=t=1TλtL(θ;σt)\mathcal{L}(\boldsymbol{\theta};\sigma_{1:T}) = \sum_{t=1}^T \lambda_t \mathcal{L}(\boldsymbol{\theta};\sigma_t) 에 대한 확률적 근사를 계산한다. 그러면 해당 방정식 다음이 된다.
L=Ex0q0(x0),xtq(xtx0),tUnif(1,T)[λtsθ(xt,t)+(xtx0)σt222]\mathcal{L} = \mathbb{E}_{\bold{x}_0 \sim q_0(\bold{x}_0),\bold{x}_t\sim q(\bold{x}_t|\bold{x}_0),t \sim \text{Unif}(1,T)}\left[\lambda_t \left\|\bold{s}_{\boldsymbol{\theta}}(\bold{x}_t,t) + {(\bold{x}_t - \bold{x}_0) \over \sigma_t^2} \right\|_2^2 \right]
xt=x0+σtϵ\bold{x}_t = \bold{x}_0 + \sigma_t\boldsymbol{\epsilon}이라는 사실을 이용하고 sθ(xt,t)=ϵθ(xt,t)σt\bold{s}_{\boldsymbol{\theta}}(\bold{x}_t,t) = -{\boldsymbol{\epsilon}_{\boldsymbol{\theta}}(\bold{x}_t,t) \over \sigma_t}를 정의하면 이것을 다음처럼 작성할 수 있다.
L=Ex0q0(x0),ϵN(0,I),tUnif(1,T)[λtσt2ϵϵθ(xt,t)22]\mathcal{L} = \mathbb{E}_{\bold{x}_0 \sim q_0(\bold{x}_0),\boldsymbol{\epsilon} \sim \mathcal{N}(\bold{0},\bold{I}),t \sim \text{Unif}(1,T)}\left[{\lambda_t\over \sigma_t^2} \left\|\boldsymbol{\epsilon}-\boldsymbol{\epsilon}_{\boldsymbol{\theta}}(\bold{x}_t,t)\right\|_2^2 \right]
λT=σt2\lambda_T = \sigma_t^2를 설정하면 DDPM의 Lsimple\mathcal{L}_\text{simple} 손실을 복구한다.

참고