Search
Duplicate

AI/ Progressive Distillation

Progressive Distillation

DDPM과 DDIM을 함께 사용하여 Distaillation 할 수 있다. 이것은 작은 단계로 높은 품질의 샘플을 생성할 수 있고 progressive distillation 이라 한다. 기본 아이디어는 다음과 같다.
우선 일반적인 방법으로 DDPM 모델을 학습하고 DDIM 방법을 사용하여 샘플한다. 이것을 ‘teacher’ 모델로 취급한다. 이것을 사용하여 중간 잠재 장태를 생성하고 아래 그림에 나오는대로 ‘student’ 모델을 매 두 번째 단계에서 교사의 출력을 예측하도록 학습한다.
학생이 학습된 후에 교사만큼 좋은 결과를 생성할 수 있지만 단계는 절반으로 줄어든다. 이 학생은 다시 교사가 되어 더 빠른 세대의 학생을 생성할 수 있다. pseudocode에 대해 아래 알고리즘 참조.
교사가 더 작아짐에 따라 teaching의 각 단계는 더 빨라짐에 유의하라. 따라서 distillation을 수행하는 전체 시간은 상대적으로 작아진다. 결과 모델은 4단계만으로 높은 품질의 샘플을 생성할 수 있다.
Algorithm: Progressive distillation
1.
입력: 학습된 teacher 모델 x^η(zt)\hat{\bold{x}}_{\boldsymbol{\eta}}(\bold{z}_t)
2.
입력: 데이터셋 D\mathcal{D}
3.
입력: 손실 가중치 함수 ww
4.
입력: student 샘플링 단계 NN
5.
foreach KK 반복 do
a.
θ:=η\boldsymbol{\theta} := \boldsymbol{\eta} (student 할당)
b.
while 수렴하지 않으면 do
i.
xD\bold{x} \sim \mathcal{D}
ii.
t=i/N,iCat(1,2,..,N)t = i/N,i \sim \text{Cat}(1,2,..,N)
iii.
ϵN(0,I)\boldsymbol{\epsilon} \sim \mathcal{N}(\bold{0},\bold{I})
iv.
zt=αtx+σtϵ\bold{z}_t = \alpha_t\bold{x} + \sigma_t \boldsymbol{\epsilon}
v.
t=t0.5/N,t=t1/Nt' = t-0.5/N, t'' = t-1/N
vi.
zt=αtx^η(zt)+σtσt(ztαtx^η(zt))\bold{z}_{t'} = \alpha_{t'}\hat{\bold{x}}_{\boldsymbol{\eta}}(\bold{z}_t) + {\sigma_{t'} \over \sigma_t}(\bold{z}_t - \alpha_t \hat{\bold{x}}_{\boldsymbol{\eta}}(\bold{z}_t))
vii.
zt=αtx^η(zt)+σtσt(ztαtx^η(zt))\bold{z}_{t''} = \alpha_{t''}\hat{\bold{x}}_{\boldsymbol{\eta}}(\bold{z}_{t'}) + {\sigma_{t''} \over \sigma_{t'}}(\bold{z}_{t'}-\alpha_{t'}\hat{\bold{x}}_{\boldsymbol{\eta}}(\bold{z}_{t'}))
viii.
x~=zt(σt/σt)ztαt(σt/σt)αt\tilde{\bold{x}} = {\bold{z}_{t''} - (\sigma_{t''}/\sigma_t)\bold{z}_t \over \alpha_{t''} - (\sigma_{t''}/\sigma_t)\alpha_t} (teacher가 타겟)
ix.
λt=log(αt2/σt2)\lambda_t = \log (\alpha_t^2/\sigma_t^2)
x.
Lθ=w(λt)x~x^θ(zt)22\mathcal{L}_{\boldsymbol{\theta}} = w(\lambda_t) \|\tilde{\bold{x}} - \hat{\bold{x}}_{\boldsymbol{\theta}}(\bold{z}_t)\|_2^2
xi.
θ:=θγθLθ\boldsymbol{\theta} := \boldsymbol{\theta}-\gamma\nabla_{\boldsymbol{\theta}}\mathcal{L}_{\boldsymbol{\theta}}
c.
η:=θ\boldsymbol{\eta} : -=\boldsymbol{\theta} (student는 새로운 teacher가 된다.)
d.
N:=N/2N := N/2 (샘플링 단계를 절반으로 줄인다.)

참고