Search
Duplicate

AI/ Paper/ Flash Attention/ Appendix

B Algorithm Details

우리는 우선 attention의 forward와 backward pass를 유도하고 이들이 메모리 효율적인 방법(시퀀스 길이에 따라 2차 대신 선형적인 추가 메모리 요구)에서 계산될 수 있음을 보인다. 비록 추가 메모리 요구의 양이 줄지만 단순히 구현하면 여전히 2차적 HBM 접근이 발생하여 결과적으로 느린 실행 속도를 갖는다. 우리는 GPU에서 HBM 접근을 줄이기 위해 forward와 backward pass 모두를 구현하는 FlashAttention 알고리즘을 설명한다. 이것은 더 따른 실행 속도와 더 적은 메모리 사용량을 이끈다.

B.1 Memory-efficient forward pass

attention을 메모리 효율적으로 만드는 주요 도전은 K\bold{K}의 컬럼(과 V\bold{V}의 컬럼)을 커플링하는 softmax이다. 우리의 접근은 softmax normalization 상수를 별도로 계산하여 컬럼을 decouple 한다. 이 기법은 몇 문헌에서 attention 계산에 2차적 추가 메모리가 필요하지 않음을 보이기 위해 사용되었다(그러나 HBM 접근 수가 여전히 2차적이어서 실행은 느리다.)
단순성을 위해 softmax 중 max-shifting 단계를 생략한다. 전체 알고리즘은 모든 단계를 포함하는 부록 B.3 참조.
입력 시퀀스 Q,K,VRN×d\bold{Q}, \bold{K}, \bold{V} \in \mathbb{R}^{N \times d}가 주어질 때, attention 출력 ORN×d\bold{O} \in \mathbb{R}^{N \times d}를 계산하기를 원한다는 것을 떠올려라.
S=QKRN×NP=softmax(S)RN×NO=PVRN×d\begin{aligned} \bold{S} &= \bold{QK}^\top \in \mathbb{R}^{N \times N} \\ \bold{P} &= \text{softmax}(\bold{S}) \in \mathbb{R}^{N \times N} \\ \bold{O} &= \bold{PV} \in \mathbb{R}^{N \times d} \end{aligned}
Sij=qikjS_{ij} = q_i^\top k_j를 갖는다. 여기서 qiq_ikjk_j는 각각 Q\bold{Q}K\bold{K}ii-번째 행과 jj-번째 열이다. softmax의 정규화 상수를 다음과 같이 정의한다.
Li=jeqikj(1)L_i = \sum_j e^{q_i^\top k_j} \tag{1}
vjv_jV\bold{V}jj-번째 열이라고 하자. 그러면 출력의 ii-번째 열은 다음과 같다.
oi=Pi:V=jPijvj=jeqikjLivj(2)o_i = P_{i:}\bold{V} = \sum_j P_{ij}v_j = \sum_j {e^{q_i^\top k_j} \over L_i}v_j \tag{2}
LiL_i가 한 번 계산되면 추가 메모리 없이 eqikjLivj{e^{q_i^\top k_j} \over L_i}v_j를 반복적으로 합산하여 oio_i를 계산할 수 있음을 볼 수 있다. 그러므로 forward 패스는 O(n)O(n) 추가 메모리를 사용하여 계산될 수 있다.
1.
방정식 (1)을 따라 모든 ii에 대해 LiL_i을 계산한다. 이것은 O(n)O(n)의 추가 메모리를 취한다.
2.
방정식 (2)를 따라 모든 ii에 대해 oio_i를 계산한다. 이것은 O(d)O(d)의 추가 메모리를 취한다.

B.2 Memory-efficient backward pass

우리는 attention의 backward pass를 유도하고 이것 또한 선형 메모리로 계산될 수 있음을 보인다. Rabe와 Staats는 메모리 효율적인 forward pass에 gradient checkpointing을 적용하여 backward pass가 2차적 추가 메모리 없이 계산될 수 있다고 제안했다. 우리는 대신 backward pass를 명시적으로 유도하고 메모리 효율적인 방법으로 계산할 수 있는 방법을 보인다.
스칼라 손실 함수 ϕ\phi가 있다고 하고 출력 gradient가 dORn×d\bold{dO} \in \mathbb{R}^{n \times d}라고 하자. (여기서 dO\bold{dO}ϕO{\partial \phi \over \partial \bold{O}}를 표기한다.) 우리는 입력 gradient dQ,dK,dVRn×d\bold{dQ}, \bold{dK}, \bold{dV} \in \mathbb{R}^{n \times d}를 계산하기 원한다. (여기서 dQ,dK,dV\bold{dQ}, \bold{dK}, \bold{dV}는 각각 ϕQ,ϕK,ϕV{\partial \phi \over \partial \bold{Q}}, {\partial \phi \over \partial \bold{K}}, {\partial \phi \over \partial \bold{V}}를 표기한다.)
gradient dV\bold{dV}는 쉽게 볼 수 있다. 역방향 자동미분을 수동으로(chain rule 이라고도 함) 적용하면 (행렬표기에서) dV=PdO\bold{dV} = \bold{P}^\top \bold{dO}를 얻는다. 따라서
임의의 행렬곱 AB=C\bold{AB} = \bold{C}에 대해 gradient는 각각 dA=LA=LCCA=dCB\bold{dA} = {\partial L \over \partial \bold{A}} = {\partial L \over \partial \bold{C}}{\partial \bold{C} \over \partial \bold{A}} = \bold{dC}\bold{B}^\topdB=LB=LCCB=AdC\bold{dB} = {\partial L \over \partial \bold{B}} = {\partial L \over \partial \bold{C}}{\partial \bold{C} \over \partial \bold{B}} = \bold{A}^\top \bold{dC}으로 주어진다. 즉, 역행렬이 아니라 전치행렬로 gradient를 구할 수 있다.
dvj=iPijdoi=ieqikjLidoi(3)dv_j = \sum_i P_{ij}do_i = \sum_i {e^{q_i^\top k_j} \over L_i} do_i \tag{3}
LiL_i가 이미 계산되었기 때문에 dvjdv_j를 합산을 반복하여 추가 메모리 없이 계산할 수 있다.
gradient dQ\bold{dQ}dK\bold{dK}는 약간 더 복잡하다. 우선 gradient dP\bold{dP}dS\bold{dS} 를 살펴보자. 방정식 (2)에서 dP=dOV\bold{dP} = \bold{dOV}^\top이 성립한다. 따라서
dPij=doivjdP_{ij} = do_i^\top v_j
Pi:=softmax(Si:)P_{i:} = \text{softmax}(S_{i:})임을 떠올려라. y=softmax(x)y = \text{softmax}(x)의 야코비안이 diag(y)yy\text{diag}(y) - yy^\top라는 점을 사용하여 다음을 갖는다.
dSi:=(diag(Pi:)Pi:Pi:)dPi:=Pi:dPi:(Pi:dPi:)Pi:dS_{i:} = (\text{diag}(P_{i:})-P_{i:}P_{i:}^\top)dP_{i:} = P_{i:}\circ dP_{i:} - (P_{i:}^\top dP_{i:})P_{i:}
여기서 \circ는 pointwise 곱을 표기한다.
P=softmax(S)\bold{P} = \text{softmax}(\bold{S})이므로 S\bold{S}의 gradient를 구하려면 softmax의 야코비안 diag(P)PP\text{diag}(\bold{P})-\bold{P}\bold{P}^\topP\bold{P}의 gradient dP=dOV\bold{dP} = \bold{dOV}^\top를 모두 구해서 곱해야 한다.
다음을 정의한다.
Di=Pi:dPi:=jeqikjLidoivj=doijeqikjLivj=doioi(4)D_i = P_{i:}^\top dP_{i:} = \sum_j {e^{q_i^\top k_j} \over L_i}do_i^\top v_j = do_i^\top \sum_j {e^{q_i^\top k_j} \over L_i}v_j = do_i^\top o_i \tag{4}
그러면
dSi:=Pi:dPi:DiPi:dS_{i:} = P_{i:} \circ dP_{i:} -D_iP_{i:}
그러므로
dSij=PijdPijDiPij=Pij(dPijDi)dS_{ij} = P_{ij}dP_{ij} - D_iP_{ij} = P_{ij}(dP_{ij} - D_i)
이제 dQ\bold{dQ}dK\bold{dK}의 gradient를 얻을 수 있다. Sij=qikjS_{ij} = q_i^\top k_j를 떠올려라. 따라서
QK=S\bold{QK}^\top = \bold{S}에 대해 gradient는 각각 dQ=dSK\bold{dQ} = \bold{dSK}dK=QdS\bold{dK} = \bold{Q}^\top\bold{dS}으로 주어진다.
dqi=jdSijkj=jPij(dPijDi)kj=jeqikjLi(doivjDi)kj(5)dq_i = \sum_j dS_{ij}k_j = \sum_j P_{ij}(dP_{ij}-D_i)k_j = \sum_j {e^{q_i^\top k_j} \over L_i}(do_i^\top v_j - D_i)k_j \tag{5}
유사하게
dki=jdSijqi=jPij(dPijDi)qi=ieqikjLi(doivjDi)qi(6)dk_i = \sum_j dS_{ij}q_i = \sum_j P_{ij}(dP_{ij}-D_i)q_i = \sum_i {e^{q_i^\top k_j} \over L_i}(do_i^\top v_j - D_i)q_i \tag{6}
그러므로 backward pass도 O(n)O(n) 추가 메모리로 계산될 수 있다.
1.
방정식 (3)을 따라 모든 jj에 대해 dvjdv_j를 계산한다. 이것은 O(d)O(d)의 추가 메모리를 취한다.
2.
방정식 (4)을 따라 모든 ii에 대해 DiD_i를 계산한다. 이것은 O(n)O(n)의 추가 메모리를 취한다.
3.
방정식 (5)을 따라 모든 ii에 대해 dqidq_i를 계산한다. 이것은 O(d)O(d)의 추가 메모리를 취한다.
4.
방정식 (6)을 따라 모든 jj에 대해 dkjdk_j를 계산한다. 이것은 O(d)O(d)의 추가 메모리를 취한다.

B.3 FlashAttention: Forward Pass

FlashAttention forward pass의 전체 세부사항을 설명한다. 입력 시퀀스 Q,K,VRN×d\bold{Q}, \bold{K}, \bold{V} \in \mathbb{R}^{N \times d}가 주어지면 attention 출력 ORN×d\bold{O} \in \mathbb{R}^{N \times d}을 계산하기를 원한다.
S=τQKRN×NSmasked=MASK(S)RN×NP=softmax(Smasked)RN×NPdropped=dropout(P,pdrop)O=PdroppedVRN×d\begin{aligned} \bold{S} &= \tau\bold{QK}^\top \in \mathbb{R}^{N\times N} \\ \bold{S}^\text{masked} &= \text{MASK}(S) \in \mathbb{R}^{N \times N} \\ \bold{P} &= \text{softmax}(\bold{S}^\text{masked}) \in \mathbb{R}^{N \times N} \\ \bold{P}^\text{dropped} &= \text{dropout}(\bold{P}, p_\text{drop})\\ \bold{O} &= \bold{P}^\text{dropped}\bold{V} \in \mathbb{R}^{N \times d} \end{aligned}
여기서 τR\tau \in \mathbb{R}은 어떤 softmax scaling(일반적으로 1d{1\over \sqrt{d}})이고 MASK는 입력의 일부 항목을 -\infty로 설정하고 나머지는 동일하게 두는 어떤 masking 함수이다(예: batch의 시퀀스 길이가 동일하지 않을 때 key padding mask). dropout(x,p)\text{dropout}(x,p)xx에 요소별로 dropout을 적용한다(예: 각 요소 xx에 대해 확률 1p1-p의 출력 x1p{x \over 1-p}와 확률 pp의 출력 00)
전체 알고리즘은 알고리즘 2 참조. 우리는 출력 O\bold{O}, softmax 통계량 \ellmm과 backward pass를 위한 pseudo-random number 생성기 상태 R\mathcal{R}을 저장한다.
Algorithm 2. FlashAttention Forward Pass
Require: HBM에서 행렬 Q,K,VRN×d\bold{Q}, \bold{K}, \bold{V} \in \mathbb{R}^{N\times d}, 크기 MM의 on-chip SRAM, softmax scaling 상수 τR\tau \in \mathbb{R}, masking 함수 MASK\text{MASK}, dropout 확률 pdropp_\text{drop}
1.
pseudo-random number 생성기 상태 R\mathcal{R}을 초기화하고 HBM에 저장
2.
block 크기 Bc=M4d,Br=min(M4d,d)B_c = \lceil{M \over 4d} \rceil, B_r = \min(\lceil{M \over 4d}\rceil, d) 설정
3.
HBM에 O=(0)N×dRN×d,=(0)NRN,m=()NRN\bold{O} = (0)_{N \times d} \in \mathbb{R}^{N \times d}, \ell = (0)_N \in \mathbb{R}^N, m = (-\infty)_N \in \mathbb{R}^N 초기화
4.
Q\bold{Q}를 각각 Br×dB_r \times d 크기의 Tr=NBrT_r = \lceil {N \over B_r} \rceil개 블록 Q1,...,QT\bold{Q}_1, ..., \bold{Q}_T으로 분할하고, K,V\bold{K}, \bold{V}를 각각 Bc×dB_c \times d 크기의 Tc=NBcT_c = \lceil {N \over B_c} \rceil개 블록 K1,...,KTc\bold{K}_1, ..., \bold{K}_{T_c}V1,...,VTc\bold{V}_1, ..., \bold{V}_{T_c}으로 분할
5.
O\bold{O}를 각각 Br×dB_r \times d 크기의 TrT_r개 블록 Oi,...,OTr\bold{O}_i,..., \bold{O}_{T_r}로 분할, \ell을 각각 BrB_r 크기의 TrT_r개 블록 i,...,Tr\ell_i,...,\ell_{T_r}로 분할 mm을 각각 BrB_r 크기의 TrT_r개 블록 m1,...,mTrm_1,..., m_{T_r}로 분할
6.
for 1jTc1 \le j \le T_c do
a.
Kj,Vj\bold{K}_j, \bold{V}_j를 HBM에서 on-chip SRAM으로 로드
b.
for 1iTr1 \le i \le T_r do
i.
Qi,Oi,i,mi\bold{Q}_i, \bold{O}_i, \ell_i, m_i를 HBM에서 on-chip SRAM으로 로드
ii.
On chip에서 Sij=τQiKjRBr×Bc\bold{S}_{ij} = \tau \bold{Q}_i \bold{K}_j^\top \in \mathbb{R}^{B_r \times B_c} 계산
iii.
On chip에서 Sijmasked=MASK(Sij)\bold{S}_{ij}^\text{masked} = \text{MASK}(\bold{S}_{ij}) 계산
iv.
On chip에서 다음을 계산
m~ij=rowmax(Sijmasked)RBr\tilde{m}_{ij} = \text{rowmax}(\bold{S}_{ij}^\text{masked}) \in \mathbb{R}^{B_r}
P~ij=exp(Sijmaskedm~ij)RBr×Bc\tilde{\bold{P}}_{ij} = \exp(\bold{S}_{ij}^\text{masked} - \tilde{m}_{ij}) \in \mathbb{R}^{B_r \times B_c} (pointwise)
~ij=rowsum(P~ij)RBr\tilde{\ell}_{ij} = \text{rowsum}(\tilde{\bold{P}}_{ij}) \in \mathbb{R}^{B_r}
v.
On chip에서 다음을 계산
minew=max(mi,m~ij)RBrm_i^\text{new} = \max(m_i, \tilde{m}_{ij}) \in \mathbb{R}^{B_r}
inew=emiminewi+em~ijminew~ijRBr\ell_i^\text{new} = e^{m_i - m_i^\text{new}}\ell_i + e^{\tilde{m}_{ij} - m_i^\text{new}}\tilde{\ell}_{ij} \in \mathbb{R}^{B_r}
vi.
On chip에서 P~ijdropped=dropout(P~ij,pdrop)\tilde{\bold{P}}_{ij}^\text{dropped} = \text{dropout}(\tilde{\bold{P}}_{ij},p_\text{drop})
vii.
HBM에 Oidiag(inew)1(diag(i)emiminewOi+em~ijminewP~ijdroppedVj)\bold{O}_i \leftarrow \text{diag}(\ell_i^\text{new})^{-1}(\text{diag}(\ell_i)e^{m_i - m_i^\text{new}}\bold{O}_i + e^{\tilde{m}_{ij} - m_i^\text{new}}\tilde{\bold{P}}_{ij}^\text{dropped}\bold{V}_j) 쓰기
viii.
HBM에 iinew,miminew\ell_i \leftarrow \ell_i^\text{new}, m_i \leftarrow m_i^\text{new} 쓰기
c.
end for
7.
end for
8.
O,,m,R\bold{O}, \ell, m, \mathcal{R} 반환

B.4 FlashAttention: Backward Pass

FlashAttention의 Backward pass의 전체 세부사항을 설명한다. 입력 시퀀스 Q,K,VRN×d\bold{Q}, \bold{K}, \bold{V} \in \mathbb{R}^{N \times d}, 출력 ORN×d\bold{O} \in \mathbb{R}^{N \times d}과 출력 gradient dO\bold{dO}가 주어지면 입력 gradient dQ,dK,dVRN×d\bold{dQ}, \bold{dK}, \bold{dV} \in \mathbb{R}^{N \times d}를 계산하기를 원한다.
완결성을 위해 우선 Algorithm 3에서 표준 attention backward pass를 설명한다.
Algorithm 3. 표준 Attention Backward pass
Require: HBM에서 행렬 Q,K,V,dORN×d,PRN×N\bold{Q}, \bold{K}, \bold{V}, \bold{dO} \in \mathbb{R}^{N \times d}, \bold{P} \in \mathbb{R}^{N \times N}
1.
HBM에서 블록 별로 P,dO\bold{P}, \bold{dO}를 로드하고 dV=PdORN×d\bold{dV} = \bold{P}^\top \bold{dO} \in \mathbb{R}^{N \times d}를 계산하고 dV\bold{dV}를 HBM에 쓴다.
2.
HBM에서 블록 별로 dO,V\bold{dO}, \bold{V}를 로드하고 dP=dOVRN×N\bold{dP} = \bold{dOV}^\top \in \mathbb{R}^{N \times N}를 계산하고 dP\bold{dP}를 HBM에 쓴다.
3.
HBM에서 P,dP\bold{P}, \bold{dP}를 읽고 dSRN×N\bold{dS} \in \mathbb{R}^{N \times N}을 계산하고 (여기서 dSij=Pij(dPijlPildPil)dS_{ij} = P_{ij}(dP_{ij} - \sum_l P_{il}dP_{il}), dS\bold{dS}를 HBM에 쓴다.
4.
HBM에서 블록별로 dS\bold{dS}K\bold{K}를 로드하고, dQ=dSK\bold{dQ} = \bold{dSK}를 계산하고, dQ\bold{dQ}를 HBM에 쓴다.
5.
HBM에서 블록별로 dS\bold{dS}Q\bold{Q}를 로드하고, dK=dSQ\bold{dK} = \bold{dS}^\top\bold{Q}를 계산하고, dK\bold{dK}를 HBM에 쓴다.
6.
dQ,dK,dV\bold{dQ}, \bold{dK}, \bold{dV} 반환
이제 FlashAttention backward pass에 관한 2가지 관찰을 한다.
1.
forward pass에서 O(N2)O(N^2) 크기의 dropout mask를 저장할 필요가 없다. 대신 forward pass에서 pseudo-random number 생성기 상태를 저장하고 backward pass에서 dropout mask를 re-generate 할 수 있다. 이것은 오직 O(N)O(N)의 추가 메모리만 사용한다.
2.
softmax gradient를 계산할 때, 방정식 (4)를 사용하여 크기 NNPi:P_{i:}dPi:dP_{i:}에 대한 축소하지 않고 Di=Pi:dPi:D_i = P_{i:}^\top dP_{i:}를 계산한다(그것들은 SRAM으로 맞춰지지 않는다). 대신 Di=doioiD_i = do_i^\top o_i로 재작성하고 크기 dd의 벡터 사이의 점곱을 계산할 수 있다.
FlashAttention backward pass 알고리즘은 Algorhtm 4에 있다. 개념적으로 이것은 부록 B.2의 파생의 block 버전이다.
Algorithm 4. FlashAttention Backward pass
Require: HBM에서 행렬 Q,K,V,dORN×d\bold{Q}, \bold{K}, \bold{V}, \bold{dO} \in \mathbb{R}^{N \times d}, HBM에서 벡터 ,mRN\ell, m \in \mathbb{R}^N, 크기 MM의 on-chip SRAM, softmax scaling 상수 τR\tau \in \mathbb{R}, masking 함수 MASK\text{MASK}, dropout 확률 pdropp_\text{drop}, forward pass에서 pseudo-random number 생성기 상태 R\mathcal{R}
1.
pseudo-random number 생성기 상태를 R\mathcal{R}로 설정
2.
block 크기 Bc=M4d,Br=min(M4d,d)B_c = \lceil{M \over 4d} \rceil, B_r = \min(\lceil{M \over 4d}\rceil, d) 설정
3.
Q\bold{Q}를 각각 Br×dB_r \times d 크기의 Tr=NBrT_r = \lceil {N \over B_r} \rceil개 블록 Q1,...,QT\bold{Q}_1, ..., \bold{Q}_T으로 분할하고, K,V\bold{K}, \bold{V}를 각각 Bc×dB_c \times d 크기의 Tc=NBcT_c = \lceil {N \over B_c} \rceil개 블록 K1,...,KTc\bold{K}_1, ..., \bold{K}_{T_c}V1,...,VTc\bold{V}_1, ..., \bold{V}_{T_c}으로 분할
4.
O\bold{O}를 각각 Br×dB_r \times d 크기의 TrT_r개 블록 Oi,...,OTr\bold{O}_i,..., \bold{O}_{T_r}로 분할, dO\bold{dO}를 각각 Br×dB_r \times d 크기의 TrT_r개 블록 dOi,...,dOTr\bold{dO}_i, ..., \bold{dO}_{T_r}로 분할, \ell을 각각 BrB_r 크기의 TrT_r개 블록 i,...,Tr\ell_i,...,\ell_{T_r}로 분할 mm을 각각 BrB_r 크기의 TrT_r개 블록 m1,...,mTrm_1,..., m_{T_r}로 분할
5.
HBM에서 dQ=(0)N×d\bold{dQ} = (0)_{N\times d}를 초기화하고 각각 Br×dB_r \times d 크기의 TrT_r 블록 dQ1,...,dQTr\bold{dQ}_1,..., \bold{dQ}_{T_r}로 분할. HBM에서 dK=(0)N×d,dV=(0)N×d\bold{dK} = (0)_{N\times d}, \bold{dV} = (0)_{N \times d}를 초기화하고 각각 Bc×dB_c \times d 크기의 TcT_c개 블록 dK1,...,dKTc\bold{dK}_1,..., \bold{dK}_{T_c}dV1,...,dVTc\bold{dV}_1,..., \bold{dV}_{T_c}로 분할.
6.
for ijTci \le j \le T_c do
a.
HBM에서 on-chip SRAM으로 Kj,Vj\bold{K}_j, \bold{V}_j 로드
b.
SRAM에서 dK~j=(0)Bc×d,dV~j=(0)Bc×d\bold{d}\tilde{\bold{K}}_j = (0)_{B_c \times d}, \bold{d}\tilde{\bold{V}}_j = (0)_{B_c \times d} 초기화
c.
for 1iTr1 \le i \le T_r do
i.
HBM에서 on-chip SRAM으로 Qi,Oi,dOi,dQi,i,mi\bold{Q}_i, \bold{O}_i, \bold{dO}_i, \bold{dQ}_i, \ell_i, m_i 로드
ii.
On chip에서 Sij=τQiKjRBr×Bc\bold{S}_{ij} = \tau\bold{Q}_i\bold{K}_j^\top \in \mathbb{R}^{B_r \times B_c} 계산
forward와 동일하게 Sij\bold{S}_{ij} 계산
iii.
On chip에서 Sijmasked=MASK(Sij)\bold{S}_{ij}^\text{masked} = \text{MASK}(\bold{S}_{ij}) 계산
forward와 동일하게 Sijmasked\bold{S}_{ij}^\text{masked} 계산
iv.
On chip에서 Pij=diag(i)1exp(Sijmaskedmi)RBr×Bc\bold{P}_{ij} = \text{diag}(\ell_i)^{-1}\exp(\bold{S}_{ij}^\text{masked} - m_i) \in \mathbb{R}^{B_r \times B_c} 계산
\ellmm이 forward에서 계산되었으므로 Pij\bold{P}_{ij}를 바로 계산
v.
On chip에서 dropout mask ZijRBr×Bc\bold{Z}_{ij} \in \mathbb{R}^{B_r \times B_c} 계산. 여기서 각 항은 1pdrop1- p_\text{drop}의 확률로 11pdrop{1\over 1- p_\text{drop}}의 값을 갖고 pdropp_\text{drop}의 확률로 00의 값을 가짐
forward의 dropout을 적용하기 위한 dropout mask 계산
vi.
On chip에서 Pijdropped=PijZij\bold{P}_{ij}^\text{dropped} = \bold{P}_{ij} \circ \bold{Z}_{ij} 계산 (pointwise 곱)
Pijdropped\bold{P}_{ij}^\text{dropped} 계산. 여기까지가 forward 계산을 recompute한 부분.
vii.
On chip에서 dV~jdV~j+(Pijdropped)dOiRBc×d\bold{d}\tilde{\bold{V}}_j \leftarrow \bold{d}\tilde{\bold{V}}_j + (\bold{P}_{ij}^\text{dropped})^\top \bold{dO}_i \in \mathbb{R}^{B_c \times d} 계산
dV=PdO\bold{dV} = \bold{P}^\top \bold{dO}을 따라 dvj=iPijdoi=ieqikjLidoidv_j = \sum_i P_{ij}do_i = \sum_i {e^{q_i^\top k_j} \over L_i} do_i 로 업데이트
viii.
On chip에서 dPijdropped=dOiVjRBr×Bc\bold{dP}_{ij}^\text{dropped} = \bold{dO}_i \bold{V}_j^\top \in \mathbb{R}^{B_r \times B_c} 계산
dP=dOV\bold{dP} = \bold{dOV}^\top를 따라 dPij=doivjdP_{ij} = do_i^\top v_j로 업데이트
ix.
On chip에서 dPij=dPijdroppedZij\bold{dP}_{ij} = \bold{dP}_{ij}^\text{dropped} \circ \bold{Z}_{ij} 계산 (pointwise 곱)
dropout 적용
x.
On chip에서 Di=rowsum(dOiOi)RBrD_i = \text{rowsum}(\bold{dO}_i \circ \bold{O}_i) \in \mathbb{R}^{B_r} 계산
softmax(S)\text{softmax}(\bold{S})를 구하기 위해 정의했던 Di=Pi:dPi:=jeqikjLidoivj=doijeqikjLivj=doioiD_i = P_{i:}^\top dP_{i:} = \sum_j {e^{q_i^\top k_j} \over L_i}do_i^\top v_j = do_i^\top \sum_j {e^{q_i^\top k_j} \over L_i}v_j = do_i^\top o_i 계산
xi.
On chip에서 dSij=Pij(dPijDi)RBr×Bc\bold{dS}_{ij} = \bold{P}_{ij} \circ (\bold{dP}_{ij} - D_i) \in \mathbb{R}^{B_r \times B_c} 계산
구한 DiD_i를 이용하여 dSij\bold{dS}_{ij} 계산
xii.
HBM으로 dQidQi+τdSijKjRBr×d\bold{dQ}_i \leftarrow \bold{dQ}_i + \tau\bold{dS}_{ij}\bold{K}_j \in \mathbb{R}^{B_r \times d} 쓰기
dQ=dSK\bold{dQ} = \bold{dSK}를 따라 dqi=jdSijkj=jPij(dPijDi)kj=jeqikjLi(doivjDi)kjdq_i = \sum_j dS_{ij}k_j = \sum_j P_{ij}(dP_{ij}-D_i)k_j = \sum_j {e^{q_i^\top k_j} \over L_i}(do_i^\top v_j - D_i)k_j 업데이트
xiii.
On chip에서 dK~jdK~j+τdSijQiRBc×d\bold{d}\tilde{\bold{K}}_j \leftarrow \bold{d}\tilde{\bold{K}}_j + \tau \bold{dS}_{ij}^\top \bold{Q}_i \in \mathbb{R}^{B_c \times d} 계산
dK=QdS\bold{dK} = \bold{Q}^\top\bold{dS}를 따라 dki=jdSijqi=jPij(dPijDi)qi=ieqikjLi(doivjDi)qidk_i = \sum_j dS_{ij}q_i = \sum_j P_{ij}(dP_{ij}-D_i)q_i = \sum_i {e^{q_i^\top k_j} \over L_i}(do_i^\top v_j - D_i)q_i 업데이트
d.
end for
e.
HBM으로 dKjdK~j,dVjdV~j\bold{dK}_j \leftarrow \bold{d}\tilde{\bold{K}}_j, \bold{dV}_j \leftarrow \bold{d}\tilde{\bold{V}}_j 쓰기
7.
end for
8.
dQ,dK,dV\bold{dQ}, \bold{dK}, \bold{dV} 반환
forward pass와 유사하게 backward pass는 O(N2)O(N^2) FLOPs를 수행하고 입력, 출력, 출력 gradient, 입력 gradient 외에 O(N)O(N) 추가 메모리만 필요하다.
우리는 backward pass의 IO 복잡도를 forward pass(Theorem 2)와 유사하게 분석한다.
Theorem 5.
NN을 시퀀스 길이, dd를 head 차원, MMdMNdd \le M \le Nd인 SRAM의 크기라 하자. 표준 attention(Algorithm 0) backward pass는 HBM 접근에 Θ(Nd+N2)\Theta(Nd + N^2)을 요구하는 반면 FlashAttention backward pass(Algorithm 4)는 HBM 접근에 Θ(N2d2M1)\Theta(N^2d^2M^{-1})만 요구한다.
증명은 부록 C 참조.

B.5 Comparison with Rabe and Staats

우리는 여기서 FlashAttention 알고리즘과 Rabe와 Staats의 알고리즘 사이의 유사성과 차이점을 설명한다.
개념적으로 FlashAttention과 Rabe와 Staats 모두 잘 알려진 tiling(또는 softmax scaling) 기법을 사용하여 attention 행렬의 block을 연산한다. 메모리 사용량을 줄이기 위해 두 접근은 모두 forward pass에서 큰 attention 행렬의 저장을 피하고 backward pass에서 이것을 recompute 한다.
첫 번째 주요한 차이는 Rabe와 Staats는 전체 메모리 사용량을 줄이는데(필요한 최대 GPU 메모리량) 초점을 맞추는 반면 FlashAttnetion은 메모리 접근을 줄이는데(메모리 읽기/쓰기의 수) 초점을 맞춘다는 것이다. 섹션 2에서 언급한대로, 메모리 접근 양이 실행 시간을 결정하는 주요 요인이다. 메모리 접근을 줄이면 필요한 총 메모리량도 줄어든다(예: 연산이 A회의 메모리 접근을 발생시키면, 전체 메모리 요구량은 최대 A). 결과적으로 FlashAttention은 표준 attention 보다 2-4배 빠른 반면 Rabe와 Staats는 표준 attention과 유사하거나 약간 느리다. 전체 메모리 요구량의 측면에서 두 방법은 모두 상당한 메모리 사용량을 줄인다.
두 번째 차이는 각 블록에서 다음 블록으로 정보를 전달하는 방법에 있다. Rabe와 Staats는 각 블록을 임시 출력과 softmax 정규화 통계량으로 요약한다. forward pass의 끝에서 모든 블록의 임시 출력은 통계량을 사용하여 결합되고 최종 출력을 생성한다. 대신 FlashAttention은 각 블록을 처리한 후에 출력을 점진적으로 업데이트 한다(Algorithm 1의 12번째 줄). 따라서 하나의 출력 복사본만 필요하다(KK개 블록에 대한 KK개 복사본이 아니라). 이것은 FlashAttention이 Rabe와 Staats와 비교하여 전체 메모리 요구량이 더 작다는 것을 뜻한다.
마지막 주요한 차이는 backward pass가 계산되는 방법에 있다. Rabe와 Staats는 gradient checkpointing을 사용하여 attention 행렬과 각 블록의 임시 출력을 recompute한다. FlashAttention은 대신 backward pass를 해석적으로 단순화하여(부록 B.2와 B.4) attention 행렬만 recompute하고 각 블록의 임시 출력을 재계산하지 않는다. 이를 통해 backward pass에 대한 메모리 사용량을 줄이고 속도를 개선한다.

C Proofs

Theorem 1의 증명.
우선 FLOPs의 수와 추가 메모리 요구량을 센다.
FLOPs을 지배하는 것은 행렬 곱이다. 내부 반복에서(Algorithm 1의 9번째 줄), QiRBr×d,KjRBr×d\bold{Q}_i \in \mathbb{R}^{B_r \times d}, \bold{K}_j \in \mathbb{R}^{B_r \times d}에 대해 QiKjRBr×Bc\bold{Q}_i\bold{K}_j^\top \in \mathbb{R}^{B_r \times B_c}를 계산하며 이것은 O(BrBcd)O(B_rB_cd) FLOPs를 취한다. 또한 (Algorithm 1의 12번째 줄) P~ijRBr×Bc,VjRBc×d\tilde{\bold{P}}_{ij} \in \mathbb{R}^{B_r \times B_c}, \bold{V}_j \in \mathbb{R}^{B_c \times d}에 대해 P~ijVjRBr×d\tilde{\bold{P}}_{ij}\bold{V}_j \in \mathbb{R}^{B_r \times d}을 계산하며 이것은 O(BrBcd)O(B_rB_cd) FLOPs를 취한다. 내부 반복에서 TcTr=NBcNBrT_cT_r = \lceil{N\over B_c}\rceil \lceil{N\over B_r}\rceil번 실행한다. 그러므로 전체 FLOPs 수는 다음과 같다.
O(N2BcBrBrBcd)=O(N2d)O\left({N^2 \over B_cB_r}B_rB_cd \right) = O(N^2d)
추가 메모리 요구량 측면에서 통계량 (,m)(\ell, m)을 저장하기 위해 O(N)O(N)이 필요함을 볼 수 있다.
이제 0jTc0 \le j \le T_cjj에 대해 귀납으로 알고리즘의 정확성을 증명한다. K:jRjBc×d\bold{K}_{:j} \in \mathbb{R}^{jB_c \times d}K\bold{K}의 첫 번째 jBcjB_c 행이라 하고 유사하게 V:jRjBc×d\bold{V}_{:j} \in \mathbb{R}^{jB_c \times d}V\bold{V}의 첫 번째 jBcjB_c 행이라 하자. S:,:j=QK:jRN×jBc\bold{S}_{:, :j} = \bold{QK}_{:j}^\top \in \mathbb{R}^{N \times jB_c}이고 P:,:j=softmax(S:,:j)RN×jBc\bold{P}_{:,:j} = \text{softmax}(\bold{S}_{:,:j}) \in \mathbb{R}^{N \times jB_c}(softmax는 행별로 적용됨)이라 하자. mj,(j),O(j)m^j, \ell^{(j)}, \bold{O}^{(j)}을 바깥 반복문의 jj-번째 반복 이후 HBM에서 m,,Om, \ell, \bold{O}이라 하자(Algorithm 1의 5번째 줄). (이러한 m,,Om, \ell, \bold{O}의 값들은 바깥 반복문의 각 반복 이후에 업데이트 된다). 바깥 반복문의 jj-번째 반복 이후를 HBM에서 계산된 것을 보인다.
m(j)=rowmax(S:,:j)RN(j)=rowsum(exp(S:,:jm(j)))RNO(j)=P:,:jV:jRN×d\begin{aligned} m^{(j)} &= \text{rowmax}(\bold{S}_{:,:j}) \in \mathbb{R}^N \\ \ell^{(j)} &= \text{rowsum}(\exp(\bold{S}_{:,:j}-m^{(j)})) \in \mathbb{R}^N \\ \bold{O}^{(j)} &= \bold{P}_{:,:j}\bold{V}_{:j} \in \mathbb{R}^{N\times d} \end{aligned}
초기화(Algorithm 1의 2번째 줄)에 기반하여, 이것은 j=0j=0에 대해 사실임을 주장한다(즉, 바깥 반복문의 반복이 실행되기 전에). 이 주장이 어떤 j=0,...,Tc1j=0,...,T_c-1에 대해 유지된다고 가정하자. j+1j+1에 대해 주장이 유지되는지 알기 원한다. 실제로 바깥 반복문의 (j+1)(j+1) 번째 반복에 대해 내부 반복에서 통계량을 업데이트할 때(Algorithm 1의 10번째 줄), m(j+1)=max(m(j),m~)m^{(j+1)} = \max(m^{(j)}, \tilde{m})을 업데이트한다. 여기서 m~RN\tilde{m} \in \mathbb{R}^NS:,j:j+1\bold{S}_{:,j:j+1}의 row-max이고 열 jBcjB_c에서 열 (j+1)Bc1(j+1)B_c - 1까지 S\bold{S}를 slice한다. 이것은 다음을 암시한다.
m(j+1)=rowmax(S:,:j+1)RNm^{(j+1)} = \text{rowmax}(\bold{S}_{:,:j+1})\in \mathbb{R}^N
유사하게 다음을 업데이트 한다.
(j+1)=em(j)m(j+1)(j)+em~m(j+1)~\ell^{(j+1)} = e^{m^{(j)} - m^{(j+1)}}\ell^{(j)} + e^{\tilde{m}-m^{(j+1)}}\tilde{\ell}
여기서 ~=rowsum(exp(S:,j:j+1m~))RN\tilde{\ell} = \text{rowsum}(\exp(\bold{S}_{:,j:j+1} - \tilde{m})) \in \mathbb{R}^N. 섹션 3.1과 동일한 대수 조작을 통해 다음을 얻는다.
(j+1)=rowsum(exp(S:,:j+1m(j+1)))RN\ell^{(j+1)} = \text{rowsum}(\exp(\bold{S}_{:,:j+1} - m^{(j+1)})) \in \mathbb{R}^N
Vj:j+1\bold{V}_{j:j+1}을 열 jBcjB_c에서 열 (j+1)Bc1(j+1)B_c -1까지 V\bold{V}의 slice라 하자. 또한 다음과 같이 업데이트 한다.
O(j+1)=diag((j+1))1(diag((j))em(j)m(j+1)O(j)+em~m(j+1)exp(Sj:j+1m~)Vj:j+1)=diag((j+1))1(diag((j))em(j)m(j+1)P:,:jV:j+em(j+1)exp(Sj:j+1)Vj:j+1)=diag((j+1))1(diag((j))em(j)m(j+1)diag((j))exp(S:,:jm(j))V:j+em(j+1)exp(Sj:j+1)Vj:j+1)=diag((j+1))1(em(j+1)exp(S:,:j)V:j+em(j+1)exp(Sj:j+1)Vj:j+1)=diag((j+1))1(exp(S:,:jm(j+1))V:j+exp(Sj:j+1m(j+1))Vj:j+1)=diag((j+1))1(exp([S:,:jSj:j+1]m(j+1)))[V:jVj:j+1]=softmax(S:j+1)V:j+1\begin{aligned} \bold{O}^{(j+1)} &= \text{diag}(\ell^{(j+1)})^{-1}(\text{diag}(\ell^{(j)})e^{m^{(j)}-m^{(j+1)}}\bold{O}^{(j)} + e^{\tilde{m} - m^{(j+1)}}\exp(\bold{S}_{j:j+1}-\tilde{m})\bold{V}_{j:j+1}) \\ &= \text{diag}(\ell^{(j+1)})^{-1}(\text{diag}(\ell^{(j)})e^{m^{(j)}-m^{(j+1)}}\bold{P}_{:,:j}\bold{V}_{:j} + e^{- m^{(j+1)}}\exp(\bold{S}_{j:j+1})\bold{V}_{j:j+1}) \\ &= \text{diag}(\ell^{(j+1)})^{-1}(\text{diag}(\ell^{(j)})e^{m^{(j)}-m^{(j+1)}}\text{diag}(\ell^{(j)})\exp(\bold{S}_{:,:j}-m^{(j)})\bold{V}_{:j}+ e^{- m^{(j+1)}}\exp(\bold{S}_{j:j+1})\bold{V}_{j:j+1}) \\ &= \text{diag}(\ell^{(j+1)})^{-1}(e^{-m^{(j+1)}}\exp(\bold{S}_{:,:j})\bold{V}_{:j} + e^{-m^{(j+1)}}\exp(\bold{S}_{j:j+1})\bold{V}_{j:j+1}) \\ &= \text{diag}(\ell^{(j+1)})^{-1}(\exp(\bold{S}_{:,:j}-m^{(j+1)})\bold{V}_{:j} + \exp(\bold{S}_{j:j+1}-m^{(j+1)})\bold{V}_{j:j+1}) \\ &= \text{diag}(\ell^{(j+1)})^{-1}(\exp([\bold{S}_{:,:j} \quad \bold{S}_{j:j+1}]-m^{(j+1)}))\begin{bmatrix} \bold{V}_{:j} \\ \bold{V}_{j:j+1} \end{bmatrix} \\ &= \text{softmax}(\bold{S}_{:j+1})\bold{V}_{:j+1} \end{aligned}
그러면 j+1j+1에 대한 주장 또한 사실임을 확인할 수 있다. 귀납에 의해 모든 j=0,...,Tcj=0,...,T_c에 대해 주장은 사실이다.
j=Tcj=T_c일 때 HBM에서 O\bold{O}의 최종 값이 softmax(S)V=softamx(QK)V\text{softmax}(\bold{S})\bold{V} = \text{softamx}(\bold{QK}^\top)\bold{V}라고 결론내릴 수 있다.
Theorem 2의 증명.
우선 표준 attention 구현의 IO 복잡도를 분석한다. HBM에 존재하는 입력 Q,K,VRN×d\bold{Q}, \bold{K}, \bold{V} \in \mathbb{R}^{N \times d}과 알고리즘 끝에서 HBM으로 쓰여지는 출력 ORN×d\bold{O} \in \mathbb{R}^{N \times d}.
행렬 곱 S=QK\bold{S} = \bold{QK}^\top을 계산하는 첫 번째 단계에서, 입력 Q,K\bold{Q}, \bold{K}는 HBM에서 읽고 출력 SRN×N\bold{S} \in \mathbb{R}^{N \times N}은 HBM으로 쓰여진다(Algorithm 0의 1번째 줄). 이것은 Θ(Nd+N2)\Theta(Nd + N^2) HBM 접근을 발생시킨다.
P=softmax(S)\bold{P} = \text{softmax}(\bold{S})를 계산하는 두 번째 단계에서, 입력 S\bold{S}는 HBM에서 읽고, 출력 P\bold{P}는 HBM으로 쓰여진다(Algorithm 0의 2번째 줄). 이것은 Θ(N2)\Theta(N^2) HBM 접근을 발생시킨다.
O=PV\bold{O} = \bold{PV}를 계산하는 마지막 단계에서, 입력 P,V\bold{P}, \bold{V}는 global 메모리에서 읽고 출력 O\bold{O}는 HBM으로 쓰여진다(Algorithm 0의 3번째 줄). 이것은 Θ(Nd+N2)\Theta(Nd + N^2) HBM 접근을 발생시킨다.
전체적으로 표준 attention 구현은 Θ(Nd+N2)\Theta(Nd + N^2) global 메모리 접근이 필요하다.
이제 streaming attention의 IO 복잡도를 분석한다.
Algorithm 1을 따라 K\bold{K}V\bold{V}의 각 요소는 HBM에서 한 번에 로드되는 것을 볼 수 있다(Algorithm 1의 6번째 줄). Q\bold{Q}O\bold{O}에 대해 TcT_c번 통과한다. 각 통과 시 모든 Q\bold{Q}와 모든 O\bold{O}를 HBM으로 전달하는 각각(Algorithm 1의 8번째 줄). 그러므로 HBM의 접근 수는 Θ(Nd+NdTc)=Θ(NdTc)\Theta(Nd + NdT_c) = \Theta(NdT_c)이다.
블록 크기 BcB_cBrB_r에 대한 조건을 유도한다. 우리는 Bc×dB_c \times d 크기의 블록 Kj\bold{K}_jVj\bold{V}_j를 on-chip 메모리에 맞춰야 한다. 이것은 다음과 같이 표현된다.
Bcd=O(M)Bc=O(Md)B_cd = O(M) \Leftrightarrow B_c = O\left({M\over d}\right)
유사하게 Br×dB_r \times d 크기의 블록 Qi,Oi\bold{Q}_i, \bold{O}_i를 on-chip 메모리에 맞춰야 한다. 이것은 다음과 같이 표현된다.
Brd=O(M)Br=O(Md)B_rd = O(M) \Leftrightarrow B_r = O\left({M\over d} \right)
마지막으로 Br×BcB_r \times B_c 크기의 블록 Sij\bold{S}_{ij}을 on-chip 메모리에 맞춰야 한다. 이것은 다음과 같이 표현된다.
BrBc=O(M)B_rB_c = O(M)
그러므로 다음을 설정한다.
Bc=Θ(Md),Br=Θ(min(Md,MBc))=Θ(min(Md,d))B_c = \Theta\left({M \over d} \right), B_r = \Theta\left(\min\left({M \over d},{M \over B_c} \right) \right) = \Theta\left(\min\left({M \over d},d \right) \right)
그러므로 다음이 성립한다.
Tc=NBc=Θ(NdM)T_c = {N \over B_c} = \Theta\left( {Nd \over M}\right)
결과적으로 HBM의 접근 수는 다음과 같다.
Θ(NdTc)=Θ(N2d2M)\Theta(NdT_c) = \Theta\left( {N^2d^2 \over M}\right)
Proposition 3의 증명.
모순을 위해, 정확한 attention을 계산하는 알고리즘이 존재한다고 가정하자. 여기서 모든 M[d,Nd]M \in [d, Nd]에 대해 HBM 접근 수는 다음과 같다.
o(N2d2M)o\left({N^2d^2 \over M} \right)
M=Θ(Nd)M = \Theta(Nd)의 체제에서 HBM 접근의 수가 발생한다.
o(N2d2Nd)=o(Nd)o\left({N^2d^2 \over Nd} \right) = o(Nd)
그러나 attention에 대한 입력(행렬 Q,K,V\bold{Q}, \bold{K}, \bold{V})와 출력 O\bold{O}NdNd 크기를 갖고 HBM에서 시작한다. 따라서 정확한 attention을 계산하는 알고리즘이 있다면, 적어도 Ω(Nd)\Omega(Nd) HBM 접근이 발생해야 한다. 이것은 모순이다.
Theorem 5의 증명.
attention backward의 IO 복잡도는 attention forward의 IO 복잡도와 매우 유사하다(Theorem 2). 여기서 증명의 스케치를 제공한다.
우선 표준 attention backward pass의 IO 복잡도를 분석한다. 입력 Q,K,V,dORN×d\bold{Q}, \bold{K}, \bold{V}, \bold{dO} \in \mathbb{R}^{N \times d}는 HBM에 존재하고 알고리즘의 끝에서 출력 dQ,dK,dVRN×d\bold{dQ}, \bold{dK}, \bold{dV} \in \mathbb{R}^{N \times d}가 HBM에 쓰여진다.
표준 attention backward pass의 각 단계에서 NdNd 또는 N2N^2 크기의 입력을 HBM에서 로드해야 하고 N2N^2 또는 NdNd 크기의 출력을 HBM에 써야 한다. 이것은 Θ(Nd+N2)\Theta(Nd + N^2) HBM 접근을 발생시킨다.
이제 FlashAttention backward pass의 IO 복잡도를 분석한다.
Theorem 2와 유사하게 K\bold{K}V\bold{V}의 각 요소는 HBM에서 한 번에 로드되고 dK\bold{dK}dV\bold{dV}의 각 요소는 HBM으로 한 번에 쓰여진다. 우리는 Q,O,dO\bold{Q}, \bold{O}, \bold{dO}에 대해 TcT_c번 통과를 수행하며, 각 통과 시 모든 Q,O,dO\bold{Q}, \bold{O}, \bold{dO}을 HBM으로 로드한다. 또한 dQ\bold{dQ}에 대해 TcT_c번 통과를 수행하며 각 통과 시 모든 dQ\bold{dQ}를 HBM에서 읽거나 HBM으로 쓴다. 그러므로 HBM 접근의 수는 Θ(Nd+NdTc)=Θ(NdTc)\Theta(Nd + NdT_c) = \Theta(NdT_c)이다.
Theorem 2의 증명에 따라 block 크기에 대한 제약은 다음과 같다.
Bc=Θ(Md),Br=Θ(min(Md,d))B_c = \Theta\left({M \over d} \right), B_r = \Theta\left(\min\left({M \over d},d \right) \right)
그러면 다음이 성립한다.
Tc=NBc=Θ(NdM)T_c = {N \over B_c} = \Theta\left( {Nd \over M}\right)
결과적으로 HBM 접근의 수는 다음이 된다.
Θ(NdTc)=Θ(N2d2M)\Theta(NdT_c) = \Theta \left( {N^2d^2 \over M}\right)

D Extension Details

D.1 Block-sparse FlashAttention

우리는 전체 block-sparse FlashAttention 알고리즘을 Algorithm 5에 설명한다. 이 알고리즘은 skip zero 블록만 제외하면 Algorithm 2와 동일하다.
우리는 block-sparse FlashAttention의 IO 복잡도를 증명한다.
Proposition 4의 증명.
증명은 Theorem 2의 증명과 매우 유사하다. block-sparse 경우에 대해 nonzero 블록에 해당하는 블록만 로드하면 된다는 것에 유의하라. 결과적으로 HBM 접근의 수는 block-sparsity mask에서 nonzero 블록의 비율인 ss에 의해 scaled 된다. 그러나 ss값이 작은 경우 여전히 결과 ORN×d\bold{O} \in \mathbb{R}^{N \times d}를 써야 한다. 그러므로 HBM 접근의 수는 다음과 같다.
Θ(Nd+N2d2Ms)\Theta\left(Nd + {N^2 d^2 \over M}s \right)
Algorithm 5. Block-Sparse FlashAttention Forward Pass
Require: HBM에서 행렬 Q,K,V,dORN×d\bold{Q}, \bold{K}, \bold{V}, \bold{dO} \in \mathbb{R}^{N \times d}, 크기 MM의 on-chip SRAM, softmax scaling 상수 τR\tau \in \mathbb{R}, masking 함수 MASK\text{MASK}, dropout 확률 pdropp_\text{drop}, block size Bc=M4D,Br=min(M4d,d)B_c = \lceil{M \over 4D} \rceil, B_r = \min(\lceil{M \over 4d}\rceil, d), block sparsity mask M{0,1}N/Br×N/BcM \in \{ 0, 1\}^{N/B_r \times N/B_c}
1.
pseudo-random number 생성기 상태를 R\mathcal{R}로 설정하고 HBM으로 저장
2.
HBM에서 O=(0)N×dRN×d,=(0)NRN,m=()NRN\bold{O} = (0)_{N \times d} \in \mathbb{R}^{N \times d}, \ell=(0)_N \in \mathbb{R}^N, m = (-\infty)_N \in \mathbb{R}^N 초기화
3.
Q\bold{Q}를 각각 Br×dB_r \times d 크기의 Tr=NBrT_r = \lceil {N \over B_r} \rceil개 블록 Q1,...,QT\bold{Q}_1, ..., \bold{Q}_T으로 분할하고, K,V\bold{K}, \bold{V}를 각각 Bc×dB_c \times d 크기의 Tc=NBcT_c = \lceil {N \over B_c} \rceil개 블록 K1,...,KTc\bold{K}_1, ..., \bold{K}_{T_c}V1,...,VTc\bold{V}_1, ..., \bold{V}_{T_c}으로 분할
4.
O\bold{O}를 각각 Br×dB_r \times d 크기의 TrT_r개 블록 Oi,...,OTr\bold{O}_i,..., \bold{O}_{T_r}로 분할, \ell을 각각 BrB_r 크기의 TrT_r개 블록 i,...,Tr\ell_i,...,\ell_{T_r}로 분할 mm을 각각 BrB_r 크기의 TrT_r개 블록 m1,...,mTrm_1,..., m_{T_r}로 분할
5.
for ijTci \le j \le T_c do
a.
HBM에서 on-chip SRAM으로 Kj,Vj\bold{K}_j, \bold{V}_j 로드
b.
for 1iTr1 \le i \le T_r do
i.
if Mij0M_{ij} \ne 0 then
1.
HBM에서 on-chip SRAM으로 Qi,Oi,i,mi\bold{Q}_i, \bold{O}_i, \ell_i, m_i 로드
2.
On chip에서 Sij=τQiKjRBr×Bc\bold{S}_{ij} = \tau\bold{Q}_i\bold{K}_j^\top \in \mathbb{R}^{B_r \times B_c} 계산
3.
On chip에서 Sijmasked=MASK(Sij)\bold{S}_{ij}^\text{masked} = \text{MASK}(\bold{S}_{ij}) 계산
4.
On chip에서 다음을 계산
m~ij=rowmax(Sijmasked)RBr\tilde{m}_{ij} = \text{rowmax}(\bold{S}_{ij}^\text{masked}) \in \mathbb{R}^{B_r}
P~ij=exp(Sijmaskedm~ij)RBr×Bc\tilde{\bold{P}}_{ij} = \exp(\bold{S}_{ij}^\text{masked} - \tilde{m}_{ij}) \in \mathbb{R}^{B_r \times B_c} (pointwise)
~ij=rowsum(P~ij)RBr\tilde{\ell}_{ij} = \text{rowsum}(\tilde{\bold{P}}_{ij}) \in \mathbb{R}^{B_r}
5.
On chip에서 다음을 계산
minew=max(mi,m~ij)RBrm_i^\text{new} = \max(m_i, \tilde{m}_{ij}) \in \mathbb{R}^{B_r}
inew=emiminewi+em~ijminew~ijRBr\ell_i^\text{new} = e^{m_i - m_i^\text{new}}\ell_i + e^{\tilde{m}_{ij} - m_i^\text{new}}\tilde{\ell}_{ij} \in \mathbb{R}^{B_r}
6.
On chip에서 P~ijdropped=dropout(P~ij,pdrop)\tilde{\bold{P}}_{ij}^\text{dropped} = \text{dropout}(\tilde{\bold{P}}_{ij},p_\text{drop})
7.
HBM에 Oidiag(inew)1(diag(i)emiminewOi+em~ijminewP~ijdroppedVj)\bold{O}_i \leftarrow \text{diag}(\ell_i^\text{new})^{-1}(\text{diag}(\ell_i)e^{m_i - m_i^\text{new}}\bold{O}_i + e^{\tilde{m}_{ij} - m_i^\text{new}}\tilde{\bold{P}}_{ij}^\text{dropped}\bold{V}_j) 쓰기
8.
HBM에 iinew,miminew\ell_i \leftarrow \ell_i^\text{new}, m_i \leftarrow m_i^\text{new} 쓰기
ii.
end if
c.
end for
6.
end for
7.
O,,m,R\bold{O}, \ell, m, \mathcal{R} 반환

D.2 Potential Extensions

딥러닝 학습 속도를 높이기 위한 IO-aware의 몇 가지 잠재적 확장을 논의한다.
Multi-GPU Attention.
LLM은 수백 또는 수천 개의 GPU에서 학습되며 일반적으로 동일한 노드에 4-8개 GPUs 사이에 attention 계산을 분할한다. 이로 인해 메모리 계층에 또 다른 수준이 발생한다. GPU SRAM과 GPU HBM 외에도 다른 GPUs의 HBM도 갖는다. 매우 긴 시퀀스에 대해 동일한 노드에 대해 다른 GPUs가 협력하여 다양한 메모리 계층 레벨의 비대칭을 고려하여 attention 계산을 할 수 있다.
Sparse MLP layers.
일반적인 밀집 MLP 레이어는 compute-bound이지 memory-bound 가 아니다. 효율성을 높이기 위해 희소 가중치 행렬을 갖는 MLP 레이어를 사용할 수 있다. 그러나 많은 희소 MLP 레이어는 대신 memory-bound이고 그들의 속도 향상이 희소성에 비례하지 않는다. 우리는 IO-aware 구현이 이 이슈를 완화하고 희소성의 이점을 현실화 할 수 있으리라 믿는다. 우리는 대형 모델의 계산 요구량을 줄이고 실제 시간 실행을 개선하기 위한 이 방향으로의 미래 작업을 기대한다.
Kernel Machine Learning.
FlashAttention에서 접근은 N×NN \times N attention 행렬이 low-rank 행렬 QK\bold{QK}^\top (랭크 dNd \ll N)의 함수이라는 사실에 의존한다. 결과적으로 입력 Q,K\bold{Q}, \bold{K}을 반복적으로 로드하고 필요한 attention 행렬의 블록을 recompute하여 HBM 접근을 크게 감소시킬 수 있다. 유사한 시나리오가 kernel machine learning에서 일어날 수 있다. N×NN \times N 커널 행렬 K\bold{K}의 각 요소 KijK_{ij}는 두 데이터 포인트 xix_ixjx_j 사이의 유사성을 측정하는 dNd \ll N 크기의 두 벡터의 함수이다. KeOps 라이브러리는 메모리 읽기/쓰기를 줄여서 커널 연산의 속도를 높일 수 있음을 보여주는 성공적인 예이다. 우리는 이것이 단지 FLOPs 대신 IOs의 감소에 더 초점을 맞춘 커널 방법에 동기를 부여할 수 있기를 희망한다.