본문 바로가기
Experience/- KT AIVLE School

KT AIVLE School 6주차 정리 - FI, PFI, SHAP

by Yoojacha 2023. 3. 10.

FI (Feature Importance)

계산되는 방식이 2가지가 존재하기에 중요도 값의 단위를 확인할 필요가 있다.

 

# 반복문을 돌면서 모델에 내장된 feature_importance_ 시각화
# results 는 딕셔너리 형태로 모델을 저장함
for key in results:
    if key in ['RL', 'KNNR', 'SVR']:
        continue
    tmp = pd.DataFrame({'feature_importance': results[key].feature_importances_, 'feature_names': list(x_train)}).sort_values('feature_importance', ascending=False)[:20]
    plt.figure(figsize=(16, 6))
    sns.barplot(x='feature_importance', y='feature_names', data = tmp)
    plt.title(key)
    plt.show()

# 제공되는 모듈 사용
plot_importance(model)
# 모델 1개만 시각화

key = 'LGBMC'
tmp = pd.DataFrame({'feature_importance': results[key].feature_importances_, 'feature_names': list(x_train)}).sort_values('feature_importance', ascending=False)

plt.figure(figsize=(16, 6))
sns.barplot(x='feature_importance', y='feature_names', data = tmp)
plt.title(key)
plt.show()

PFI (순열 특성 중요도, Permutation Feature Importance)

피처값의 순서를 변경해서 생기는 모델의 예측 오차 증가량을 가지고 Feature Importance를 도출한다.

피처간의 상관관계에 유의하면서 중요도를 바라봐야 한다.

피처를 섞다보면 무의미한 중요도가 될 수 있다

모듈 불러오기

from sklearn.inspection import permutation_importance
# 피처별 박스플롯 찍어보기

pfi = permutation_importance(model, x_val_s, y_val, n_repeats=10, 
                             scoring = 'r2', random_state=20)
sorted_idx = pfi.importances_mean.argsort()

plt.figure(figsize = (16, 10))
plt.boxplot(pfi.importances[sorted_idx].T, vert=False, labels=x.columns[sorted_idx])
plt.axvline(x=0, color='r', linestyle='--', linewidth=1)
plt.grid()
plt.show()

SHAP

  • f(x) : 실제 예측값
  • base value: 전체 예측값들의 평균
  • 컬럼별 기여도의 합: base value와 f(x) 의 차이

shap 초기 설정

explainer = shap.TreeExplainer(model)
shap_values = explainer.shap_values(x_train)

시각화

shap.initjs()

# shap.force_plot(전체평균, 기여도 한 행, 실제값 한 행)
shap.force_plot(explainer.expected_value, shap_values[index, :], x.iloc[index,:])

행별 기여도 분포를 통한 변수 중요도 이해

shap.initjs()
shap.summary_plot(shap_values, x_train)

shap.initjs()
shap.dependence_plot('Feature 1', shap_values, x_train, interaction_index = 'Feature 6')


 

출처

 

shap.summary_plot — SHAP latest documentation

© Copyright 2018, Scott Lundberg Revision c22690f3.

shap-lrjball.readthedocs.io

 

댓글