Search
Duplicate

AI/ Focal Loss

Focal Loss

Focal Loss는 데이터셋에 Class 불균형이 발생하는 경우에 Cross Entropy Loss를 개선한 Loss 함수이다. 논문의 저자들의 설명에 따라 우선 Cross entropy에 대한 설명부터 시작한다.

Cross Entropy

우선 binary classification의 cross entropy 식은 다음과 같이 정의된다는 것을 떠올리자.
CE=(ylog(p)+(1y)log(1p))\text{CE} = -(y\log (p) + (1-y) \log (1-p))
여기서 yy는 실제 라벨이고, pp는 모델이 입력에 대해 해당 라벨일 확률을 예측한 값이다.
한편 클래스 multi-class인 경우 one-hot 인코딩 방식을 이용하여 cross entropy를 다음과 같이 정의한다.
CE=i=1Cyilog(p^i)\text{CE} = -\sum_{i=1}^C y_i \log (\hat{p}_i)
여기서 yiy_i는 one-hot 벡터 y\bold{y}ii번째 요소로 해당 클래스가 정답이면 11 그렇지 않으면 00이 된다. p^i\hat{p}_i는 모델이 입력에 대해 해당 클래스일 확률을 예측한 결과가 된다.
따라서 위 식은 ii가 정답 클래스가 아니면 모두 00이 되어 loss에 영향을 미치지 않고 ii가 정답 클래스인 경우에만 loss가 계산된다. 때문에 cross entropy 식에서 yiy_i를 제거하고 target에 대한 모델의 예측 확률 ptp_t만 남겨서 아래와 같이 간단히 표기하는 경우가 있는데 (Focal Loss 저자들이 위와 같이 표기함) 여기서는 설명을 위해 전체 식을 기준으로 적상한다.
CE(p,y)=CE(pt)=log(pt)\text{CE}(p,y) = \text{CE}(p_t) = -\log (p_t)

α\alpha-balanced Cross Entropy

학습 데이터셋의 Class가 균형잡혀 있다면 cross entropy loss는 잘 동작하겠지만, 현실에서 class가 불균형한 경우가 많으며, 이 경우 모델은 제대로 학습되지 못한다. 예컨대 현실에서 머리가 긴 남자는 꽤 드물기 때문에 일반적으로 샘플링된 남자/여자 데이터셋으로 모델을 학습하면 모델은 다른 특징을 보지 않고 머리가 길면 여자로 분류할 가능성이 높아진다.
이를 해결하기 위해 클래스에 대해 별도의 파라미터 α\alpha를 도입한 α\alpha-balanced CE loss를 정의할 수 있다. 여기서 α\alpha는 클래스 빈도의 역수이거나 cross validation에 의해 설정될 수 있다.
CE=αi=1Cyilog(p^i)\text{CE} = -\alpha\sum_{i=1}^C y_i \log (\hat{p}_i)
α\alpha가 positive/negative 예제의 균형을 잡지만, easy/hard 예제를 구별하지 않기 때문에 쉽게 분류되는 negative 샘플들이 log의 대부분을 차지할 수 있다.

Focal Loss

easy 예제의 가중치를 낮추고, hard negative를 학습하는데 초점을 맞추기 위해 Focal Loss의 저자들은 다음과 같이 정의되는 Focal Loss 함수를 제안했다.
FL=i=1Cyi(1p^i)γlog(p^i)\text{FL} = -\sum_{i=1}^C y_i (1-\hat{p}_i)^\gamma \log (\hat{p}_i)
이것은 α\alpha-balanced CE loss에서의 α\alpha 대신 저자들이 modulating factor라 부른 (1p^i)γ(1-\hat{p}_i)^\gamma을 추가한 형식이다. 여기서 (1p^i)(1-\hat{p}_i)은 모델이 입력을 잘못 예측한 확률이고 γ0\gamma \ge 0는 focusing 파라미터로 조절 가능하다.
위 식에서 만일 정답 클래스를 맞추면 즉 p^i1\hat{p}_i \to 1로 감에 따라 (1p^i)(1-\hat{p}_i)00에 가까워지므로(여기서 γ\gamma(1p^i)(1-\hat{p}_i)가 줄어드는 속도를 조절하는 역할을 한다) 최종 loss는 cross entropy 보다 훨씬 더 작아진다. 이것은 confidence가 높은 easy 예제의 loss를 빠르게 줄이는 효과를 발생시킨다.
한편 위 식에서 정답 클래스를 틀리면 즉 p^i0\hat{p}_i \to 0으로 감에 따라 (1p^i)(1-\hat{p}_i)11에 가까워지므로 최종 loss는 cross entropy와 유사한 크기가 된다. 이것은 confidence가 낮은 hard 예제의 loss를 높게 유지하는 효과를 발생시킨다.
이로 인해 focal loss는 cross entropy에 비해 easy 예제의 loss를 더 작게하고, hard 예제의 loss를 유지하여 학습의 초점이 hard 예제에 맞춰지도록 한다.
저자들은 modulating factor가 쉬운 예제의 loss 기여를 줄이면서 낮은 loss를 받는 예제의 범위를 확장하는 역할을 한다고 설명한다. 예컨대 γ=2\gamma=2일 때 pt=0.9p_t = 0.9로 분류되는 예제는 CE와 비교하여 100배 낮은 loss를 갖고 pt0.968p_t \approx 0.968이면 CE와 비교하여 1000배 낮은 loss를 갖는다.
위의 Focal Loss 형식에 α\alpha-balanced를 적용한 다음과 같은 형태도 가능하다.
FL=αi=1Cyilog(p^i)(1p^i)γ\text{FL} = -\alpha\sum_{i=1}^C y_i \log (\hat{p}_i)(1-\hat{p}_i)^\gamma
저자들은 이것이 원래의 FL에 비해 약간의 개선이 있었지만 FL의 정확한 형식은 크게 중요하지 않고 이것과 유사한 다른 형식도 얼마든지 가능하다고 얘기한다.

참고