11wk-40: Medical Cost / 의사결정나무의 시각화

Author

최규빈

Published

November 16, 2023

1. 강의영상

2. Imports

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import sklearn.tree
import graphviz
#---#
import warnings
warnings.filterwarnings('ignore')

3. 데이터준비 + 적합

df_train = pd.read_csv('https://raw.githubusercontent.com/guebin/MP2023/main/posts/insurance.csv')
df_train
age sex bmi children smoker region charges
0 19 female 27.900 0 yes southwest 16884.92400
1 18 male 33.770 1 no southeast 1725.55230
2 28 male 33.000 3 no southeast 4449.46200
3 33 male 22.705 0 no northwest 21984.47061
4 32 male 28.880 0 no northwest 3866.85520
... ... ... ... ... ... ... ...
1333 50 male 30.970 3 no northwest 10600.54830
1334 18 female 31.920 0 no northeast 2205.98080
1335 18 female 36.850 0 no southeast 1629.83350
1336 21 female 25.800 0 no southwest 2007.94500
1337 61 female 29.070 0 yes northwest 29141.36030

1338 rows × 7 columns

# step1 
X = pd.get_dummies(df_train.loc[:,'age':'region'])
y = df_train['charges']
# step2 
predictr = sklearn.tree.DecisionTreeRegressor(max_depth=3)
# step3 
predictr.fit(X,y)
# step4 -- pass 
DecisionTreeRegressor(max_depth=3)
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.

4. matplotlib 기반 시각화

A. 기본시각화

sklearn.tree.plot_tree(predictr);

  • 잘 안보임

B. max_depth 조정

sklearn.tree.plot_tree(
    predictr,
    max_depth=0
);

C. 변수이름 추가

sklearn.tree.plot_tree(
    predictr,
    max_depth=0,
    feature_names=X.columns
);

D. fig 오브젝트

- plt.gcf()를 이용하여 fig 오브젝트 추출

sklearn.tree.plot_tree(
    predictr,
    max_depth=1,
    feature_names=X.columns
);
fig = plt.gcf()

- fig.suptitle 을 이용하여 제목을 붙일 수도 있지 않을까?

fig.suptitle("title??")
Text(0.5, 0.98, 'title??')
fig

- dpi 조정

fig.set_dpi(250)
fig

E. matplotlib의 ax에 그리기

fig = plt.figure()
ax = fig.subplots(2,1)
ax[0].plot(y,y,'--')
ax[0].plot(y,predictr.predict(X),'o',alpha=0.1)
sklearn.tree.plot_tree(predictr,feature_names=X.columns,ax=ax[1],max_depth=0);

5. GraphViz 를 이용

g = sklearn.tree.export_graphviz(
    predictr,
    feature_names=X.columns
)
graphviz.Source(g)

  • 보기에 훨씬 깔끔함

- 파일로 추출하여 저장하는 방법

g = sklearn.tree.export_graphviz(
    predictr,
    feature_names=X.columns
)
graphviz.Source(g).render('tree',format="pdf")
'tree.pdf'
  • tree.pdf 이 생성되고 이 파일을 열면 tree의 구조를 쉽게 파악할 수 있음.

6. HW

tree.pdf 파일을 제출