福岡人データサイエンティストの部屋

データサイエンスを極めるため、日々の学習を綴っています。

CNNの転移学習その2【図解速習DEEP LEARNING】#006

こんにちは!こーたろーです。
図解速習DEEP LEARNING】の転移学習の続きをやっていきます!


前回のブログでは、
1.学習済みモデルを特徴抽出につかう
をやってみました。
CNNの転移学習【図解速習DEEP LEARNING】#005 - 福岡の社会人データサイエンティストの部屋



今回は
2.学習済みのモデルのファインチューニングを行います。

前回、データセットの準備やモデルの作成は記載しましたので、続きから書いていきます。
※同じワークブックに、続けて記載していけば動きます。

前回は、学習済みのMobileNetV2のベースモデルにプーリング層と予測層を重ねて、予測層のみを学習させました。
今回は、学習済みモデルの層を一部重みの更新をさせるという方法でファインチューニングしていきます。




1.ベースモデルの層をチェックする。

base_model.trainable = True
print("Number of layers in the base model: ", len(base_model.layers))

fine_tune_at = 100

for layer in base_model.layers[:fine_tune_at]:
  layer.trainable =  False

out
f:id:dsf-kotaro:20210125170117p:plain

全部で155層あり、そのうちの100層は重みの更新がなし、50層は更新可能な状態へ変更しました。


2.モデルのコンパイル

model.compile(loss='binary_crossentropy',
              optimizer = tf.keras.optimizers.RMSprop(lr=base_learning_rate/10),
              metrics=['accuracy'])
model.summary()

out
f:id:dsf-kotaro:20210125170156p:plain


3.モデルの学習

fine_tune_epochs = 10
total_epochs =  initial_epochs + fine_tune_epochs

history_fine = model.fit(train_batches, 
                         epochs=total_epochs, 
                         initial_epoch = initial_epochs,
                         validation_data=validation_batches)

f:id:dsf-kotaro:20210125170931p:plain

4.結果を表示する

acc += history_fine.history['accuracy']
val_acc += history_fine.history['val_accuracy']

loss += history_fine.history['loss']
val_loss += history_fine.history['val_loss']
plt.figure(figsize=(8, 8))
plt.subplot(2, 1, 1)
plt.plot(acc, label='Training Accuracy')
plt.plot(val_acc, label='Validation Accuracy')
plt.ylim([0.8, 1])
plt.plot([initial_epochs-1,initial_epochs-1], 
          plt.ylim(), label='Start Fine Tuning')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')

plt.subplot(2, 1, 2)
plt.plot(loss, label='Training Loss')
plt.plot(val_loss, label='Validation Loss')
plt.ylim([0, 1.0])
plt.plot([initial_epochs-1,initial_epochs-1], 
         plt.ylim(), label='Start Fine Tuning')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')
plt.xlabel('epoch')
plt.show()

out
f:id:dsf-kotaro:20210125171010p:plain

今回は20エポックのうち、後半10回がファインチューニングとなっています。
その結果、転移した学習結果を使ってファインチューニングを行ったケース2の方が、前回のブログでおこなったケース1の追加した層だけ学習したが場合よりも精度が向上したことが分かる。