福岡人データサイエンティストの部屋

データサイエンスを極めるため、日々の学習を綴っています。

【GAN】JoJoGAN使ってみた(その1)【Python】



こんにちは!こーたろーです。


以前からきになっていたJoJoGANですが、土日で環境構築して使えるようになりました。


参考のサイトを探していたら、最近私のサイトを見てくださっている、「touch-sp」さんのサイトへたどり着きました。


【PyTorch】JoJoGANというものを使わせて頂きました - パソコン関連もろもろ


こんなところで繋がるとは思っていませんでした!笑


touch-spさんのサイトを参考にしながらやっていきます!




環境構築



ubuntuでもいいのですが、私はWindowsでなおかつAnacondaの仮想環境を使っているので、そちらでやっていきます。


正直、今回の環境構築は苦戦しました。


必要だったもの

  • PyTorch
  • torchvision
  • dlib



この辺のインストールがエラーだらけ。


いろいろ調べながらやりました。


結局やったのは、Anaconda NavigatorからPyTorchはインストールしておいて、あとは仮想環境のターミナルからいくつかインストールしました。

conda install git
conda install cmake
conda install -c menpo dlib
conda install torchvision



この辺りは、使用している環境でインストールの仕方が若干異なるため、苦戦すると思います。

データのダウンロード



JoJoGANは、GitHubにデータが公開されているため、zipファイルをダウンロードして、好きなフォルダに展開しておきます。


GitHub - mchong6/JoJoGAN: Official PyTorch repo for JoJoGAN: One Shot Face Stylization





今回は、学習済みのモデルを使うため、別途モデルをダウンロードして、「model」というフォルダをルートフォルダ内に作成して、そこにモデルを配置します。


用意されているデータはモデルはというと、プログラムを開いたら出てきます。

google_drive_paths = {
    "models/stylegan2-ffhq-config-f.pt": "https://drive.google.com/uc?id=1Yr7KuD959btpmcKGAUsbAk5rPjX2MytK",
    "models/dlibshape_predictor_68_face_landmarks.dat": "https://drive.google.com/uc?id=11BDmNKS1zxSZxkgsEvQoKgFd8J264jKp",
    "models/e4e_ffhq_encode.pt": "https://drive.google.com/uc?id=1o6ijA3PkcewZvwJJ73dJ0fxhndn0nnh7",
    "models/restyle_psp_ffhq_encode.pt": "https://drive.google.com/uc?id=1nbxCIVw9H3YnQsoIPykNEFwWJnHVHlVd",
    "models/arcane_caitlyn.pt": "https://drive.google.com/uc?id=1gOsDTiTPcENiFOrhmkkxJcTURykW1dRc",
    "models/arcane_caitlyn_preserve_color.pt": "https://drive.google.com/uc?id=1cUTyjU-q98P75a8THCaO545RTwpVV-aH",
    "models/arcane_jinx_preserve_color.pt": "https://drive.google.com/uc?id=1jElwHxaYPod5Itdy18izJk49K1nl4ney",
    "models/arcane_jinx.pt": "https://drive.google.com/uc?id=1quQ8vPjYpUiXM4k1_KIwP4EccOefPpG_",
    "models/disney.pt": "https://drive.google.com/uc?id=1zbE2upakFUAx8ximYnLofFwfT8MilqJA",
    "models/disney_preserve_color.pt": "https://drive.google.com/uc?id=1Bnh02DjfvN_Wm8c4JdOiNV4q9J7Z_tsi",
    "models/jojo.pt": "https://drive.google.com/uc?id=13cR2xjIBj8Ga5jMO7gtxzIJj2PDsBYK4",
    "models/jojo_preserve_color.pt": "https://drive.google.com/uc?id=1ZRwYLRytCEKi__eT2Zxv1IlV6BGVQ_K2",
    "models/jojo_yasuho.pt": "https://drive.google.com/uc?id=1grZT3Gz1DLzFoJchAmoj3LoM9ew9ROX_",
    "models/jojo_yasuho_preserve_color.pt": "https://drive.google.com/uc?id=1SKBu1h0iRNyeKBnya_3BBmLr4pkPeg_L",
    "models/supergirl.pt": "https://drive.google.com/uc?id=1L0y9IYgzLNzB-33xTpXpecsKU-t9DpVC",
    "models/supergirl_preserve_color.pt": "https://drive.google.com/uc?id=1VmKGuvThWHym7YuayXxjv0fSn32lfDpE",
}



辞書で登録されているのですが、googledriveに格納されているので、各々アクセスしてファイルをダウンロードしておきましょう。


今回は、「stylegan2-ffhq-config-f.pt」と「jojo.pt」を使いました。





ルートフォルダ内はこのようになっています。


JoJo GAN training.ipynb」は、私が動作確認のために作ったフォルダになります。

ライブラリ・オブジェクトのインポート



必要なライブラリをインポートしていきます。


このライブラリが実行できないと進みません。(当たり前ですが。)


ルートフォルダにある、pyファイルなども呼び出します。

import torch
import dlib
from torchvision import transforms, utils
from util import *
from PIL import Image
import os
from model import *
from e4e_projection import projection as e4e_projection


モデル呼び出し



JoJoGANのファイルの中に、predict.pyがあるのですが、JupyterNotebookで実行するため、必要な箇所を取り出して実行します。

device = 'cpu'
latent_dim = 512

generator = Generator(1024, latent_dim, 8, 2).to(device)
ckpt = torch.load('models/stylegan2-ffhq-config-f.pt', map_location=lambda storage, loc: storage)
mean_latent = generator.mean_latent(10000)

plt.rcParams['figure.dpi'] = 150



「'models/stylegan2-ffhq-config-f.pt'」はオリジナルのGeneratorみたいです。呼び出して使いましょう。




サンプル画像の「iu.jpeg」を使って、JoJoの画風にしていきます。


そのための前処理として、画像の顔をそろえてトリミングを行います。


「align_face」は、util.pyで定義された関数で、顔の処理を行ってくれます。

filename = 'iu.jpg'
filepath = f'test_input/{filename}'

name = strip_path_extension(filepath) +'.pt'

aligned_face = align_face(filepath)

my_w = e4e_projection(aligned_face, name, device).unsqueeze(0)


参考にする画像の学習モデルを呼び出し、画像を生成



今回は、JoJoの画像を使ったものを作成していきます。


サンプル画像としては、





こちらの画像になります。GitHubでダウンロードしたらついてきます。


jojo.pt」が学習済みモデルになりますので、それを呼び出します。

ckpt = torch.load(os.path.join('models', jojo.pt), map_location=lambda storage, loc: storage)
generator.load_state_dict(ckpt["g"], strict=False)

n_sample =  1

seed = 3000
torch.manual_seed(seed)

with torch.no_grad():
    generator.eval()
    my_sample = generator(my_w, input_is_latent=True)

tensor = utils.make_grid(my_sample, normalize=True, range=(-1, 1), nrow=1)
result_image = transforms.ToPILImage()(tensor)

if pretrained == 'arcane_multi':
    style_path = f'style_images_aligned/arcane_jinx.png'
else:   
    style_path = f'style_images_aligned/{pretrained}.png'

style_image = Image.open(style_path)

aligned_face.show()
style_image.show()
result_image.show()



出力した結果がこちらです。


元画像





生成画像





なるほど。


JoJoGANの他のpyファイルを見てみると、参考とするptファイルを作成することもできるようでしたので、近くそちらでも試してみたいと思います。


ではでは。


実践GAN ~敵対的生成ネットワークによる深層学習~ (Compass Booksシリーズ)