[Python] 의사결정나무(DecisionTree) 구현 - 분류(Classifier)/회귀(Regressor)/가지치기(Pruning)

2020. 7. 18. 21:35ML in Python/Python

 

 의사결정 나무는 간단하게 말해서 if~else와 같이 특정 조건을 기준으로 O/X로 나누어 분류/회귀를 진행하는 tree구조의 분류/회귀 데이터마이닝 기법이다. 

 이해도가 매우 높고 직관적이라는 장점이 있다. 그렇기에 많이 사용되며, 의사결정나무도 많은 머신러닝 기법과 동일하게 종속변수의 형태에 따라 분류와 회귀 문제로 나뉜다.

 종속변수가 범주형일 경우 Decision Tree Classification으로 분류를 진행하고, 종속변수가 연속형일 경우 Decision Tree Regression으로 회귀를 진행한다. 

 

 상세한 원리와 수학적/직관적 이해는 아래 링크를 통해서 학습하길 바란다. 

https://todayisbetterthanyesterday.tistory.com/39

 

[Data Analysis 개념] Decision Tree(의사결정나무) 모형 - Classification/Regression Tree의 직관적/수학적 이해

 이 게시글에서는 Decision Tree의 개념만 다룰 것이다. Python으로 구현하고자 한다면 아래 실습링크를 통해서 학습하길 바란다. https://todayisbetterthanyesterday.tistory.com/38 [Python] 의사결정나무(Dec..

todayisbetterthanyesterday.tistory.com


 

1. 기본적인 의사결정나무의 형태 

# sklearn 모듈의 tree import
from sklearn import tree

# 간단한 데이터셋 생성
X = [[0, 0], [1, 1]]               
Y = [0, 1]

# 의사결정나무 적합 및 학습데이터 예측
clf = tree.DecisionTreeClassifier()
clf = clf.fit(X, Y)
clf.predict([[1, 1]])

 위의 코드블럭은 가장 간단한 의사결정나무의 예시이다. 학습데이터에 2개의 설명변수만 사용하여 2행짜리 데이터를 생성하였다. 종속변수 또한 설명변수가 2행이기에 2개밖에 안된다. 

 이제 sklearn.datasets에 존재하는 iris 데이터를 통해서 구현하고 시각화 작업을 함께 해보자.


2. 라이브러리 import & 실습 데이터 로드

# sklearn 모듈의 tree import
from sklearn import tree
from sklearn.datasets import load_iris
from os import system                 # graphviz 라이브러리 설치를 위함

# graphviz 라이브러리 설치 // 아래 예제에서 오류나는 경우 anaconda prompt에서 설치바람
system("pip install graphviz")

# graphviz 사용에 있어서 error발생원인이 환경변수일 경우 환경변수 추가 필요
# 환경변수 추가 후 환경변수 설정 아래코드
# os.environ["PATH"] += os.pathsep + 'C:\\Program Files (x86)\\Graphviz2.38\\bin\\' 

# iris 실습데이터 로드
iris = load_iris()

 iris 데이터는 이전 게시글의 예제에서 계속적으로 다루어왔다. 4개의 feature 변수가 있으며, 3개의 target 변수가 있다. 

자세하게 알아보는 과정은 그렇기에 생략하겠다. 


2. Decision Tree Classifier ( 의사결정분류나무 )

 

기본적인 의사결정 나무 : Information Gain - Gini

# 의사결정나무 분류 
clf = tree.DecisionTreeClassifier()               # 종속변수가 현재 범주형
clf = clf.fit(iris.data, iris.target)             # feature, target

 이것은 의사결정나무 분류모형을 적합시킨 것이다. tree를 시각화시켜서 자세하게 살펴보자

# 시각화
dot_data = tree.export_graphviz(clf,   # 의사결정나무 모형 대입
                               out_file = None,  # file로 변환할 것인가
                               feature_names = iris.feature_names,  # feature 이름
                               class_names = iris.target_names,  # target 이름
                               filled = True,           # 그림에 색상을 넣을것인가
                               rounded = True,          # 반올림을 진행할 것인가
                               special_characters = True)   # 특수문자를 사용하나

graph = graphviz.Source(dot_data)              
graph

 graphviz는 tree를 도식화하는 라이브러리이다. 각 매개변수에 대한 설명을 적어놓았으니 읽어보길 바란다.

 의사결정나무는 맨 위의 1개 root노드부터 맨 아래 여러가지 노드들 즉, leaf노드들로 구성된다. 그리고 Decision Tree를 이용할 때, 가장 기본적인(아무런 매개변수를 주지 않았을 때) Information Gain 방식은 지니계수를 이용한다. 도식화된 Tree를 살펴보면 gini = xxx 가 써있음을 통해 알 수 있다. 이 gini계수는 엔트로피와 마찬가지로 낮을 수록 분류가 잘 된것으로 판단하며 기본적으로 의사결정나무는 이 Information Gain을 낮추는 방향으로 분류를 진행한다. 


Information Gain - entropy 의사결정나무

# 의사결정나무 분류 
clf2 = tree.DecisionTreeClassifier(criterion = "entropy")  # Information Gain - entropy
clf2 = clf2.fit(iris.data, iris.target)                    # feature, target
# 시각화
dot_data2 = tree.export_graphviz(clf2,        # 의사결정나무 모형 대입
                               out_file = None,     # file로 변환할 것인가
                               feature_names = iris.feature_names, # feature 이름
                               class_names = iris.target_names,   # target 이름
                               filled = True,        # 그림에 색상을 넣을것인가
                               rounded = True,       # 반올림을 진행할 것인가
                               special_characters = True)    # 특수문자를 사용하나

graph2 = graphviz.Source(dot_data2)              
graph2

 과정은 지니와 모두 동일하나 DecisionTreeClassifier()을 생성할 때 매개변수로 criterion = "entropy"만 추가하였다. 

 즉, 의사결정나무의 분류 기준을 entropy로 한다는 것이다. 이제 위의 결과표를 보면 gini가 아니라 entropy가 쓰인 것을 확인할 수 있다.

 

 위의 두 의사결정나무 모형은 너무 많은 노드들이 존재한다. 게다가 마지막 노드에서 gini와 entropy 모두 0.0을 출력한다. 이는 완벽하게 분리시켰다고 말할 수 도 있지만, 사실 억지로 분류시킨 것에 가깝다. 그렇기에 과적합(Overfitting)이 발생한 것이다. 

 추가적으로 한 가지 더 알아야 할 것이 있다. 위의 색은 3가지 색의 계열로 이루어져 있다. 같은 색 계열이면 같은 집단으로 분류를 한 것이며, 색이 진할수록 Information Gain(entropy, gini .. )이 낮은 것이다. 즉, 정확하게 분류를 했다는 것이다. 이는 상대적이기에 depth가 작으면 entropy가 높아도 진하게 출력될 수 있다.  

이제 pruning(가지치기)라는 기법을 배워서 과적합을 방지하도록 학습해보자.


Pruning - 가지치기

# Pruning
clf3 = tree.DecisionTreeClassifier(criterion = "entropy", max_depth = 2)
clf3.fit(iris.data, iris.target)

 이번 실습에서는 가지치기를 최대 깊이를 제한시켜서 실습하고자 한다. 의사결정나무에서 깊이란 맨 마지막 leaf노드들이 root노드까지 바로 가는데 걸리는 조건(edge)의 개수이다. 이번 실습에서는 깊이제한을 2로 하였다. 

 사실 이 실습의 가지치기의 기준은 올바른 방식이 아니다. 가지치기를 진행하는 방법은 여러 기준이 있는데

1) 지니계수/엔트로피와 같은 Information Gain의 값이 일정 수준 이하로 안내려가도록

2) 가지의 개수 자체를 제한하는 방법

3) 이 실습과 같이 깊이를 제한하는 방법

등이 있다.

 실제 통계용 언어로 많이 사용되는 R의 함수에서는 내부적으로 1)을 활용하여 가지치기를 진행한다. 그리고 이번 실습과 같이 3)을 활용하여 가지치기를 진행할 경우 Cross Validation등을 통해서 보다 정확한 깊이를 찾아낼 수 있다. 그렇기에 일반적으로 X라고 단정짓는 것은 좋은 방법은 아니다.  

 여튼 시각화를 통해서 확인해보자.

# 시각화
dot_data3 = tree.export_graphviz(clf3,               # 의사나무 모형 대입
                               out_file = None,        # file로 변환할 것인가
                               feature_names = iris.feature_names,  # feature 이름
                               class_names = iris.target_names,   # target 이름
                               filled = True,          # 그림에 색상을 넣을것인가
                               rounded = True,         # 반올림을 진행할 것인가
                               special_characters = True)  # 특수문자를 사용하나

graph3 = graphviz.Source(dot_data3)              
graph3

 

 가지치기의 기준으로 max_depth를 2로 주었더니, 트리의 깊이가 2로 변했다. 그리고 entropy 또한 0.4/0.151로 많이 높아졌다. 위의 DecisionTree의 gini/entropy는 0.0이었는데 분류가 너무 안된 것이 아닌가? 라고 생각할 수 있다.

 하지만, 우리는 여태 train데이터를 예측했다. 그렇기에 학습데이터의 경우 가지가 무한정 많아지면 정확해질 수 밖에 없다. 만약 새로운 test데이터가 주어진다면, 오히려 과적합된 DecisionTree가 학습데이터 내에서 너무 이상값들에 집중해서 일반적인 새로운 test데이터를 제대로 예측하지 못할 수도 있다. 

 게다가 이 짧은 트리가 거창한 트리보다 훨씬 직관적이고 이해도가 높다. Decision Tree를 사용하는 가장 큰 이유중 하나가 바로 "직관적인 이해"인데 가지치기를 하지않고 무한한 가지를 만들면 Decision Tree를 사용하는 의미 또한 퇴색된다. 

 

이제 Confusion matrix를 활용하여 3가지 분류기의 학습데이터를 분류하는 정확도를 확인해보자.


Confusion matrix를 활용한 정확도 비교

from sklearn.metrics import confusion_matrix

# 1번 의사결정나무 - 지니계수 활용
confusion_matrix(iris.target, clf.predict(iris.data))

# 2번 의사결정나무 - entropy 활용
confusion_matrix(iris.target, clf2.predict(iris.data))

# 3번 의사결정나무 - 가지치기 작업
confusion_matrix(iris.target, clf3.predict(iris.data))

 위의 결과를 보면 가지치기를 한 의사결정나무의 정확도가 가장 떨어진다. 하지만 학습데이터를 분류한 것이라는 사실을 염두해 두어야 한다. 만약 새로운 데이터가 들어오면 말했듯이 맨 마지막 가지치기의 의사결정나무가 일반화된 특징을  잡을 가능성이 높다.


Traing Set / Test Set 구분

 

 여태까진 전부 학습데이터로 분류를 진행했지만 실제 데이터가 주어졌을 때, 데이터는 Train/(Validation)/Test로 나누어 학습할 가능성이 크다. 그렇기에 Train set과 Test set을 나누어 실습해보자.

 

# 데이터셋 분리 함수
from sklearn.model_selection import train_test_split

X_train, X_test, y_train, y_test = train_test_split(iris.data     #feature 
                                                    , iris.target     #target
                                                    , stratify = iris.target #층화추출법
                                                    , random_state = 1)  #난수고정
												

 위의 train_test_split의 매개변수중 stratify매개변수가 들어간 것을 볼 수 있다. 이것은 필수적으로 들어가야할 요소는 아니다. iris데이터셋의 경우 150개의 데이터밖에 없기에 무작위 추출이 진행된다면 target데이터가 치우쳐질 수도 있다. 

 그리고 이러한 사용은 제약/임상실험에서도 마찬가지다.(병에 안걸린사람이 걸린사람보다 훨씬 많으니 과소평가될 가능성이 있음). 그렇기에 데이터가 적은 이유로 고루고루 데이터를 추출시키기 위해 층화추출법을 사용한 것이다. 

# train dataset
clf4 = tree.DecisionTreeClassifier(criterion = "entropy")
clf4.fit(X_train, y_train)

# test set predict confusion matrix
confusion_matrix(y_test,clf4.predict(X_test))

 위의 confusion matrix결과를 보면 가지치기를 하지 않았는데도 불구하고 3번에서 한 경우에 오분류가 발생했다. 이는 train data set과 test data set의 특성이 어느정도 달라서 학습의 분류결과가 완전하게 맞을 순 없다는 것을 보여준다. 


3. Decision regression Tree ( 의사결정회귀나무 )

 

 의사결정 회귀나무는 종속변수가 연속형 변수일때 진행한다. 기본적인 방식은 의사결정 분류나무와 동일하나 사용하는 함수가 다르다. 실습을 통해서 알아보자.

# 필요 라이브러리 
import numpy as np
from sklearn.tree import DecisionTreeRegressor    # 회귀나무 함수
import matplotlib.pyplot as plt

# 실습용 데이터셋 만들기
rng = np.random.RandomState(1)
X = np.sort(5 * rng.rand(80, 1), axis=0)
y = np.sin(X).ravel()                        # sin함수의 예측을 목표로한다
y[::5] += 3 * (0.5 - rng.rand(16))           # 이상치를 발생시킨다.

 위의 코드를 통해 필요한 라이브러리를 로드하고, 실습에서 사용할 학습용 데이터셋을 만들었다. 기본적인 target변수의 형태는 sin함수를 따르도록 만들었으나, 이상치를 주었다. 

# X_test set 생성
X_test = np.arange(0.0,5.0,0.01)[:,np.newaxis]

 test를 진행하기 위해 X_test셋을 만들었다. 이를 가지고 예측을 하는 작업을 진행할 것이다. 

 


Regression  Tree 구축

# 깊이가 다른 두 Regression 나무 생성
regr1 = tree.DecisionTreeRegressor(max_depth = 2)
regr2 = tree.DecisionTreeRegressor(max_depth = 5)

# 두 가지 회귀나무 적합
regr1.fit(X,y)
regr2.fit(X,y)
# 예측
y_1 = regr1.predict(X_test)
y_2 = regr2.predict(X_test)
# 예측 결과물
y_1

# depth가 다른 두 회귀나무 도식화

plt.figure()
plt.scatter(X, y, s=20, edgecolor="black",
            c="darkorange", label="data")
plt.plot(X_test, y_1, color="cornflowerblue",
         label="max_depth=2", linewidth=2)
plt.plot(X_test, y_2, color="yellowgreen", label="max_depth=5", linewidth=2)
plt.xlabel("data")
plt.ylabel("target")
plt.title("Decision Tree Regression")
plt.legend()
plt.show()

 위의 그림을 확인해보면 max_depth = 5인 의사결정회귀나무는 이상값에 영향을 더 크게 받았음을 확인할 수 있다. 오히려 max_depth = 2의 의사결정회귀나무가 이상값을 무시하고 전체적인 추세를 더 잘 잡는 것을 확인할 수 있다. 

 하지만 만약 sin함수에서 떨어져있는 점들이 이상값이 아니었다면, 저런 점들 또한 고려할 필요가 생긴다. 그렇기에 가지치기의 적절한 기준을 찾는 것 또한 분석가의 안목에 달려있다. 

 

# depth = 5 의사결정 회귀나무 시각화

dot_data4 = tree.export_graphviz(regr2, out_file=None, 
                                filled=True, rounded=True,  
                                special_characters=True)
graph4 = graphviz.Source(dot_data4) 
graph4

 위를 보면 depth가 5인 의사결정나무의 부분이다. 이미지가 너무 커서 짤렸으나 한 가지 확인하고 가야할 것이 있다. 

위에서 보면 value가 낮을 수록 같은 색계열에서 연한 색을 띈다. 그리고 value값이 높을 수록 진한 색을 띈다.

 

# depth =2 의사결정 회귀나무 시각화

dot_data5 = tree.export_graphviz(regr1, out_file=None, 
                                filled=True, rounded=True,  
                                special_characters=True)
graph5 = graphviz.Source(dot_data5) 
graph5

 위의 경우는 depth = 2의 의사결정 회귀나무이다. 회귀나무에서 보았듯이 분류나무와 분류하는 기준이 다르다. 분류나무에서는 Information Gain으로 entropy/gini지수를 사용했다면, 회귀나무에서는 회귀에 많이 사용하는 mse(mean squared error)가 기본적인 기준으로 작동한다. 

 즉, 회귀나무는 mse를 낮추는 방향으로 가지를 뻗어나아간다는 것이다. 

 

 이상 파이썬을 활용하여 의사결정분류나무(DecisionTreeClassifier)와 의사결정회귀나무(DecisionTreeRegression)에 대해서 알아보았고 이를 데이터 셋을 나누며 실습을 진행하고 시각화 또한 진행해 보았다.