【GAN】JoJoGAN使ってみた(その1)【Python】
こんにちは!こーたろーです。
以前からきになっていたJoJoGANですが、土日で環境構築して使えるようになりました。
参考のサイトを探していたら、最近私のサイトを見てくださっている、「touch-sp」さんのサイトへたどり着きました。
【PyTorch】JoJoGANというものを使わせて頂きました - パソコン関連もろもろ
こんなところで繋がるとは思っていませんでした!笑
touch-spさんのサイトを参考にしながらやっていきます!
環境構築
ubuntuでもいいのですが、私はWindowsでなおかつAnacondaの仮想環境を使っているので、そちらでやっていきます。
正直、今回の環境構築は苦戦しました。
必要だったもの
- PyTorch
- torchvision
- dlib
この辺のインストールがエラーだらけ。
いろいろ調べながらやりました。
結局やったのは、Anaconda NavigatorからPyTorchはインストールしておいて、あとは仮想環境のターミナルからいくつかインストールしました。
conda install git conda install cmake conda install -c menpo dlib conda install torchvision
この辺りは、使用している環境でインストールの仕方が若干異なるため、苦戦すると思います。
データのダウンロード
JoJoGANは、GitHubにデータが公開されているため、zipファイルをダウンロードして、好きなフォルダに展開しておきます。
GitHub - mchong6/JoJoGAN: Official PyTorch repo for JoJoGAN: One Shot Face Stylization
今回は、学習済みのモデルを使うため、別途モデルをダウンロードして、「model」というフォルダをルートフォルダ内に作成して、そこにモデルを配置します。
用意されているデータはモデルはというと、プログラムを開いたら出てきます。
google_drive_paths = { "models/stylegan2-ffhq-config-f.pt": "https://drive.google.com/uc?id=1Yr7KuD959btpmcKGAUsbAk5rPjX2MytK", "models/dlibshape_predictor_68_face_landmarks.dat": "https://drive.google.com/uc?id=11BDmNKS1zxSZxkgsEvQoKgFd8J264jKp", "models/e4e_ffhq_encode.pt": "https://drive.google.com/uc?id=1o6ijA3PkcewZvwJJ73dJ0fxhndn0nnh7", "models/restyle_psp_ffhq_encode.pt": "https://drive.google.com/uc?id=1nbxCIVw9H3YnQsoIPykNEFwWJnHVHlVd", "models/arcane_caitlyn.pt": "https://drive.google.com/uc?id=1gOsDTiTPcENiFOrhmkkxJcTURykW1dRc", "models/arcane_caitlyn_preserve_color.pt": "https://drive.google.com/uc?id=1cUTyjU-q98P75a8THCaO545RTwpVV-aH", "models/arcane_jinx_preserve_color.pt": "https://drive.google.com/uc?id=1jElwHxaYPod5Itdy18izJk49K1nl4ney", "models/arcane_jinx.pt": "https://drive.google.com/uc?id=1quQ8vPjYpUiXM4k1_KIwP4EccOefPpG_", "models/disney.pt": "https://drive.google.com/uc?id=1zbE2upakFUAx8ximYnLofFwfT8MilqJA", "models/disney_preserve_color.pt": "https://drive.google.com/uc?id=1Bnh02DjfvN_Wm8c4JdOiNV4q9J7Z_tsi", "models/jojo.pt": "https://drive.google.com/uc?id=13cR2xjIBj8Ga5jMO7gtxzIJj2PDsBYK4", "models/jojo_preserve_color.pt": "https://drive.google.com/uc?id=1ZRwYLRytCEKi__eT2Zxv1IlV6BGVQ_K2", "models/jojo_yasuho.pt": "https://drive.google.com/uc?id=1grZT3Gz1DLzFoJchAmoj3LoM9ew9ROX_", "models/jojo_yasuho_preserve_color.pt": "https://drive.google.com/uc?id=1SKBu1h0iRNyeKBnya_3BBmLr4pkPeg_L", "models/supergirl.pt": "https://drive.google.com/uc?id=1L0y9IYgzLNzB-33xTpXpecsKU-t9DpVC", "models/supergirl_preserve_color.pt": "https://drive.google.com/uc?id=1VmKGuvThWHym7YuayXxjv0fSn32lfDpE", }
辞書で登録されているのですが、googledriveに格納されているので、各々アクセスしてファイルをダウンロードしておきましょう。
今回は、「stylegan2-ffhq-config-f.pt」と「jojo.pt」を使いました。
ルートフォルダ内はこのようになっています。
「JoJo GAN training.ipynb」は、私が動作確認のために作ったフォルダになります。
ライブラリ・オブジェクトのインポート
必要なライブラリをインポートしていきます。
このライブラリが実行できないと進みません。(当たり前ですが。)
ルートフォルダにある、pyファイルなども呼び出します。
import torch import dlib from torchvision import transforms, utils from util import * from PIL import Image import os from model import * from e4e_projection import projection as e4e_projection
モデル呼び出し
JoJoGANのファイルの中に、predict.pyがあるのですが、JupyterNotebookで実行するため、必要な箇所を取り出して実行します。
device = 'cpu' latent_dim = 512 generator = Generator(1024, latent_dim, 8, 2).to(device) ckpt = torch.load('models/stylegan2-ffhq-config-f.pt', map_location=lambda storage, loc: storage) mean_latent = generator.mean_latent(10000) plt.rcParams['figure.dpi'] = 150
「'models/stylegan2-ffhq-config-f.pt'」はオリジナルのGeneratorみたいです。呼び出して使いましょう。
サンプル画像の「iu.jpeg」を使って、JoJoの画風にしていきます。
そのための前処理として、画像の顔をそろえてトリミングを行います。
「align_face」は、util.pyで定義された関数で、顔の処理を行ってくれます。
filename = 'iu.jpg' filepath = f'test_input/{filename}' name = strip_path_extension(filepath) +'.pt' aligned_face = align_face(filepath) my_w = e4e_projection(aligned_face, name, device).unsqueeze(0)
参考にする画像の学習モデルを呼び出し、画像を生成
今回は、JoJoの画像を使ったものを作成していきます。
サンプル画像としては、
こちらの画像になります。GitHubでダウンロードしたらついてきます。
「jojo.pt」が学習済みモデルになりますので、それを呼び出します。
ckpt = torch.load(os.path.join('models', jojo.pt), map_location=lambda storage, loc: storage) generator.load_state_dict(ckpt["g"], strict=False) n_sample = 1 seed = 3000 torch.manual_seed(seed) with torch.no_grad(): generator.eval() my_sample = generator(my_w, input_is_latent=True) tensor = utils.make_grid(my_sample, normalize=True, range=(-1, 1), nrow=1) result_image = transforms.ToPILImage()(tensor) if pretrained == 'arcane_multi': style_path = f'style_images_aligned/arcane_jinx.png' else: style_path = f'style_images_aligned/{pretrained}.png' style_image = Image.open(style_path) aligned_face.show() style_image.show() result_image.show()
出力した結果がこちらです。
元画像
生成画像
なるほど。
JoJoGANの他のpyファイルを見てみると、参考とするptファイルを作成することもできるようでしたので、近くそちらでも試してみたいと思います。
ではでは。
実践GAN ~敵対的生成ネットワークによる深層学習~ (Compass Booksシリーズ)