본문 바로가기
AI/Deep-Learning

[Keras] 다중 클래스 분류 손실 함수

by Wikinist 2023. 9. 12.

sparse_categorical_crossentropy와 categorical_crossentropy는 딥러닝 모델을 훈련시킬 때 주로 사용되는 손실 함수(loss function) 중 두 가지입니다. 이 두 손실 함수는 주로 다중 클래스 분류 문제를 해결할 때 쓰이며, 모델의 출력과 실제 레이블 간의 차이를 측정하여 모델을 훈련시키는 데 사용됩니다. 둘 다 크로스 엔트로피(cross-entropy)를 기반으로 합니다.

categorical_crossentropy

categorical_crossentropy는 주로 원-핫 인코딩(One-Hot Encoding) 형식의 레이블을 사용하는 다중 클래스 분류 문제에 적합합니다.
원-핫 인코딩은 각 클래스에 대해 하나의 레이블을 선택하고, 선택한 클래스의 인덱스에 1을 할당하고 나머지 클래스에는 0을 할당하는 방식입니다.
이 손실 함수는 모델의 출력 확률 분포와 실제 원-핫 인코딩된 레이블 간의 차이를 측정합니다.
모델의 출력은 클래스에 속할 확률 분포로, 예를 들어 10개의 클래스가 있는 경우, 출력 벡터의 크기는 10이 됩니다.

sparse_categorical_crossentropy

sparse_categorical_crossentropy는 정수 형태의 레이블을 사용하는 다중 클래스 분류 문제에 주로 사용됩니다.
정수 레이블은 클래스의 인덱스를 직접 나타냅니다. 예를 들어, 0부터 9까지의 클래스를 나타내는 정수 레이블을 사용할 때 유용합니다.
이 손실 함수는 모델의 출력 확률 분포와 실제 정수 형태의 레이블 간의 차이를 측정합니다.
모델의 출력은 여전히 클래스에 속할 확률 분포로, 출력 벡터의 크기는 클래스의 개수와 같습니다.
예를 들어, 이미지 분류 문제에서 categorical_crossentropy는 원-핫 인코딩된 레이블을 사용하고, sparse_categorical_crossentropy는 정수 형태의 레이블을 사용합니다. 선택하는 것은 데이터의 레이블 표현 방식에 따라 달라집니다.

categorical_crossentropy 사용 예제

이미지 분류:
예를 들어, 10개의 다른 종류의 과일을 분류하는 이미지 분류 모델을 만든다고 가정합니다.
각 이미지에 대한 레이블은 원-핫 인코딩으로 표현됩니다. 예를 들어, 사과는 [1, 0, 0, 0, 0, 0, 0, 0, 0, 0], 바나나는 [0, 1, 0, 0, 0, 0, 0, 0, 0, 0]과 같이 표현됩니다.
이 모델을 훈련할 때 categorical_crossentropy 손실 함수를 사용하여 모델의 출력과 원-핫 인코딩된 레이블 간의 차이를 최소화합니다.

from keras.models import Sequential
from keras.layers import Dense
from keras.losses import categorical_crossentropy

model = Sequential()
model.add(Dense(10, activation='softmax', input_shape=(input_dim,)))
model.compile(optimizer='adam', loss=categorical_crossentropy, metrics=['accuracy'])

sparse_categorical_crossentropy 사용 예제

텍스트 분류: 
문서를 여러 범주 또는 토픽으로 분류하는 텍스트 분류 모델을 만든다고 가정합니다.
각 문서에는 해당하는 범주 또는 토픽의 정수 레이블이 할당됩니다.
이 모델을 훈련할 때 sparse_categorical_crossentropy 손실 함수를 사용하여 모델의 출력과 정수 레이블 간의 차이를 최소화합니다.

from keras.models import Sequential
from keras.layers import Embedding, LSTM, Dense
from keras.losses import sparse_categorical_crossentropy

model = Sequential()
model.add(Embedding(input_dim=vocab_size, output_dim=embedding_dim, input_length=max_seq_length))
model.add(LSTM(64))
model.add(Dense(num_classes, activation='softmax'))
model.compile(optimizer='adam', loss=sparse_categorical_crossentropy, metrics=['accuracy'])

위 예제에서 categorical_crossentropy는 원-핫 인코딩된 레이블에 사용되고, sparse_categorical_crossentropy는 정수 형태의 레이블에 사용됩니다. 데이터의 레이블 표현 방식에 따라 적절한 손실 함수를 선택합니다.

해당 게시글은 ChatGPT의 도움을 받아 작성되었습니다.