Search
Duplicate

AI/ Denoising Diffusion Implicit Model(DDIM)

Denoising Diffusion Implicit Model(DDIM)

DDIM은 샘플링을 결정론적으로 바꾸고 샘플링 단계 수를 줄여서 시간이 오래 걸리는 DDPM에 비해 속도를 높인 방법이다. 대신 DDPM에 비해 샘플의 다양성이 떨어진다는 단점이 존재한다. DDIM도 DDPM과 마찬가지로 모델이 noise를 예측하도록 학습된다.
이를 위한 간단한 방법 중 하나는 [T/S][T/S] 단계마다 샘플링 업데이트를 수행하여 프로세스를 TT 단계에서 SS 단계로 줄이는 점진적 샘플링 일정을 실행하는 것이다. 생성에 대한 새로운 샘플링 스케줄은 {τ1,,τS}\{\tau_1, \dots, \tau_S\}이며 여기서 τ1<τ2<<τS[1,T]\tau_1 < \tau_2 < \dots <\tau_S \in [1, T]S<TS < T .
원하는 표준편차 σt\sigma_t로 매개변수화되도록 qσ(xt1xt,x0)q_{\sigma}(x_{t-1}|x_t, x_0)를 다음과 같이 다시 작성할 수 있다. xt=αˉtx0+1αˉtϵt\mathbf{x}_t = \sqrt{\bar{\alpha}_t}\mathbf{x}_0 + \sqrt{1 - \bar{\alpha}_t}\boldsymbol{\epsilon}_t으로부터
xt1=αˉt1x0+1αˉt1ϵt1=αˉt1x0+1αˉt1σt2ϵt+σtϵ=αˉt1x0+1αˉt1σt2xtαˉtx01αˉt+σtϵ\begin{aligned}\mathbf{x}_{t-1} &= \sqrt{\bar{\alpha}_{t-1}}\mathbf{x}_0 + \sqrt{1 - \bar{\alpha}_{t-1}}\boldsymbol{\epsilon}_{t-1} \\&= \sqrt{\bar{\alpha}_{t-1}}\mathbf{x}_0 + \sqrt{1 - \bar{\alpha}_{t-1} - \sigma_t^2} \boldsymbol{\epsilon}_t + \sigma_t\boldsymbol{\epsilon} \\&= \sqrt{\bar{\alpha}_{t-1}}\mathbf{x}_0 + \sqrt{1 - \bar{\alpha}_{t-1} - \sigma_t^2} \frac{\mathbf{x}_t - \sqrt{\bar{\alpha}_t}\mathbf{x}_0}{\sqrt{1 - \bar{\alpha}_t}} + \sigma_t\boldsymbol{\epsilon} \end{aligned}
(첫 번째 줄의 1αˉt1ϵt1 \sqrt{1 - \bar{\alpha}_{t-1}}\boldsymbol{\epsilon}_{t-1}1αˉt1σt2ϵt+σtϵ\sqrt{1 - \bar{\alpha}_{t-1} - \sigma_t^2} \boldsymbol{\epsilon}_t + \sigma_t\boldsymbol{\epsilon}로 바뀌는 것은 다소 이해하기 어려울 수 있는데, 두 가우시안 N(0,σ12,I)\mathcal{N}(0, \sigma_1^2, \bold{I})N(0,σ22,I)\mathcal{N}(0, \sigma_2^2, \bold{I})을 병합하면 새로운 분포는 N(0,(σ12+σ22)I)\mathcal{N}(0, (\sigma_1^2 + \sigma_2^2)\bold{I})가 된다는 것을 거꾸로 생각하면 된다. 1αˉt1σt2ϵt+σtϵ\sqrt{1 - \bar{\alpha}_{t-1} - \sigma_t^2} \boldsymbol{\epsilon}_t + \sigma_t\boldsymbol{\epsilon}의 두 계수를 각각 제곱한 후 합하면 1αˉt1σt2+σt21 - \bar{\alpha}_{t-1} - \sigma_t^2 + \sigma_t^2이 되고 여기에 제곱근을 씌우면 1αˉt1ϵt1 \sqrt{1 - \bar{\alpha}_{t-1}}\boldsymbol{\epsilon}_{t-1}이 된다.)
따라서
q(xt1xt,x0)=N(αˉt1x0+1αˉt1σ~t2xtαˉtx01αˉt,σ~t2I)q(\bold{x}_{t-1}|\bold{x}_t,\bold{x}_0) = \mathcal{N}(\sqrt{\bar{\alpha}_{t-1}}\bold{x}_0 + \sqrt{1-\bar{\alpha}_{t-1}-\tilde{\sigma}_t^2}{\bold{x}_t -\sqrt{\bar{\alpha}_t}\bold{x}_0 \over \sqrt{1-\bar{\alpha}_t}}, \tilde{\sigma}_t^2\bold{I})
해당하는 reverse 프로세스는 다음과 같이 정의된다.
pθ(xt1xt)=N(αˉt1x^0+1αˉt1σ~t2xtαˉtx^01αˉt,σ~t2I)p_{\boldsymbol{\theta}}(\bold{x}_{t-1}|\bold{x}_t) = \mathcal{N}(\sqrt{\bar{\alpha}_{t-1}}\hat{\bold{x}}_0 + \sqrt{1-\bar{\alpha}_{t-1}-\tilde{\sigma}_t^2}{\bold{x}_t - \sqrt{\bar{\alpha}_t}\hat{\bold{x}}_0 \over \sqrt{1-\bar{\alpha}_t}},\tilde{\sigma}_t^2\bold{I})
여기서 x^0=x^θ(xt,t)\hat{\bold{x}}_0 = \hat{\bold{x}}_{\boldsymbol{\theta}}(\bold{x}_t,t)는 모델에서 예측된 출력이다. σ~t2=0\tilde{\sigma}_t^2 = 0을 설정하여 초기 prior 샘플(분산이 σ~T2\tilde{\sigma}_T^2에 의해 제어됨)이 주어지면 reverse 프로세스는 완전히 결정론적이 된다. 이 경우 모델이 학습을 마치면 원하는 시점 tt에 대한 샘플을 Markov Chain을 따르지 않고 계산을 통해 샘플을 생성할 수 있다는 뜻이다. DDIM은 SDE와 ODE의 장점을 결합하여 빠른 계산 속도와 우수한 샘플 품질을 보장한다. 이 경우 noise를 추가하는 forward process가 의미가 없어지기 때문에 reverse process 만으로 모델을 학습할 수 있다.
DDIM은 DDPM과 동일한 marginal noise 분포를 갖지만 노이즈를 원본 데이터 샘플에 결정론적으로 다시 매핑한다. 이로 인해 높은 샘플 품질과 함께 빠른 계산 속도를 얻을 수 있다. 일반적으로 diffusion step이 작은 경우에는 DDIM의 품질이 더 낫고, diffusion step이 많은 경우에는 DDPM의 품질이 더 낫다.
DDPM과 비교하여 DDIM은 다음이 가능하다.
1.
훨씬 적은 수의 단계로 더 높은 품질의 샘플을 생성할 수 있다.
2.
생성 프로세스가 결정론적이기 때문에 '일관성' 속성을 가지며, 이는 동일한 잠재 변수에 조건화 된 여러 샘플이 유사한 상위 수준의 특징을 가져야 함을 의미한다.
3.
일관성 때문에 DDIM은 잠재 변수에서 semantically 의미 있는 보간을 수행할 수 있다.
이 모델의 weighted negative VLB는 DDPM의 Lsimple\mathcal{L}_\text{simple}과 같음에 유의하라. 따라서 DDIM 샘플러는 학습된 DDPM 모델에 적용될 수 있다.

Sample Code

DDIM은 forward, reverse 프로세스에서 계산에 차이가 있을 뿐 전체적인 흐름은 DDPM과 유사하다.
DDIM의 forward 단계 식을 더 간단하게 변환하여 아래처럼 구현한다. 만일 분산을 0으로 설정하면 forward 단계를 생략하고 reverse 단계만으로 학습을 수행할 수 있다.
def ddim_q_sample(x_0, noise=None): if noise is None: noise = torch.randn_like(x_0) x_t = torch.sqrt(alphas_cumprod[-1]) * x_0 + torch.sqrt(1.0 - alphas_cumprod[-1]) * noise return x_t
Python
복사
DDIM의 reverse 단계 식을 더 간단하게 변환하여 아래처럼 구현한다.
def ddim_p_sample(model, x_t, t, sigma_t): hat_x_0 = model(x_t, t) sqrt_alpha_cumprod_prev = torch.sqrt(alphas_cumprod_prev[t-1]) sqrt_one_minus_alpha_cumprod = torch.sqrt(1 - alphas_cumprod[t]) sqrt_recip_alpha_cumprod = torch.sqrt(1 / alphas_cumprod[t]) mean = (sqrt_alpha_cumprod_prev * hat_x_0 + sqrt_one_minus_alpha_cumprod * (x_t - sqrt_recip_alpha_cumprod * hat_x_0)) noise = torch.randn_like(x_t) if sigma_t > 0 else 0 x_prev = mean + sigma_t * noise return x_prev
Python
복사
DDIM에서 학습은 DDPM과 마찬가지로 모델이 예측한 noise와 실제 noise 사이의 차이를 이용한다.
x_noisy = ddim_q_sample(x_0, noise) predicted_noise = model(x_noisy, timesteps) loss = loss_function(predicted_noise, noise) loss.backward() optimizer.step()
Python
복사

참고