Search
Duplicate

AI/ Kernel Fusion, Recomputation

Kernel Fusion과 Recomputation은 모두 메모리 사용량을 줄이는 소프트웨어 기법을 의미한다. 이것은 GPU와 같이 계산 자체는 빠르지만, SRAM과 HBM 사이의 메모리 전송이 병목이 되는 하드웨어 환경에서 성능을 높이는 주요한 기법으로 사용된다. 어차피 GPU에서 계산 능력은 남기 때문에 그 남는 계산 능력을 추가로 사용해서 SRAM과 HBM 사이의 메모리 전송을 아끼고 전체 계산 성능을 높인다는 개념이다.
이를 위해 이 기법들은 저장을 최대한 줄이고 코드 블럭 내에서 연산을 많이 수행하여 메모리 사용을 줄이도록 노력한다.

Kernel Fusion

Kernel Fusion이란 여러 커널 호출을 코드 상에서 하나의 커널 호출로 결합해서 최적화하는 소프트웨어 기법이다. 이를 통해 연산의 중간 결과를 메모리에 저장하지 않고 한 번에 연산하여 메모리 접근을 최적화하고 커널 호출 횟수를 줄이고 캐시 효율성을 증가시킬 수 있다.

Example

간단한 예로 다음과 같이 2가지 kernel이 있고 각각 순차적으로 처리된다고 할 때
__global__ void addKernel(float *A, float *B, float *C, int N) { int idx = blockIdx.x * blockDim.x + threadIdx.x; if (idx < N) { C[idx] = A[idx] + B[idx]; } } __global__ void funcKernel(float *C, float *D, int N) { int idx = blockIdx.x * blockDim.x + threadIdx.x; if (idx < N) { D[idx] = func(C[idx]); } }
C++
복사
다음처럼 하나의 코드로 합쳐서 처리하는 것을 kernel fusion이라고 한다.
__global__ void fusedKernel(float *A, float *B, float *D, int N) { int idx = blockIdx.x * blockDim.x + threadIdx.x; if (idx < N) { D[idx] = func(A[idx] + B[idx]); } }
C
복사

Recomputation

Recomputation이란 메모리 사용량을 줄이기 위해 계산을 더하는 소프트웨어 기법을 의미한다. 예컨대 딥러닝에서 backpropagation에서는 forward pass에서 계산된 중간 결과를 저장하여 역전파에 사용하는데, 모든 중간 결과를 저장하는 것이 매우 큰 모델이나 메모리가 제한된 환경에서는 비효율적이다.
이를 개선하기 위해 순전파 단계에서 모든 중간 출력을 저장하지 않고 일부만 저장한 다음, 역전파 단계에서 필요한 부분에 대해 순전파를 다시 계산(recomputation) 하여 gradient를 계산한다.

Example

Pytorch에서 역전파에 대해 recomputation을 사용하기 위해 다음과 같이 function.Function을 상속받는 클래스를 정의하고 forward와 backward 함수를 구현한다.
backward 단계의 grad_output은 RecomputationLayer의 출력을 받은 다음 레이어에서 역전파 단계에서 전파해 준 gradient에 해당한다. 만일 RecomputationLayer의 출력이 2개 였다면 backward 단계에서 2개 파라미터를 받아 처리하도록 구현해야 한다.
class RecomputationLayer(function.Function): @staticmethod def forward(ctx, input, weight, bias): # forward 패스에서 사용된 텐서를 저장 ctx.save_for_backward(input, weight, bias) # forward 단계에서 필요한 계산 수행 return input.mm(weight.t()) + bias @staticmethod def backward(ctx, grad_output): # 저장된 텐서를 복원 input, weight, bias = ctx.saved_tensors with torch.enable_grad(): input.requires_grad_() weight.requires_grad_() bias.requires_grad_() # 역전파 단계에서 recomputation recomputed_output = input.mm(weight.t()) + bias # 역전파 그래디언트를 계산 grad_input = torch.autograd.grad(recomputed_output, input, grad_output, retain_graph=True)[0] grad_weight = torch.autograd.grad(recomputed_output, weight, grad_output, retain_graph=True)[0] grad_bias = torch.autograd.grad(recomputed_output, bias, grad_output, retain_graph=True)[0] # 최종적으로 역전파 결과를 반환 return grad_input, grad_weight, grad_bias
Python
복사
위와 같이 구현된 코드를 모델에서 아래처럼 사용할 수 있다. 이 경우 forward 단계에서 RecomputationLayer.apply()을 사용한 부분은 역전파 때 해당 부분을 이용해서 역전파되고, 그 다음 relu()fc2()는 해당 결과를 받아 일반적인 방법으로 역전파 된다.
class SimpleNN(nn.Module): def __init__(self, input_size, hidden_size, output_size): super(SimpleNN, self).__init__() self.input_size = input_size self.hidden_size = hidden_size self.output_size = output_size self.fc1_weight = nn.Parameter(torch.randn(hidden_size, input_size)) self.fc1_bias = nn.Parameter(torch.randn(hidden_size)) self.fc2 = nn.Linear(hidden_size, output_size) def forward(self, x): x = RecomputationLayer.apply(x, self.fc1_weight, self.fc1_bias) x = torch.relu(x) x = self.fc2(x) return x
Python
복사