福岡の社会人データサイエンティストの部屋

データサイエンスを極めるため、日々の学習を綴っています。

学習済みGAN使ってみた【図解速習Deep Learning】#009

こんにちは!こーたろーです。

本日はついにGANに取り掛かっていきます。

今日はdemoですが(笑

本日も【図解速習DEEP LEARNING】をやっていきます。

皆さんもう買いました? すべてを理解するのは結構難しいですよね。。汗



それでは早速本日分!
サンプルコードではTensorFlow1.Xですが、今回もTensorFlow2.4.0でやっていきます。
コードが変わっていますのでご注意ください。

この辺のバグ取りなんかは、Python、TensorFlowの勉強になって、とても為になっています。


GANをやっていきますが、今回は学習済みのGANのコレクションを使って、画像生成を行います。
洗剤空間上のランダムなベクトルを選び、それをGANへ入力し、画像を生成するという流れです。

ではソースコードを見ていきましょう。


1.ライブラリのインポート

from google.colab import output

import matplotlib.pyplot as plt

import numpy as np
import pandas as pd

import tensorflow as tf
import tensorflow_hub as hub

tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)


2.モデルを取得

辞書型で入れているGANモデルを選択すると、Tensorflow_hubからモデルをダウンロードします。

module_metadata_dict = {'dataset': ['CelebA HQ (128x128)', 'CelebA HQ (128x128)', 'LSUN Bedroom', 'LSUN Bedroom', 'CelebA HQ (128x128)', 'CelebA HQ (128x128)', 'LSUN Bedroom', 'LSUN Bedroom', 'CelebA HQ (128x128)', 'LSUN Bedroom', 'CIFAR10', 'CIFAR10', 'CIFAR10', 'CIFAR10', 'CIFAR10'], 'penalty': ['-', '-', '-', '-', '-', '-', '-', '-', 'DRAGAN (lambda=1.000)', 'WGAN (lambda=0.145)', '-', '-', '-', '-', 'WGAN (lambda=1.000)'], 'architecture': ['ResNet19', 'ResNet19', 'ResNet19', 'ResNet19', 'ResNet19', 'ResNet19', 'ResNet19', 'ResNet19', 'ResNet19', 'ResNet19', 'ResNet CIFAR', 'ResNet CIFAR', 'ResNet CIFAR', 'ResNet CIFAR', 'ResNet CIFAR'], 'beta1': ['0.375', '0.500', '0.585', '0.195', '0.500', '0.500', '0.500', '0.102', '0.500', '0.711', '0.500', '0.500', '0.500', '0.500', '0.500'], 'beta2': ['0.998', '0.999', '0.990', '0.882', '0.999', '0.999', '0.999', '0.998', '0.900', '0.979', '0.999', '0.999', '0.999', '0.999', '0.999'], 'module_url': ['https://tfhub.dev/google/compare_gan/model_1_celebahq128_resnet19/1', 'https://tfhub.dev/google/compare_gan/model_2_celebahq128_resnet19/1', 'https://tfhub.dev/google/compare_gan/model_3_lsun_bedroom_resnet19/1', 'https://tfhub.dev/google/compare_gan/model_4_lsun_bedroom_resnet19/1', 'https://tfhub.dev/google/compare_gan/model_5_celebahq128_resnet19/1', 'https://tfhub.dev/google/compare_gan/model_6_celebahq128_resnet19/1', 'https://tfhub.dev/google/compare_gan/model_7_lsun_bedroom_resnet19/1', 'https://tfhub.dev/google/compare_gan/model_8_lsun_bedroom_resnet19/1', 'https://tfhub.dev/google/compare_gan/model_9_celebahq128_resnet19/1', 'https://tfhub.dev/google/compare_gan/model_10_lsun_bedroom_resnet19/1', 'https://tfhub.dev/google/compare_gan/model_11_cifar10_resnet_cifar/1', 'https://tfhub.dev/google/compare_gan/model_12_cifar10_resnet_cifar/1', 'https://tfhub.dev/google/compare_gan/model_13_cifar10_resnet_cifar/1', 'https://tfhub.dev/google/compare_gan/model_14_cifar10_resnet_cifar/1', 'https://tfhub.dev/google/compare_gan/model_15_cifar10_resnet_cifar/1'], 'disc_iters': [1, 1, 1, 1, 1, 1, 1, 1, 5, 1, 5, 5, 5, 5, 5], 'model': ['Non-saturating GAN', 'Non-saturating GAN', 'Least-squares GAN', 'Non-saturating GAN', 'Non-saturating GAN', 'Non-saturating GAN', 'Least-squares GAN', 'Non-saturating GAN', 'Non-saturating GAN', 'Non-saturating GAN', 'Non-saturating GAN', 'Non-saturating GAN', 'Non-saturating GAN', 'Non-saturating GAN', 'Non-saturating GAN'], 'inception_score': ['2.38', '2.59', '4.23', '4.10', '2.38', '2.54', '3.64', '3.58', '2.34', '3.92', '7.57', '7.47', '7.74', '7.74', '7.70'], 'disc_norm': ['none', 'none', 'none', 'none', 'layer_norm', 'layer_norm', 'spectral_norm', 'spectral_norm', 'layer_norm', 'layer_norm', 'none', 'none', 'spectral_norm', 'spectral_norm', 'spectral_norm'], 'fid': ['34.29', '35.85', '102.74', '112.92', '30.02', '32.05', '41.60', '42.51', '29.13', '40.36', '28.12', '30.08', '22.91', '23.22', '22.73'], 'ms_ssim_score': ['0.32', '0.29', 'N/A', 'N/A', '0.29', '0.28', 'N/A', 'N/A', '0.30', 'N/A', 'N/A', 'N/A', 'N/A', 'N/A', 'N/A'], 'learning_rate': ['3.381e-05', '1.000e-04', '3.220e-05', '1.927e-05', '1.000e-04', '1.000e-04', '2.000e-04', '2.851e-04', '1.000e-04', '1.281e-04', '2.000e-04', '1.000e-04', '2.000e-04', '2.000e-04', '2.000e-04']}
MODULE_METADATA = pd.DataFrame.from_dict(module_metadata_dict)

MIN_FID_MODULE = MODULE_METADATA.loc[
    MODULE_METADATA['fid'].astype(float).idxmin()]

SELECTED_MODULE = MIN_FID_MODULE['module_url']
SELECTED_MODULE_DATASET = MIN_FID_MODULE['dataset']

def display_images(images, captions=None):
  batch_size, dim1, dim2, channels = images.shape
  num_horizontally = 8
  
  figsize = (20, 20) if dim1 > 32 else (10, 10)
  f, axes = plt.subplots(
      len(images) // num_horizontally, num_horizontally, figsize=figsize)
  for i in range(len(images)):
    axes[i // num_horizontally, i % num_horizontally].axis("off")
    if captions is not None:
      axes[i // num_horizontally, i % num_horizontally].text(0, -3, captions[i])
    axes[i // num_horizontally, i % num_horizontally].imshow(images[i])
  f.tight_layout()
  

class ShowModuleTable(object):
  def __init__(self, callback):
    self._callback = callback

  def _repr_html_(self):
    template = """
    <style>
       table {
         font-size: 15px;
         font-family: Inconsolata, monospace;
         border-collapse: collapse;
         border: 1px solid #444444;
       }
       th {
         font-size: 18px;
         background-color: #DDDDDD;
         border: 1px solid #AAAAAA;
         white-space: nowrap;
       }
       tr {
         cursor: pointer;
         white-space: nowrap;
       }
       td {
         padding: 6px 6px 6px 6px;
         border: 1px solid #AAAAAA;
       }
      .selected-row {
        font-weight: bold;
        background-color: #B0BED9;
      }
    </style>
    <table>"""
    
    table_headers = [
      ('dataset', 'Dataset'),
      ('architecture', 'Architecture'),
      ('fid', 'FID'),
      ('inception_score', 'IS'),
      ('ms_ssim_score', 'MS-SSIM'),
      ('model', 'Model'),
      ('learning_rate', 'Learning rate'),
      ('beta1', '&beta;<sub>1</sub>'),
      ('beta2', '&beta;<sub>2</sub>'),
      ('disc_iters', 'n<sub>disc</sub>'),
      ('disc_norm', 'Disc norm'),
      ('penalty', 'Penalty'),
      ('module_url', 'Module name'),
    ]
    header_template = "<tr>"
    for _, header_name in table_headers:
      header_template += "<th>" + header_name + "</th>"
    header_template += "</tr>"
    template += header_template
    
    for i, (_, row) in enumerate(MODULE_METADATA.iterrows()):
      uuid = "row-%s" % i
      
      output.register_callback(uuid, self._callback)
      
      selected_class = ""
      if row['module_url'] == MIN_FID_MODULE['module_url']:
        selected_class = "class=\"selected-row\""

      row_template = "<tr id=\"" + uuid + "\"" + selected_class + ">"
      for key, _ in table_headers:
        row_template += "<td>" + str(row[key]) + "</td>"
      row_template += "</tr>"
      template += row_template
      
    template += """
      </table>
      <script>"""
    
    for i, (_, row) in enumerate(MODULE_METADATA.iterrows()):
      uuid = "row-%s" % i
      m = row['module_url']
      d = row['dataset']
      template += """
        document.querySelector(\"#""" + uuid + """\").onclick = function() {
          google.colab.kernel.invokeFunction('""" + uuid + """', ['""" + m +"""', '""" + d + """'], {});
          var selected = document.getElementsByClassName('selected-row');
          for (var i = 0; i < selected.length; i++) {
            selected[i].classList.remove('selected-row');
          }
          this.classList.toggle("selected-row");
          e.preventDefault();
        };
        """
    template += """</script>"""
    return template


def set_selected_module(module_name, dataset):
  global SELECTED_MODULE
  SELECTED_MODULE = module_name
  global SELECTED_MODULE_DATASET
  SELECTED_MODULE_DATASET = dataset
ShowModuleTable(set_selected_module)

一覧表示したらこんな感じです。
f:id:dsf-kotaro:20210203214217p:plain


下記の「assert」の使い方は覚えておいた方がいいです!
今度解説したいと思います。

assert SELECTED_MODULE is not None and SELECTED_MODULE_DATASET is not None, \
  'You must run the above cell and select a module from the table to generate images.'

print('Using module: "%s"' % SELECTED_MODULE)
print('Generating images like dataset: "%s"' % SELECTED_MODULE_DATASET)

batch_size = 64
z_dim = 128

with tf.Graph().as_default():
  gan = hub.Module(SELECTED_MODULE)
  z_input = tf.compat.v1.placeholder(dtype=tf.float32, shape=(batch_size, z_dim))
  image_output = gan(z_input, signature="generator") 
  
  with tf.compat.v1.train.MonitoredSession() as session:
    z_values = np.random.uniform(-1, 1, size=(batch_size, z_dim))
    images = session.run(image_output, feed_dict={z_input: z_values})

    display_images(images)

Z_input という潜在空間ベクトルを定義しています。
そこからGANを使って画像を生成しています。

結果がこちら↓↓↓↓↓

f:id:dsf-kotaro:20210203214300p:plain


メタデータから別の学習済みGANを選択するには、上記のコードのうち、

MIN_FID_MODULE = MODULE_METADATA.loc[
    MODULE_METADATA['fid'].astype(float).idxmin()]

の部分を変更してみてください。
モジュールダイレクト入力なんかでも大丈夫です。
一覧表示のコードの記述が面倒なので、一つずつ使ってみてもいいかもしれませんね。

「Dataset : LSUN Bedroom」「model : model_4_lsun_bedroom_resnet19」の場合
f:id:dsf-kotaro:20210203213359p:plain

こんな感じです。

いかがでしたでしょうか。
よくわかりませんよね。。。汗
入力から画像を生成しただけなので、GAN本来の特性が出ていない感じがします。
入門編なのでこんなもんなのかな? GANも後々作成できたらと思っています。

ではでは。