Loss function
Loss function이란 손실 함수로써 모델의 Output이 얼마나 틀렸는지 나타내는 척도이다. 즉, Loss가 작을 수록 좋은 것인데, 모델의 학습과정에서 이 Loss의 미분값을 통해 Back Propagation을 진행한다.
Loss function은 관점에 따라 Cost function이나 Objective Function으로도 부르는데, 구분할 필요가 있다.
- Loss function : 하나의 데이터에 대한 오차를 최소화하기 위해 정의된 함수
- Cost function : 모든 오차를 일반적으로 최소하하기 위해 정의된 함수
- Objective Function : 어떤 값을 최대화 혹은 최소화하기 위해 정의된 함수.
따라서 "Loss function ⊂ Cost function ⊂ Objective function"의 관계이다.
1. MSE(Mean, Squared Error, L2 loss)
회귀 알고리즘에서 많이 쓰이는 MSE는 평균 제곱 오차이다. 즉 모델의 예측 값과 실제 값의 오차를 제곱해 평균한 값이다.
2. MAE(Mean Absolute Error, L1 loss)
MAE는 오차 절대값 평균으로, loss가 크던 작던 항상 gradient가 일정하다. loss가 작아도 gradient는 작지 않기 때문에 경사하강법에서 최적값을 찾는 데에 어려움이 있다. 따라서 MSE가 최적값을 구하기에 보다 적절하지만, MAE가 이상치에 덜 영향을 받는다는 장점이 있다.
3. CEE (Cross Entropy Error)
CEE는 일반적으로 분류 알고리즘에서 사용한다. Pytorch에서의 CEE는내부적으로 softmax activation 이후 Log transformation(Logsoftmax+NLLloss)이 구현되어 있어 raw prediction value를 기대한다.
Entropy에 대해서 먼저 짚고 넘어가면, "불확실성에 대한 척도"이다. 불확실성이 커지면 Entropy가 증가한다.
예를 들어 위 그림처럼 공이 들어있는 두 상자가 있다고 했을 때, 첫 번째 상자는 어떤 공을 꺼내도 파란색이기 때문에 불확실성이 없고, Entropy가 0이다. 하지만 두 번째 상자는 주황색 공과 파란색 공이 나올 확률이 50%이기 때문에 어떤 공이 나올지 예측하기 힘들다.
CEE는 다음과 같이 계산된다.
ex) [개, 고양이, 사자] - 3종의 다른 동물을 분류하는 알고리즘의 Loss function은(CEE에서는 자연로그를 주로 사용한다.)?
정답이 사자일때, 실제값은 [0, 0, 1]이 되고, 예측값은 각 class 별 확률값으로 [0.5, 0.1, 0.3]으로 나왔다. 이 때 각 class 별 실제값과 예측확률값을 사용해 CEE를 계산하면, -{0 * ln(0.5) + 0 * ln(0.1) + 1 * ln(0.3)} = 1.204가 나온다.
직관적으로 봤을 때 실제값이 0인 항들은 0으로 없어지고, 실제값이 1인 class의 probability가 자연로그를 거친다. 이 때 전체 식에 음수가 취해져있어 결국 최종적인 식은 실제값이 1인 class의 probability가 커지도록 학습하고, 다른 class의 probability는 결국 합이 1이기 때문에 작아지게 된다.
4. BCEE(Binary Cross Entropy Error)
BCEE는 주로 이진 분류에 사용된다.
실제값이 0인 경우, 뒷 항만 남게될 때 1xlog(1-probability) 인데, 전체가 음수이므로 이 값이 커지도록 학습하게 되면 결국 실제값이 0인 class의 probability가 작아지도록 학습된다.
실제값이 1인 경우, 앞 항만 남게될 때 1xlog(probabiltiy)인데, 전체가 음수이므로 이 값이 커지도록 학습하게 되면 결국 실제값이 1인 class의 probability가 커지도록 학습된다.
'DL > modules' 카테고리의 다른 글
FLOPs (1) | 2024.01.25 |
---|---|
BPE(Byte Pair Encoding), WordPiece Tokenization (2) | 2024.01.25 |
Normalization의 종류 (0) | 2024.01.16 |