multi-class classification

다중 클래스 분류

  • 이진(Binary Class) 분류 : 종속변수의 클래스가 2개인 분류 문제

  • 다중 클래스(Multi-Class) 분류 : 종속변수의 클래스가 3 개 이상인 분류문제

    • OvO 혹은 OvR 방법을 통해 여러 개의 이진 클래스 분류문제로 변환해서 푼다

OvO (One-vs-One)

: K개의 타겟 클래스가 존재할 때, 그 중 2개씩 선택해 이진 클래스 분류 문제를 풀고, 그 결과로 가장 많은 판별값을 얻은 클래스를 선택하는 방법

  • 풀어야 하는 이진 클래스분류 문제의 수 : $ _KC_2 $

두 개의 클래스씩 비교했을 때 선택받은 횟수의 총합으로 비교하면 여러 클래스가 동점이 나오는 tie case가 발생할 수 있기 때문에 각 클래스가 얻은 조건부 확률값을 모두 더한 값으로 비교하면 그 문제가 해결된다.

OneVsOneClassifier 클래스를 사용하면 이진 클래스용 모형을 OvO 방법으로 다중 클래스용 모형으로 변환한다.

1
2
3
4
5
from sklearn.multiclass import OneVsOneClassifier
from sklearn.linear_model import LogisticRegression

model_ovo = OneVsOneClassifier(LogisticRegression()).fit(iris.data, iris.target)
#2진 모형 인스턴스를 만들고 OvO로 wrapping하면 내부에서 세 번 경합 실시

각 클래스가 얻는 조건부 확률값을 합한 값을 decision_function으로 출력한다.

1
2
3
4
5
6
7
8
ax1 = plt.subplot(211)
pd.DataFrame(model_ovo.decision_function(iris.data)).plot(ax=ax1, legend=False)
plt.title("판별 함수")
ax2 = plt.subplot(212)
pd.DataFrame(model_ovo.predict(iris.data), columns=["prediction"]).plot(marker='o', ls="", ax=ax2)
plt.title("클래스 판별")
plt.tight_layout()
plt.show()

​ 0(파란색), 1(주황색), 2(초록색)의 총 3개 클래스로 데이터가 판별되었고, 총 세 개의 데이터가 잘못 예측되었음이 확인 가능하다.

OvR (One-vs-the-Rest)

: 클래스 개수가 K개이면 풀어야할 이진분류 문제가 K의 제곱에 비례하여 많아지는 OvO와 달리, OvR은 K개의 문제를 풀면 되기 때문에 훨씬 빠르고 효율적이다.

클래스 A, B, C가 있을 때,

  • A vs B,C 즉, A vs A$^C$
  • B vs A,C 즉, B vs B$^C$
  • C vs A,B 즉, C vs C$^C$

이렇게 3 번 이진문제를 푸는데, OvR에서도 판별 결과의 수가 같은 동점 문제가 발생할 수가 있기 때문에 각 클래스가 얻은 조건부 확률값을 더하여 그 값이 +가 나오면 해당 클래스고 -가 나오면 해당 클래스가 아니라고 판단한다. 결과적으로는 +값이 나온 클래스들 중 가장 그 값이 큰 클래스를 정답으로 예측한다.

OneVsRestClassifier 클래스를 사용하면 이진 클래스용 모형을 OvR 방법으로 다중 클래스용 모형으로 변환한다.

1
2
3
4
5
6
7
8
9
10
11
12
13
from sklearn.multiclass import OneVsRestClassifier
from sklearn.linear_model import LogisticRegression

model_ovr = OneVsRestClassifier(LogisticRegression()).fit(iris.data, iris.target)

ax1 = plt.subplot(211)
pd.DataFrame(model_ovr.decision_function(iris.data)).plot(ax=ax1, legend=False)
plt.title("판별 함수")
ax2 = plt.subplot(212)
pd.DataFrame(model_ovr.predict(iris.data), columns=["prediction"]).plot(marker='o', ls="", ax=ax2)
plt.title("클래스 판별")
plt.tight_layout()
plt.show()

​ 그래프를 보면 클래스 판별 예측에 실패한 데이터의 수가 약 6개로 OvO보다 예측 성능이 조금 떨어지는 것을 볼 수 있다. 하지만 현실적으로 클래스가 많아지면 OvO는 아예 쓸 수가 없기 때문에 OvR을 쓰는 것이 일반적이다.

Share