왜 크로스-엔트로피를 사용할까?

(본 글은 유투브 영상을 글로 재구성 한 것입니다.)

머신러닝/딥러닝을 하다보면 크로스-엔트로피(cross-entropy)가 자주 등장합니다. 정확히는, 분류(classification) 문제를 풀 때 크로스-엔트로피를 이용해서 손실(loss, cost) 함수를 정의하곤 하죠.

우선 크로스-엔트로피 식을 한번 살펴볼까요? 정의역이 같은 두 랜덤변수 $X$, $Y$가 있고, 각각의 확률질량함수(PMF)를 $p$, $q$라 할 때, 크로스-엔트로피 $\text{CE}(X, Y)$는 다음과 같이 정의됩니다:

$$ \begin{align} \text{CE}(X, Y) &= - \mathbb{E}_{x \sim p(x)} \left[ \log q(x) \right] \\ &= - \sum_{x} p(x) \log q(x). \end{align} $$

이 식, 그리고 엔트로피 자체에 대해선 여러가지 관점으로 해석 할 수 있습니다. 왜 엔트로피 관련 식에는 항상 확률의 로그가 등장하는가? 같은 얘기들 말이죠. 이는 언젠가 기회가 되면 글로 써보도록 하겠습니다.

여기서는 어떠한 분포간의 "차이"에 집중을 해보도록 하겠습니다. 머신러닝/딥러닝은 대부분 정답 분포를 모사하는것이 목표가 되고, 얼마나 근접했는지를 수치화해서 피드백을 주는 방식으로 구성되어있습니다. 즉, 어떤 문제를 풀고 싶고, 그 문제엔 (정확히는 알 수 없으나) 어떠한 미지의 진분포라는게 존재하고, 우리의 머신러닝/딥러닝 모델은 그 진분포를 모사해주는 확률 모델인 것이죠.

🐍 사족을 조금 더 붙이자면, 정답 분포가 신경망으로 구성된 분포인지 아닌지는 모르겠으나, 우리는 신경망을 이용해서 해당 분포를 잘 모사할 수 있다고 생각을 한 것이죠 (이를 '보편 근사 정리'라고 합니다). 모델의 세부 파라미터는 학습 하면 되겠지 하는 것이죠. 이 '모델링'이라는 개념이 굉장히 중요합니다. 예컨대, 힘에 관련된 물리법칙을 우린 보통 $F = ma$로 '모델링' 합니다. 현실이 정말 그런 원리인지는 모르겠지만, 이 '모델'로 충분히 근사가 되다면 우리 머리로 이해하고 다룰 수 있게 해주는 강력한 도구 역할을 해줍니다.

다시 분포의 차이 얘기로 돌아와서, 보통 두 분포간의 차이를 얘기할땐 $\text{KL}$-발산(Kullback-Leibler divergence)이라는 식을 자주 사용합니다. 이 $\text{KL}$-발산의 식은 다음과 같이 정의됩니다:

$$ D_{\text{KL}}(p \parallel q) = \sum_{x} p(x) \log \frac{p(x)}{q(x)}. $$

두 확률의 비율에다가 로그를 씌웠습니다. 어떻게 해석해야할지 조금 난감한 식이죠. 보통은 이를 "가짜분포 $q$가 진분포 $p$를 얼마나 잘 따라했는가"에 대한 값을 나타낸다고 해석을 합니다. "차이" 느낌의 값이기 때문에, 작을수록 좋은 상황입니다. 실제로 항상 $p(x) = q(x)$라면, 위 값은 0이 됩니다. 다만 이 식이 담고 있는 함의를 느끼기엔 다소 어색한 감이 좀 남아있죠. 잘 살펴보면, 확률 비율과 관련된 무언가의 기댓값이라고 생각할 수도 있습니다:

$$ \sum_{x} p(x) \log \frac{p(x)}{q(x)} = \mathbb{E}_{x \sim p(x)} \left[ \log \frac{p(x)}{q(x)} \right]. $$

그러나 여전히 뭔가 와닿진 않네요. 예를 한번 들어봅시다. $x \in \{a, b\}$라 하고, $p(a) = 1/3, p(b) = 2/3, q(a) = 2/3, q(b) = 1/3$로 둡시다. 이때 $\text{KL}$-발산을 계산하면 다음과 같은 계산이 됩니다:

$$ \begin{align} D_{\text{KL}}(p \parallel q) &= \sum_{x} p(x) \log \frac{p(x)}{q(x)} \\ &= \frac{1}{3} \log \frac{1}{2} + \frac{2}{3} \log 2 \\ &= - \frac{1}{3} \log 2 + \frac{2}{3} \log 2. \end{align} $$

보면, $q(x)$가 $p(x)$보다 클 때, 즉 넉넉하게 예상했을때는 패널티를 깎아줍니다 (얼마나 큰 양수인가가 패널티인 상황입니다). 반면, $q(x)$가 $p(x)$보다 작은 상황, 즉 부족하게 예상을 하면 패널티가 발생합니다. 그리고 $p(x)$나 $q(x)$나 확률이라서 모든 $x$에 대해서 합을 구하면 항상 1이기 때문에, 필연적으로 이 확률값을 적절히 분배해야하는 상황이 되고, 따라서 어딘가에선 $q(x)$가 $p(x)$보다 크고 또 어딘가에선 작게 됩니다. 그리고 패널티의 가중치는 $p(x)$로 주는것이기 때문에, 우리는 지금 "부족하게 예상을 함"에 대해서 큰 패널티를 주고 있는 셈입니다. 이게 꼭 중요한가? 라고 생각할수도 있지만, 일단 높은 확률값이 나와야할때 잘 나와주는게 중요하다고 생각합시다. 그게 더 '잘 모사된 상황이다'라고 두는것이죠. 그럼 어쨌거나 이 값을 작게 하는것은 충분한 목표가 되어줄 것입니다.

🐍 이 $\text{KL}$-발산은 두 분포간의 차이를 나타내주는데, 그렇다고 거리(metric)가 되진 못합니다. $\text{KL}$-발산이 항상 0 이상이라는 것은 젠센 부등식을 사용해서 증명 할 수 있지만 (직접 한번 해보세요!), 대칭적이지 않기 때문에 거리의 조건을 만족하지 않습니다. 그러나 때로는 꼭 거리가 아니더라도 충분할때가 있는 법이죠.

위 식을 살짝 바꿔봅시다:

$$ \begin{align} D_{\text{KL}}(p \parallel q) &= \sum_{x} p(x) \log \frac{p(x)}{q(x)} \\ &= \sum_{x} p(x) \log p(x) - \sum_{x} p(x) \log q(x) \\ &= - \text{H}(X) + \text{CE}(X, Y). \end{align} $$

약간의 변형을 통해 이 $\text{KL}$-발산과 엔트로피($\text{H}$), 그리고 크로스-엔트로피($\text{CE}$)가 서로 깊은 관계가 있는 식들임을 알 수 있습니다.

머신러닝/딥러닝의 원래 목표인 '분포 따라하기'를 생각해보면, 진분포 $p$가 있고 이를 우리의 확률 모델 $q$가 따라하도록 해야합니다. 즉, 기가막히게도 $\text{KL}(p \parallel q)$가 줄어들면 되는, 딱 들어맞는 상황이죠. 그래서 이 $\text{KL}(p \parallel q)$가 작아지도록 경사 하강법(gradient descent) 등을 사용해서 파라미터들을 학습하면 됩니다.

그런데 분류 문제를 풀 때엔 상황이 조금 재밌어집니다. 보통 분류 문제를 딥러닝으로 풀려고 할 땐, 뭔가 입력 $i$가 있고, 신경망이 마구 나오고, 마지막에 가서는 분류 해야 하는 클래스 수 만큼의 차원으로 보낸 뒤, 소프트맥스 함수를 취합니다. 이러한 구조가 일반적인데, 소프트맥스의 역할은 $C$개의 실수값을 합이 1인 분포값들로 바꿔주는 것입니다. 즉, 가장 마지막 레이어의 결과값은 곧 이 인풋이 각 분류 클래스일 확률을 알려주는 것이죠!

분류 클래스가 1번부터 $C$번까지 있고, 고정된 인풋 $i$에 대해, 소프트맥스 레이어까지 포함된 우리의 신경망 모델 MLP의 분류 결과가 곧 (1부터 $C$까지의 값을 가질 수 있는) 랜덤변수 $Y_{i}$라고 해봅시다. 조금 복잡해보일 수 있는데, 수식으로 쓰면 다음과 같습니다:

$$ \text{MLP}(i) = \begin{pmatrix} P(Y_{i} = 1) \\ P(Y_{i} = 2) \\ \vdots \\ P(Y_{i} = C) \\ \end{pmatrix} . $$

다시 말하자면, 일반적인 분류 신경망 모델의 가장 마지막에 있는 소프트맥스 레이어를 통과하면 보통 $C$개의 확률값들이 나오고, 이는 현 인풋이 각 분류 클래스일 확률을 나타내주는 일종의 분포가 됩니다. 여기서 확률값이 가장 높은 번호를 우리는 $i$의 클래스로 예측하는것이죠.

그런데 잘 생각해보면, $i$의 진짜 클래스는 이미 정해져있습니다. 가령, 고양이 같은거죠. 고양이 사진을 주고 우리의 신경망 모델이 이를 고양이라고 답하도록 학습하고 있었으니, 정답은 이미 정해져있습니다. 즉, $i$의 진분포는 정답이 하나인 분포입니다!

🐍 자꾸 사족이 튀어나오네요.. 정답이 하나인 분포는 연속 확률 분포에선 디랙 델타 함수로밖에 표현 할 수 없습니다. 다행히도 분류 문제는 가산(countable) 이상의 범위를 갖는 경우가 없죠.

🐍 또 사족을 붙이자면, '정답은 이미 정해져있다'라고 했는데, 과연 그럴까요? 머신러닝/딥러닝에서 가장 자주 나오는 얘기가 바로 데이터 정제입니다. 오답으로 태깅된 데이터가 너무 많다는것이죠. 근데 그 중에는, 인간의 실수가 아니라 인간의 합의가 문제인 경우도 있습니다. 패션이 대표적인데요, 어떤 옷 스타일이 클래식한걸까요 내추럴한걸까요? 애매합니다. 이렇게 인간 사이의 '합의'를 수치로 나타내는 것이 바로 카파 상관계수(Cohen's Kappa Coefficient)이며, 보통 어떤 학습 모델의 정확도가 이 값을 넘어버리면 큰 의미가 없다고 판단을 합니다. 인간끼리도 합의가 안되었는데 기계가 답을 내버리다니 말도 안되죠 (농답입니다). 이에 대해선 별도로 글을 쓸 예정입니다.

다시 $i$의 진분포 얘기로 돌아가자면, 예컨대 $i$의 분류 클래스는 1번이라고 합시다. 통일성 있는 표기를 위해, $i$의 분류 클래스가 될 수 있는 랜덤변수를 $X_{i}$라고 합시다. 그럼 이때의 $i$의 클래스에 대한 진분포는 다음과 같습니다:

$$ \begin{align} P(X_{i} = 1) = 1 &, \\ P(X_{i} = c) = 0 & \quad (c \ne 1). \end{align} $$

분포라고 하기도 뭐하죠. 엔트로피로 치면 0이 나올겁니다. 헉 제가 방금 뭐라고 그랬죠? 엔트로피가 0이 나옵니다:

$$ \begin{align} \text{H}(X_{i}) &= - \sum_{c = 1}^{C} P(X_{i} = c) \log P(X_{i} = c) \\ &= - P(X_{i} = 1) \log P(X_{i} = 1) \\ &= -1 \cdot \log 1 \\ &= 0. \end{align} $$

위에서 분명 $\text{KL}$-발산과 엔트로피 관련된 식을 유도한 적이 있었는데...

$$ D_{\text{KL}}(p \parallel q) = - \text{H}(X) + \text{CE}(X, Y). $$

그렇습니다. 지금 상황에선, $X_{i}$가 진짜 클래스값이고 $Y_{i}$는 가짜 클래스값, 즉 예측값이죠. 그렇기 때문에 우리가 원하는 상황은, 두 랜덤변수의 클래스 분포 차이 즉 $\text{KL}$-발산이 줄어들어야하는 상황인데, 진분포 역할인 $X_{i}$의 클래스 분포의 엔트로피가 0이 나와버립니다. 진분포는 정답이 한개밖에 없는 상황이었기 때문이죠!

그럼 이 상황에선, 저 엔트로피 항을 없애도 됩니다:

$$ D_{\text{KL}}(p \parallel q) = \text{CE}(X, Y). $$

깔끔하게도, 그토록 많이 나오던 크로스-엔트로피 식으로 귀결되었습니다.

저 크로스-엔트로피 식도 한번 전개를 해볼까요? 고맙게도 $X_{i}$는 1번 클래스를 제외하고는 전부 확률값이 0이기 때문에, 클래스가 1번인 경우를 제외하고는 항이 모두 사라집니다:

$$ \begin{align} \text{CE}(X_{i}, Y_{i}) &= - \sum_{c = 1}^{C} P(X_{i} = c) \log P(Y_{i} = c) \\ &= - P(X_{i} = 1) \log P(Y_{i} = 1) \\ &= - \log P(Y_{i} = 1). \end{align} $$

신경망의 손실 함수를 계산하다보면 자주 보게 되는 식이죠. "정답 클래스의 확률값만 냅두고 나머진 버린 후, 해당 확률값은 로그를 씌워서 음수로 만든다!"를 코딩해본 기억이 있을것입니다. PyTorch의 NLLLoss가 해주는것이 바로 이 식이죠. 방금 말한대로, 이 식은 음수를 빼고 말하자면 로그-가능도(log-likelihood)가 됩니다. 이 가능도를 최대화 하는것은 어디선가 본적이 있지 않나요? 베이지안 통계를 한다는 것은.

한번 정리를 하고 가자면, 우리의 목표는 자연에 있는 어떠한 진분포를, 신경망이 표현해줄 가짜분포를 사용해서 모사 하고 싶고, 얼마나 모사를 잘 했는지는 $\text{KL}$-발산이 얼마나 낮은가로 표현할 수 있으며, 분류 문제에선 진분포가 특이한 모양이라서 크로스-엔트로피를 사용해도 동일한 값이 나온다, 가 됩니다. 여기까지 이해 하셨나요?

그런데 왜 꼭 크로스-엔트로피의 형태를 고집한걸까요? 이를 알아보기 위해선 우선 젠센 부등식(Jensen's Inequallity)이 필요한데, 이 부등식에 대한 자세한 내용은 여기서는 다루지 않겠습니다. 이미 긴 글이 너무나도 길어질것 같아서...

젠센 부등식은 보통 (1) 두 점 사이의 점 버전, (2) 2번 미분이 가능한 함수 버전 등이 있지만 조금 더 일반화된 버전으로는 (3) 기댓값을 사용하는 버전이 있습니다. 볼록함수 $f$와 확률질량함수 $p$에 대해 다음이 성립합니다:

$$ f(\mathbb{E}_{x \sim p(x)} \left[ x \right]) \leq \mathbb{E}_{x \sim p(x)} \left[ f(x) \right]. $$

좀 더 익숙한 버전은 역시 두 점 사이의 평균 버전이죠:

$$ f\left( \frac{a + b}{2} \right) \leq \frac{f(a) + f(b)}{2}. $$

위 기댓값 버전은 이 평균 버전을 조금 일반화 시킨 버전에 불과합니다.

🐍 그럼 볼록함수라는건 뭘까요? 대충 봐서는 그래프로 그렸을때, 아래에서 콕 찌르면 볼록하게 반응하는게 볼록함수 입니다. 이런 느낌으로, 평균점의 함숫값보다 함숫값의 평균이 더 큰걸 볼록함수라고 합니다. 뭔가 정의가 재귀적이죠? 사실 젠센 부등식을 만족하는걸 볼록함수로 정의하기도 합니다. 수학에서의 정의는 이처럼 때때로 왔다갔다 합니다. 본질만 바뀌지 않는다면요! 이에 대해선 나름의 역사가 있지만, '볼록하다'라는것의 본질을 추상화 하다보니 껍데기는 가고 저런 부등식만 남았다고 이해하셔도 됩니다.

🐍 정보이론(information theory)은 부등식의 학문이다라고 해도 과언이 아닐겁니다. 그런데 이런 부등식을 공부할때 가장 중요한건, 어떤 상황에서 부등식이 등식이 되는가? 하는 물음입니다. 젠센 부등식은 어떤 상황에서 등식이 될까요? 저 일반화된 버전에서 한번 고민해봅시다. 등식이 되는 상황이 뜻하는 바가 무엇인지, 그 함의가 핵심입니다.

크로스-엔트로피 얘기가 더 중요하니까 젠센 부등식 얘기는 여기까지만 하도록 하겠습니다. 중요한건, 우리가 다룰 함수 $f(x) = -\log(x)$는 볼록함수라는 것입니다.

크로스-엔트로피 식을 다시 써봅시다. 이번엔 이 $-\log(x)$가 볼록이라는 것에 집중을 해서 식을 변형 할 건데요, 이런 사고 방식을 일명 wishful thinking이라고 하죠 (정말 사족이 많네요 제가 생각해도). 함수 $f(x) = -\log(x)$로 나타내봅시다. 편의상 $X_{i}$의 확률질량함수를 $p$, $Y_{i}$의 확률질량함수를 $q$라고 하면:

$$ \begin{align} \text{CE}(X_{i}, Y_{i}) &= - \sum_{c = 1}^{C} P(X_{i} = c) \log P(Y_{i} = c) \\ &= \sum_{c = 1}^{C} p(c) \left( - \log q(c) \right) \\ &= \sum_{c = 1}^{C} p(c) f \left( q(c) \right) \\ &= \mathbb{E}_{c \sim p(c)} \left[ f \left( q(c) \right) \right] \\ &\geq f \left( \mathbb{E}_{c \sim p(c)} \left[ q(c) \right] \right) \quad (\text{Jensen}) \\ &= f \left( \sum_{c = 1}^{C} p(c)q(c) \right). \end{align} $$

뭔가 많은 일을 했습니다. 중간에 젠센 부등식을 사용한 부분이 보이시나요? 분포 $p$를 사용해서 기댓값을 걸어준것입니다. 젠센 부등식을 사용 할 수 있도록, 식을 바라보는 시선을 약간 바꾼것이죠!

맨 아래에 저 함수 $f$ 안에 들어있는 부분은 과연 무엇일까요? 이게 핵심인데, 이 식을 잘 읽어보면 다음과 같은 뜻이 됩니다: 분포 $p$와 $q$가 같은 사건일때의 확률값 (곱셈공식 → 덧셈공식), 즉, $X_{i}$와 $Y_{i}$가 같은 값을 가질 확률.

엄청나지 않나요? 즉, 다음과 같다는 것입니다:

$$ f \left( \sum_{c = 1}^{C} p(c)q(c) \right) = f \left( P(X_{i} = Y_{i}) \right). $$

함수 $f$가 계속 나오는것이 조금 지저분하니까, 크로스-엔트로피를 이용해서 식을 정리 해봅시다:

$$ \begin{align} & \text{CE}(X_{i}, Y_{i}) \geq f \left( P(X_{i} = Y_{i}) \right) \\ \Leftrightarrow \quad & \text{CE}(X_{i}, Y_{i}) \geq - \log P(X_{i} = Y_{i}) \\ \Leftrightarrow \quad & - \text{CE}(X_{i}, Y_{i}) \leq \log P(X_{i} = Y_{i}) \\ \Leftrightarrow \quad & e^{- \text{CE}(X_{i}, Y_{i})} \leq P(X_{i} = Y_{i}). \end{align} $$

마지막에 엄청난 식이 나와버렸습니다. 바로 두 랜덤변수가 같을 확률의 최솟값을 구해버린것이죠! 랜덤변수 $X_{i}$는 진짜 클래스 번호, $Y_{i}$는 우리의 모델이 예측한 클래스 번호입니다. 그 두 번호가 같을 확률, 즉 정답을 맞췄을 확률의 하한(lower bound)을 알아낸 것이죠. 정답률에 하한이 있다는 것이고, 그 하한은 바로 크로스-엔트로피에 의해 결정된다는 것입니다. 지수 함수의 모양을 생각해보면, 저 크로스-엔트로피 값이 낮을수록 좌변이 커집니다. 그리고 크로스-엔트로피가 0이 되는 순간, 좌변이 1이 되죠. 이 때 두 랜덤변수가 같은 값을 가질 확률 역시 1이 됩니다. 진분포를 완벽하게 모사해낸 순간이 됩니다.

놀랍지 않나요? 크로스-엔트로피를 낮춘다는 목표는 추상적인 손실 함수가 아닙니다. 정말로 분류 문제에서 클래스 번호를 정확하게 맞출 확률의 하한을 담당하는, 굉장히 중요한 값입니다. 하한이 있다는 것은, 정답률을 어느정도 보장해준다는 것입니다. 이제 신경망 학습을 할 때 손실 함수 값이 1이 나온다면, 현재 정답률은 적어도 $1/e$, 즉 대략 36%는 되겠구나 하는 것을 알 수 있겠죠? 왜 손실 함수 값의 '단위'가 중요한지도요.

참고로 머신러닝에선 이처럼 확률값의 하한을 학습의 목표로 삼는 경우가 종종 있습니다. 더 궁금하다면 ELBO(Evidence Lower Bound)를 한번 검색해보세요!