AI

[AI] Mixed Precision Training 이란?

청귤파파 2024. 11. 1. 00:22

Mixed Precision 이란?
처리속도를 높이기 위한 FP16 과, 정확도 유지를 위한 FP32 를 섞어서 학습하는 방법

1. Intro

대부분의 LLM 학습 시 기본으로 사용되고 있는 테크닉으로, FP32(Single Precision) 과 FP16(Half Precision) 을 함께 사용하게 될 때 발생하는 오버플로우 혹은 언더플로우 현상을 해결하기 위한 방법이다.

2. Floating Point

기본적으로 우리가 실수를 표현하는 데 사용하는 방식은 FP32 방식이다. 그 중 제일 많이 쓰는 방식은 1(부호) + 8(지수 - exponent) + 23(가수 - fraction) 의 형태로 총 32bit 를 사용하여 실수를 표현한다. 하지만 메모리 및 계산량을 줄이기 위해서 FP16 을 사용하는 방향을 고려하게 되는데, 당연하게도 기존 FP32 방식보다 실수를 표현할 수 있는 범위가 줄어들게 된다. 이는 주로 값이 너무 큰 경우나, 너무 작은 경우에 문제가 발생하게 되는데 ML 학습 시에는 gradient 값이 매우 작은 경우에 문제가 발생하게 된다.

3. Mixed Precision Training

3.1 Problem

기존에 FP32 로 학습을 진행하면 학습이 잘 되던 모델들이 FP16 으로 했을 때 발산하는 현상이 발생한다. 학습에 관여하는 Tensor 는 아래 4가지로 나눌 수 있는데,

  • activations
  • activation gradients
  • weights
  • weight gradients

이 중 FP16 으로 변경하여 학습을 진행할 때 weights 와 weight gradients 는 대체로 FP16 범위 내에 잘 들어오는 편이었다고 한다. 문제가 생기는 activation gradients 의 경우, 굉장히 크기가 작은 일부 값들이 관측되었고, FP16 으로 표현할 수 있는 범위를 넘어서면서 강제적으로 0이 되어버리는 현상이 발생하였다(underflow). 아래 그래프에서 빨간선보다 왼쪽에 있는 값들이 FP16 으로 변환 시 0이 되어버리는 값들이다. 상당히 많은 양이 영향을 받고 있음을 알 수 있다.

3.2 Solution

가장 간단하게 생각할 수 있는 방법은 값이 작은 것이 문제이니 gradients 에 큰 수를 곱하여 그 값들을 오른쪽으로 shift 하는 방식일 것이다. 그래서 간단하게 8 (scaling factor) 를 곱하여 실험을 돌려보니 SSD(Single Shot MultiBox Detector) 가 학습이 잘 된다는 것을 관측하였다. 이를 바탕으로 상세 procedure 를 나타내면 아래와 같다.

FP32 값은 저장하는 상태로, FP16 값을 이용하여 forward/backward 진행 (+ scale factor)
FP16 결과로 얻은 gradients 를 기반으로 FP32 weights 업데이트

  1. Make an FP16 copy of the weights
  2. Forward propagate using FP16 weights and activations
  3. Multiply the resulting loss by the scale factor S
  4. Backward propagate using FP16 weights, activations, and their gradients
  5. Multiply the weight gradients by 1/S
  6. Optionally process the weight gradients (gradient clipping, weight decay, etc.)
  7. Update the master copy of weights in FP32

4. Results

다양한 tasks 에 대하여 실험을 진행하였는데, cls task 의 경우 accuracy 를 거의 유지하였고, SSD 의 경우 기존 발산하던 것이 잡히는 것을 확인하였다. FP16 을 적용했을 때 메모리는 대략 2배 정도 절감하였고, 수렴을 향한 속도는 2 ~ 4배 정도 향상되었다.

5. References