본문 바로가기
Experience/- KT AIVLE School

KT AIVLE School 7주차 정리 - 전이 학습과 파인 튜닝

by Yoojacha 2023. 3. 15.

전이 학습과 파인 튜닝을 하는 방법은 정말 쉬웠습니다. 배우면서도 엄청 간단해서 이래도 되나 싶을 정도였습니다. 하지만 저의 수준과 경험으로는 파인 튜닝을 통해 성능을 올리는 것은 어렵다는 것을 알게 되었고, 인터넷을 봐도 이러한 정보는 다 꽁꽁 숨겨놓거나 데이터마다 다르기 때문에 여러방면에 경험을 많이 해야겠습니다!


전이 학습하는 코드 틀

https://keras.io/guides/sequential_model/
# Load a convolutional base with pre-trained weights
base_model = keras.applications.Xception(
    weights='imagenet',
    include_top=False,
    pooling='avg')

# Freeze the base model
base_model.trainable = False

# Use a Sequential model to add a trainable classifier on top
model = keras.Sequential([
    base_model,
    layers.Dense(1000),
])

# Compile & train
model.compile(...)
model.fit(...)

 

전이 학습 및 파인튜닝 해보기 (ResNet50V2)

import tensorflow as tf
from tensorflow.keras.backend import clear_session
from tensorflow.keras.models import Sequential, Model
from tensorflow.keras.layers import Dense, Flatten, Input, GlobalAveragePooling2D
from tensorflow.keras.applications import ResNet50V2
clear_session()

# Pretrained Model을 가져올 때의 설정
kwargs = {
          include_top=False,          # classification layer 포함 여부
          input_shape=(224, 224, 3),  # 학습할 이미지 사이즈 정의, 보통 224로 이용됨
          weights='imagenet'          # imagenet을 학습했을 때의 가중치 포함 여부
         }

# ResNet50V2 불러오기
base_model = ResNet50V2(**kwargs)

# ResNet50V2의 레이어 동결
base_model.trainable = False 

# 모델 설계
inputs = Input(shape=(224, 224, 3))

x = base_model(inputs, training=False)
x = GlobalAveragePooling2D()(x)
x = Flatten()(x)
x = Dense(512, activation='relu')(x)

outputs = Dense(1, activation = 'sigmoid')(x)

model = Model(inputs, outputs)
model.compile(
              optimizer = 'adam', 
              loss = 'binary_crossentropy', 
              metrics = ['accuracy']
             )
model.summary()

 

callback 함수들은 이 게시글 참고

 

history = model.fit(
                    x_train, y_train,
                    validation_data=(x_valid, y_valid),
                    epochs=50, batch_size=256, 
                    callbacks=[es], verbose=1
                   )

 

Pretrained model 의 레이어 동결 범위 지정

print(len(base_model.layers)) # 레이어 개수 확인

for layer in model.layers[:200]: # 200층 까지 재동결, 200층부터 끝까지는 동결 해제
    layer.trainable = False

댓글