B Algorithm Details
우리는 우선 attention의 forward와 backward pass를 유도하고 이들이 메모리 효율적인 방법(시퀀스 길이에 따라 2차 대신 선형적인 추가 메모리 요구)에서 계산될 수 있음을 보인다. 비록 추가 메모리 요구의 양이 줄지만 단순히 구현하면 여전히 2차적 HBM 접근이 발생하여 결과적으로 느린 실행 속도를 갖는다. 우리는 GPU에서 HBM 접근을 줄이기 위해 forward와 backward pass 모두를 구현하는 FlashAttention 알고리즘을 설명한다. 이것은 더 따른 실행 속도와 더 적은 메모리 사용량을 이끈다.
B.1 Memory-efficient forward pass
attention을 메모리 효율적으로 만드는 주요 도전은 의 컬럼(과 의 컬럼)을 커플링하는 softmax이다. 우리의 접근은 softmax normalization 상수를 별도로 계산하여 컬럼을 decouple 한다. 이 기법은 몇 문헌에서 attention 계산에 2차적 추가 메모리가 필요하지 않음을 보이기 위해 사용되었다(그러나 HBM 접근 수가 여전히 2차적이어서 실행은 느리다.)
단순성을 위해 softmax 중 max-shifting 단계를 생략한다. 전체 알고리즘은 모든 단계를 포함하는 부록 B.3 참조.
입력 시퀀스 가 주어질 때, attention 출력 를 계산하기를 원한다는 것을 떠올려라.
를 갖는다. 여기서 와 는 각각 와 의 -번째 행과 -번째 열이다. softmax의 정규화 상수를 다음과 같이 정의한다.
를 의 -번째 열이라고 하자. 그러면 출력의 -번째 열은 다음과 같다.
가 한 번 계산되면 추가 메모리 없이 를 반복적으로 합산하여 를 계산할 수 있음을 볼 수 있다. 그러므로 forward 패스는 추가 메모리를 사용하여 계산될 수 있다.
1.
방정식 (1)을 따라 모든 에 대해 을 계산한다. 이것은 의 추가 메모리를 취한다.
2.
방정식 (2)를 따라 모든 에 대해 를 계산한다. 이것은 의 추가 메모리를 취한다.
B.2 Memory-efficient backward pass
우리는 attention의 backward pass를 유도하고 이것 또한 선형 메모리로 계산될 수 있음을 보인다. Rabe와 Staats는 메모리 효율적인 forward pass에 gradient checkpointing을 적용하여 backward pass가 2차적 추가 메모리 없이 계산될 수 있다고 제안했다. 우리는 대신 backward pass를 명시적으로 유도하고 메모리 효율적인 방법으로 계산할 수 있는 방법을 보인다.
스칼라 손실 함수 가 있다고 하고 출력 gradient가 라고 하자. (여기서 는 를 표기한다.) 우리는 입력 gradient 를 계산하기 원한다. (여기서 는 각각 를 표기한다.)
gradient 는 쉽게 볼 수 있다. 역방향 자동미분을 수동으로(chain rule 이라고도 함) 적용하면 (행렬표기에서) 를 얻는다. 따라서
임의의 행렬곱 에 대해 gradient는 각각 와 으로 주어진다. 즉, 역행렬이 아니라 전치행렬로 gradient를 구할 수 있다.
가 이미 계산되었기 때문에 를 합산을 반복하여 추가 메모리 없이 계산할 수 있다.
gradient 와 는 약간 더 복잡하다. 우선 gradient 와 를 살펴보자. 방정식 (2)에서 이 성립한다. 따라서
임을 떠올려라. 의 야코비안이 라는 점을 사용하여 다음을 갖는다.
여기서 는 pointwise 곱을 표기한다.
이므로 의 gradient를 구하려면 softmax의 야코비안 과 의 gradient 를 모두 구해서 곱해야 한다.
다음을 정의한다.
그러면
그러므로
이제 와 의 gradient를 얻을 수 있다. 를 떠올려라. 따라서
에 대해 gradient는 각각 와 으로 주어진다.
유사하게
그러므로 backward pass도 추가 메모리로 계산될 수 있다.
1.
방정식 (3)을 따라 모든 에 대해 를 계산한다. 이것은 의 추가 메모리를 취한다.
2.
방정식 (4)을 따라 모든 에 대해 를 계산한다. 이것은 의 추가 메모리를 취한다.
3.
방정식 (5)을 따라 모든 에 대해 를 계산한다. 이것은 의 추가 메모리를 취한다.
4.
방정식 (6)을 따라 모든 에 대해 를 계산한다. 이것은 의 추가 메모리를 취한다.
B.3 FlashAttention: Forward Pass
FlashAttention forward pass의 전체 세부사항을 설명한다. 입력 시퀀스 가 주어지면 attention 출력 을 계산하기를 원한다.
여기서 은 어떤 softmax scaling(일반적으로 )이고 MASK는 입력의 일부 항목을 로 설정하고 나머지는 동일하게 두는 어떤 masking 함수이다(예: batch의 시퀀스 길이가 동일하지 않을 때 key padding mask). 은 에 요소별로 dropout을 적용한다(예: 각 요소 에 대해 확률 의 출력 와 확률 의 출력 )
전체 알고리즘은 알고리즘 2 참조. 우리는 출력 , softmax 통계량 과 과 backward pass를 위한 pseudo-random number 생성기 상태 을 저장한다.
Algorithm 2. FlashAttention Forward Pass
Require: HBM에서 행렬 , 크기 의 on-chip SRAM, softmax scaling 상수 , masking 함수 , dropout 확률
1.
pseudo-random number 생성기 상태 을 초기화하고 HBM에 저장
2.
block 크기 설정
3.
HBM에 초기화
4.
를 각각 크기의 개 블록 으로 분할하고, 를 각각 크기의 개 블록 과 으로 분할
5.
를 각각 크기의 개 블록 로 분할, 을 각각 크기의 개 블록 로 분할 을 각각 크기의 개 블록 로 분할
6.
for do
a.
를 HBM에서 on-chip SRAM으로 로드
b.
for do
i.
를 HBM에서 on-chip SRAM으로 로드
ii.
On chip에서 계산
iii.
On chip에서 계산
iv.
On chip에서 다음을 계산
•
•
(pointwise)
•
v.
On chip에서 다음을 계산
•
•
vi.
On chip에서
vii.
HBM에 쓰기
viii.
HBM에 쓰기
c.
end for
7.
end for
8.
반환
B.4 FlashAttention: Backward Pass
FlashAttention의 Backward pass의 전체 세부사항을 설명한다. 입력 시퀀스 , 출력 과 출력 gradient 가 주어지면 입력 gradient 를 계산하기를 원한다.
완결성을 위해 우선 Algorithm 3에서 표준 attention backward pass를 설명한다.
Algorithm 3. 표준 Attention Backward pass
Require: HBM에서 행렬
1.
HBM에서 블록 별로 를 로드하고 를 계산하고 를 HBM에 쓴다.
2.
HBM에서 블록 별로 를 로드하고 를 계산하고 를 HBM에 쓴다.
3.
HBM에서 를 읽고 을 계산하고 (여기서 , 를 HBM에 쓴다.
4.
HBM에서 블록별로 와 를 로드하고, 를 계산하고, 를 HBM에 쓴다.
5.
HBM에서 블록별로 와 를 로드하고, 를 계산하고, 를 HBM에 쓴다.
6.
반환
이제 FlashAttention backward pass에 관한 2가지 관찰을 한다.
1.
forward pass에서 크기의 dropout mask를 저장할 필요가 없다. 대신 forward pass에서 pseudo-random number 생성기 상태를 저장하고 backward pass에서 dropout mask를 re-generate 할 수 있다. 이것은 오직 의 추가 메모리만 사용한다.
2.
softmax gradient를 계산할 때, 방정식 (4)를 사용하여 크기 의 와 에 대한 축소하지 않고 를 계산한다(그것들은 SRAM으로 맞춰지지 않는다). 대신 로 재작성하고 크기 의 벡터 사이의 점곱을 계산할 수 있다.
FlashAttention backward pass 알고리즘은 Algorhtm 4에 있다. 개념적으로 이것은 부록 B.2의 파생의 block 버전이다.
Algorithm 4. FlashAttention Backward pass
Require: HBM에서 행렬 , HBM에서 벡터 , 크기 의 on-chip SRAM, softmax scaling 상수 , masking 함수 , dropout 확률 , forward pass에서 pseudo-random number 생성기 상태
1.
pseudo-random number 생성기 상태를 로 설정
2.
block 크기 설정
3.
를 각각 크기의 개 블록 으로 분할하고, 를 각각 크기의 개 블록 과 으로 분할
4.
를 각각 크기의 개 블록 로 분할, 를 각각 크기의 개 블록 로 분할, 을 각각 크기의 개 블록 로 분할 을 각각 크기의 개 블록 로 분할
5.
HBM에서 를 초기화하고 각각 크기의 블록 로 분할. HBM에서 를 초기화하고 각각 크기의 개 블록 와 로 분할.
6.
for do
a.
HBM에서 on-chip SRAM으로 로드
b.
SRAM에서 초기화
c.
for do
i.
HBM에서 on-chip SRAM으로 로드
ii.
On chip에서 계산
•
forward와 동일하게 계산
iii.
On chip에서 계산
•
forward와 동일하게 계산
iv.
On chip에서 계산
•
과 이 forward에서 계산되었으므로 를 바로 계산
v.
On chip에서 dropout mask 계산. 여기서 각 항은 의 확률로 의 값을 갖고 의 확률로 의 값을 가짐
•
forward의 dropout을 적용하기 위한 dropout mask 계산
vi.
On chip에서 계산 (pointwise 곱)
•
계산. 여기까지가 forward 계산을 recompute한 부분.
vii.
On chip에서 계산
•
을 따라 로 업데이트
viii.
On chip에서 계산
•
를 따라 로 업데이트
ix.
On chip에서 계산 (pointwise 곱)
•
dropout 적용
x.
On chip에서 계산
•
를 구하기 위해 정의했던 계산
xi.
On chip에서 계산
•
구한 를 이용하여 계산
xii.
HBM으로 쓰기
•
를 따라 업데이트
xiii.
On chip에서 계산
•
를 따라 업데이트
d.
end for
e.
HBM으로 쓰기
7.
end for
8.
반환
forward pass와 유사하게 backward pass는 FLOPs를 수행하고 입력, 출력, 출력 gradient, 입력 gradient 외에 추가 메모리만 필요하다.
우리는 backward pass의 IO 복잡도를 forward pass(Theorem 2)와 유사하게 분석한다.
Theorem 5.
을 시퀀스 길이, 를 head 차원, 을 인 SRAM의 크기라 하자. 표준 attention(Algorithm 0) backward pass는 HBM 접근에 을 요구하는 반면 FlashAttention backward pass(Algorithm 4)는 HBM 접근에 만 요구한다.
증명은 부록 C 참조.
B.5 Comparison with Rabe and Staats
우리는 여기서 FlashAttention 알고리즘과 Rabe와 Staats의 알고리즘 사이의 유사성과 차이점을 설명한다.
개념적으로 FlashAttention과 Rabe와 Staats 모두 잘 알려진 tiling(또는 softmax scaling) 기법을 사용하여 attention 행렬의 block을 연산한다. 메모리 사용량을 줄이기 위해 두 접근은 모두 forward pass에서 큰 attention 행렬의 저장을 피하고 backward pass에서 이것을 recompute 한다.
첫 번째 주요한 차이는 Rabe와 Staats는 전체 메모리 사용량을 줄이는데(필요한 최대 GPU 메모리량) 초점을 맞추는 반면 FlashAttnetion은 메모리 접근을 줄이는데(메모리 읽기/쓰기의 수) 초점을 맞춘다는 것이다. 섹션 2에서 언급한대로, 메모리 접근 양이 실행 시간을 결정하는 주요 요인이다. 메모리 접근을 줄이면 필요한 총 메모리량도 줄어든다(예: 연산이 A회의 메모리 접근을 발생시키면, 전체 메모리 요구량은 최대 A). 결과적으로 FlashAttention은 표준 attention 보다 2-4배 빠른 반면 Rabe와 Staats는 표준 attention과 유사하거나 약간 느리다. 전체 메모리 요구량의 측면에서 두 방법은 모두 상당한 메모리 사용량을 줄인다.
두 번째 차이는 각 블록에서 다음 블록으로 정보를 전달하는 방법에 있다. Rabe와 Staats는 각 블록을 임시 출력과 softmax 정규화 통계량으로 요약한다. forward pass의 끝에서 모든 블록의 임시 출력은 통계량을 사용하여 결합되고 최종 출력을 생성한다. 대신 FlashAttention은 각 블록을 처리한 후에 출력을 점진적으로 업데이트 한다(Algorithm 1의 12번째 줄). 따라서 하나의 출력 복사본만 필요하다(개 블록에 대한 개 복사본이 아니라). 이것은 FlashAttention이 Rabe와 Staats와 비교하여 전체 메모리 요구량이 더 작다는 것을 뜻한다.
마지막 주요한 차이는 backward pass가 계산되는 방법에 있다. Rabe와 Staats는 gradient checkpointing을 사용하여 attention 행렬과 각 블록의 임시 출력을 recompute한다. FlashAttention은 대신 backward pass를 해석적으로 단순화하여(부록 B.2와 B.4) attention 행렬만 recompute하고 각 블록의 임시 출력을 재계산하지 않는다. 이를 통해 backward pass에 대한 메모리 사용량을 줄이고 속도를 개선한다.
C Proofs
Theorem 1의 증명.
우선 FLOPs의 수와 추가 메모리 요구량을 센다.
FLOPs을 지배하는 것은 행렬 곱이다. 내부 반복에서(Algorithm 1의 9번째 줄), 에 대해 를 계산하며 이것은 FLOPs를 취한다. 또한 (Algorithm 1의 12번째 줄) 에 대해 을 계산하며 이것은 FLOPs를 취한다. 내부 반복에서 번 실행한다. 그러므로 전체 FLOPs 수는 다음과 같다.
추가 메모리 요구량 측면에서 통계량 을 저장하기 위해 이 필요함을 볼 수 있다.
이제 의 에 대해 귀납으로 알고리즘의 정확성을 증명한다. 를 의 첫 번째 행이라 하고 유사하게 를 의 첫 번째 행이라 하자. 이고 (softmax는 행별로 적용됨)이라 하자. 을 바깥 반복문의 -번째 반복 이후 HBM에서 이라 하자(Algorithm 1의 5번째 줄). (이러한 의 값들은 바깥 반복문의 각 반복 이후에 업데이트 된다). 바깥 반복문의 -번째 반복 이후를 HBM에서 계산된 것을 보인다.
초기화(Algorithm 1의 2번째 줄)에 기반하여, 이것은 에 대해 사실임을 주장한다(즉, 바깥 반복문의 반복이 실행되기 전에). 이 주장이 어떤 에 대해 유지된다고 가정하자. 에 대해 주장이 유지되는지 알기 원한다. 실제로 바깥 반복문의 번째 반복에 대해 내부 반복에서 통계량을 업데이트할 때(Algorithm 1의 10번째 줄), 을 업데이트한다. 여기서 은 의 row-max이고 열 에서 열 까지 를 slice한다. 이것은 다음을 암시한다.
유사하게 다음을 업데이트 한다.
여기서 . 섹션 3.1과 동일한 대수 조작을 통해 다음을 얻는다.
을 열 에서 열 까지 의 slice라 하자. 또한 다음과 같이 업데이트 한다.
그러면 에 대한 주장 또한 사실임을 확인할 수 있다. 귀납에 의해 모든 에 대해 주장은 사실이다.
일 때 HBM에서 의 최종 값이 라고 결론내릴 수 있다.
Theorem 2의 증명.
우선 표준 attention 구현의 IO 복잡도를 분석한다. HBM에 존재하는 입력 과 알고리즘 끝에서 HBM으로 쓰여지는 출력 .
행렬 곱 을 계산하는 첫 번째 단계에서, 입력 는 HBM에서 읽고 출력 은 HBM으로 쓰여진다(Algorithm 0의 1번째 줄). 이것은 HBM 접근을 발생시킨다.
를 계산하는 두 번째 단계에서, 입력 는 HBM에서 읽고, 출력 는 HBM으로 쓰여진다(Algorithm 0의 2번째 줄). 이것은 HBM 접근을 발생시킨다.
를 계산하는 마지막 단계에서, 입력 는 global 메모리에서 읽고 출력 는 HBM으로 쓰여진다(Algorithm 0의 3번째 줄). 이것은 HBM 접근을 발생시킨다.
전체적으로 표준 attention 구현은 global 메모리 접근이 필요하다.
이제 streaming attention의 IO 복잡도를 분석한다.
Algorithm 1을 따라 와 의 각 요소는 HBM에서 한 번에 로드되는 것을 볼 수 있다(Algorithm 1의 6번째 줄). 와 에 대해 번 통과한다. 각 통과 시 모든 와 모든 를 HBM으로 전달하는 각각(Algorithm 1의 8번째 줄). 그러므로 HBM의 접근 수는 이다.
블록 크기 와 에 대한 조건을 유도한다. 우리는 크기의 블록 와 를 on-chip 메모리에 맞춰야 한다. 이것은 다음과 같이 표현된다.
유사하게 크기의 블록 를 on-chip 메모리에 맞춰야 한다. 이것은 다음과 같이 표현된다.
마지막으로 크기의 블록 을 on-chip 메모리에 맞춰야 한다. 이것은 다음과 같이 표현된다.
그러므로 다음을 설정한다.
그러므로 다음이 성립한다.
결과적으로 HBM의 접근 수는 다음과 같다.
Proposition 3의 증명.
모순을 위해, 정확한 attention을 계산하는 알고리즘이 존재한다고 가정하자. 여기서 모든 에 대해 HBM 접근 수는 다음과 같다.
의 체제에서 HBM 접근의 수가 발생한다.
그러나 attention에 대한 입력(행렬 )와 출력 는 크기를 갖고 HBM에서 시작한다. 따라서 정확한 attention을 계산하는 알고리즘이 있다면, 적어도 HBM 접근이 발생해야 한다. 이것은 모순이다.
Theorem 5의 증명.
attention backward의 IO 복잡도는 attention forward의 IO 복잡도와 매우 유사하다(Theorem 2). 여기서 증명의 스케치를 제공한다.
우선 표준 attention backward pass의 IO 복잡도를 분석한다. 입력 는 HBM에 존재하고 알고리즘의 끝에서 출력 가 HBM에 쓰여진다.
표준 attention backward pass의 각 단계에서 또는 크기의 입력을 HBM에서 로드해야 하고 또는 크기의 출력을 HBM에 써야 한다. 이것은 HBM 접근을 발생시킨다.
이제 FlashAttention backward pass의 IO 복잡도를 분석한다.
Theorem 2와 유사하게 와 의 각 요소는 HBM에서 한 번에 로드되고 와 의 각 요소는 HBM으로 한 번에 쓰여진다. 우리는 에 대해 번 통과를 수행하며, 각 통과 시 모든 을 HBM으로 로드한다. 또한 에 대해 번 통과를 수행하며 각 통과 시 모든 를 HBM에서 읽거나 HBM으로 쓴다. 그러므로 HBM 접근의 수는 이다.
Theorem 2의 증명에 따라 block 크기에 대한 제약은 다음과 같다.
그러면 다음이 성립한다.
결과적으로 HBM 접근의 수는 다음이 된다.
D Extension Details
D.1 Block-sparse FlashAttention
우리는 전체 block-sparse FlashAttention 알고리즘을 Algorithm 5에 설명한다. 이 알고리즘은 skip zero 블록만 제외하면 Algorithm 2와 동일하다.
우리는 block-sparse FlashAttention의 IO 복잡도를 증명한다.
Proposition 4의 증명.
증명은 Theorem 2의 증명과 매우 유사하다. block-sparse 경우에 대해 nonzero 블록에 해당하는 블록만 로드하면 된다는 것에 유의하라. 결과적으로 HBM 접근의 수는 block-sparsity mask에서 nonzero 블록의 비율인 에 의해 scaled 된다. 그러나 값이 작은 경우 여전히 결과 를 써야 한다. 그러므로 HBM 접근의 수는 다음과 같다.
Algorithm 5. Block-Sparse FlashAttention Forward Pass
Require: HBM에서 행렬 , 크기 의 on-chip SRAM, softmax scaling 상수 , masking 함수 , dropout 확률 , block size , block sparsity mask
1.
pseudo-random number 생성기 상태를 로 설정하고 HBM으로 저장
2.
HBM에서 초기화
3.
를 각각 크기의 개 블록 으로 분할하고, 를 각각 크기의 개 블록 과 으로 분할
4.
를 각각 크기의 개 블록 로 분할, 을 각각 크기의 개 블록 로 분할 을 각각 크기의 개 블록 로 분할
5.
for do
a.
HBM에서 on-chip SRAM으로 로드
b.
for do
i.
if then
1.
HBM에서 on-chip SRAM으로 로드
2.
On chip에서 계산
3.
On chip에서 계산
4.
On chip에서 다음을 계산
•
•
(pointwise)
•
5.
On chip에서 다음을 계산
•
•
6.
On chip에서
7.
HBM에 쓰기
8.
HBM에 쓰기
ii.
end if
c.
end for
6.
end for
7.
반환
D.2 Potential Extensions
딥러닝 학습 속도를 높이기 위한 IO-aware의 몇 가지 잠재적 확장을 논의한다.
Multi-GPU Attention.
LLM은 수백 또는 수천 개의 GPU에서 학습되며 일반적으로 동일한 노드에 4-8개 GPUs 사이에 attention 계산을 분할한다. 이로 인해 메모리 계층에 또 다른 수준이 발생한다. GPU SRAM과 GPU HBM 외에도 다른 GPUs의 HBM도 갖는다. 매우 긴 시퀀스에 대해 동일한 노드에 대해 다른 GPUs가 협력하여 다양한 메모리 계층 레벨의 비대칭을 고려하여 attention 계산을 할 수 있다.
Sparse MLP layers.
일반적인 밀집 MLP 레이어는 compute-bound이지 memory-bound 가 아니다. 효율성을 높이기 위해 희소 가중치 행렬을 갖는 MLP 레이어를 사용할 수 있다. 그러나 많은 희소 MLP 레이어는 대신 memory-bound이고 그들의 속도 향상이 희소성에 비례하지 않는다. 우리는 IO-aware 구현이 이 이슈를 완화하고 희소성의 이점을 현실화 할 수 있으리라 믿는다. 우리는 대형 모델의 계산 요구량을 줄이고 실제 시간 실행을 개선하기 위한 이 방향으로의 미래 작업을 기대한다.
Kernel Machine Learning.
FlashAttention에서 접근은 attention 행렬이 low-rank 행렬 (랭크 )의 함수이라는 사실에 의존한다. 결과적으로 입력 을 반복적으로 로드하고 필요한 attention 행렬의 블록을 recompute하여 HBM 접근을 크게 감소시킬 수 있다. 유사한 시나리오가 kernel machine learning에서 일어날 수 있다. 커널 행렬 의 각 요소 는 두 데이터 포인트 와 사이의 유사성을 측정하는 크기의 두 벡터의 함수이다. KeOps 라이브러리는 메모리 읽기/쓰기를 줄여서 커널 연산의 속도를 높일 수 있음을 보여주는 성공적인 예이다. 우리는 이것이 단지 FLOPs 대신 IOs의 감소에 더 초점을 맞춘 커널 방법에 동기를 부여할 수 있기를 희망한다.