본문 바로가기
AI/- Library

[scikit-learn] KFold 와 StratifiedKFold 의 차이

by Yoojacha 2023. 3. 9.

크게 중요하지는 않지만 알아두면 좋을 것 같아서 글을 적어둡니다! 인터넷에 널린 코드들을 보다보니 StratifiedKFold 라는 함수가 보여서 AIVLE School 에서 배우지 않았던 개념이었습니다!


KFold

  • 회귀 문제(회귀의 결정값은 연속된 숫자값이기 때문에 결정값 별로 분포를 정하는 의미가 없기 때문

 StratifedKFold

  • 레이블 데이터가 왜곡됬을 경우 반드시.
  • 일반적으로 분류에서의 교차 검증
참고: https://velog.io/@ohxhxs/%ED%8C%8C%EC%9D%B4%EC%8D%AC-%EB%A8%B8%EC%8B%A0%EB%9F%AC%EB%8B%9D-%EA%B5%90%EC%B0%A8%EA%B2%80%EC%A6%9D-KFold-StratifiedKFold-crossvalscoreGridSearchCV

cross_val_score() 

cross_val_score() 의 cv 파라미터는 내부적으로 stratifiedKFold 함수를 사용한다.

from sklearn.model_selection import KFold # 회귀에 사용

cv_score = cross_val_score(
                            models[key], 
                            x_train, 
                            y_train, 
                            cv=KFold(n_splits=5)
                          ) 

print('=' * 20, key, '=' * 20)
print(cv_score)
print('평균:', cv_score.mean())
print('표준편차:', cv_score.std())

GridSearchCV()

from sklearn.model_selection import StratifiedKFold # 분류에 사용

params = {'max_depth': range(5,20), 
          'n_estimators': range(100, 200, 10),
          'min_child_samples': [20, 22, 24]}

model = GridSearchCV(
                    model,                  # 기본 모델 이름
                    params,                 # 앞에서 선언한 튜닝용 파라미터 변수
                    cv=StratifiedKFold(n_splits=5), # default=5
                    refit=True,             # 기본값 True
                    scoring=make_scorer(f1_score, average='micro'),   # 평가 방법
                    n_jobs=-1,              # cpu 전부 사용
                    verbose=2               # 학습 진행 상황 보기
                    )

댓글