動画の分類やってみた【図解速習DEEP LEARNING】#007
こんにちは!こーたろーです。
今日もこちら!
【図解速習DEEP LEARNING】
やっていきます。
今日は動画の分類です。
サンプルとして、学習済みのモデルに動画を判別してもらいます。
1.環境設定、ライブラリのインポート
※教材では、tensorflow 1.xで使用されていましたが、今回は現在のVer2.4.0でやっていきます。
import tensorflow as tf import tensorflow_hub as hub import random import re import os import tempfile import cv2 import numpy as np import imageio from IPython import display from urllib import request
今回は、Tensorflow Hubというライブラリ内のI3D(Inflated 3D Convnet)モジュールを使用していきます。
2.データセットのダウンロード
UCF101データセットというものを使用します。
こちらは、動画とその動画がなんの動作をしているところなのかをまとめたものとなっています。
動画取得のための変数を定義します。
変数前に「_」がついているものは、Pythonの特徴ともいれるソースの書き方だと思います。
今後、pythonの基礎的な部分もブログにしたいと思っていますので、その際に触れてみたいと思います。
pythonを行っているときに躓く一つの要因だと勝手に思っていますので。。。
UCF_ROOT = "http://crcv.ucf.edu/THUMOS14/UCF101/UCF101/" _VIDEO_LIST = None _CACHE_DIR = tempfile.mkdtemp()
動画をリスト化するための関数を定義
def list_ucf_videos(): """Lists videos available in UCF101 dataset.""" global _VIDEO_LIST if not _VIDEO_LIST: index = request.urlopen(UCF_ROOT).read().decode("utf-8") videos = re.findall("(v_[\w_]+\.avi)", index) _VIDEO_LIST = sorted(set(videos)) return list(_VIDEO_LIST)
動画を読み込み、キャッシュを残す
def fetch_ucf_video(video): """Fetchs a video and cache into local filesystem.""" cache_path = os.path.join(_CACHE_DIR, video) if not os.path.exists(cache_path): urlpath = request.urljoin(UCF_ROOT, video) print("Fetching %s => %s" % (urlpath, cache_path)) data = request.urlopen(urlpath).read() open(cache_path, "wb").write(data) return cache_path
CV2を使用して、動画をロードする。
def crop_center_square(frame): y, x = frame.shape[0:2] min_dim = min(y, x) start_x = (x // 2) - (min_dim // 2) start_y = (y // 2) - (min_dim // 2) return frame[start_y:start_y+min_dim,start_x:start_x+min_dim] def load_video(path, max_frames=0, resize=(224, 224)): cap = cv2.VideoCapture(path) frames = [] try: while True: ret, frame = cap.read() if not ret: break frame = crop_center_square(frame) frame = cv2.resize(frame, resize) frame = frame[:, :, [2, 1, 0]] frames.append(frame) if len(frames) == max_frames: break finally: cap.release() return np.array(frames) / 255.0 def animate(images): converted_images = np.clip(images * 255, 0, 255).astype(np.uint8) imageio.mimsave('./animation.gif', converted_images, fps=25) with open('./animation.gif','rb') as f: display.display(display.Image(data=f.read(), height=300))
3.kinetics-400のラベルを取得する
KINETICS_URL = "https://raw.githubusercontent.com/deepmind/kinetics-i3d/master/data/label_map.txt" with request.urlopen(KINETICS_URL) as obj: labels = [line.decode("utf-8").strip() for line in obj.readlines()] print("Found %d labels." % len(labels))
UCF101データセットの中身を確認してみます。
ucf_videos = list_ucf_videos() categories = {} for video in ucf_videos: category = video[2:-12] if category not in categories: categories[category] = [] categories[category].append(video) print("Found %d videos in %d categories." % (len(ucf_videos), len(categories))) for category, sequences in categories.items(): summary = ", ".join(sequences[:2]) print("%-20s %4d videos (%s, ...)" % (category, len(sequences), summary))
4.クリケットの動画を取得
sample_video = load_video(fetch_ucf_video("v_CricketShot_g04_c02.avi")) print("sample_video is a numpy array of shape %s." % str(sample_video.shape)) animate(sample_video)
5.i3dモデルの評価
TensorFlow Hubにある、学習済みモデルi3dを使用して、
クリケット動画がどのように判別されるか評価してみましょう。
model_input = np.expand_dims(sample_video, axis=0) with tf.Graph().as_default(): i3d = hub.Module("https://tfhub.dev/deepmind/i3d-kinetics-400/1") input_placeholder = tf.compat.v1.placeholder(shape=(None, None, 224, 224, 3), dtype=tf.float32) logits = i3d(input_placeholder) probabilities = tf.nn.softmax(logits) with tf.compat.v1.train.MonitoredSession() as session: [ps] = session.run(probabilities, feed_dict={input_placeholder: model_input}) print("Top 5 actions:") for i in np.argsort(ps)[::-1][:5]: print("%-22s %.2f%%" % (labels[i], ps[i] * 100))
結果として、こちらの動画は、97%以上の確率で「クリケット動画である」という判別結果が出ました。
i3dは非常に高い識別ができるモデルとなっていることが分かります。
今日のプログラムは、ライブラリ内のモジュールの扱いが多く、知らないものもあったので、後日詳細解説したいと思います。
ではでは。。