福岡の社会人データサイエンティストの部屋

データサイエンスを極めるため、日々の学習を綴っています。

動画の分類やってみた【図解速習DEEP LEARNING】#007

こんにちは!こーたろーです。

今日もこちら!
図解速習DEEP LEARNING

やっていきます。

今日は動画の分類です。

サンプルとして、学習済みのモデルに動画を判別してもらいます。


1.環境設定、ライブラリのインポート

※教材では、tensorflow 1.xで使用されていましたが、今回は現在のVer2.4.0でやっていきます。

import tensorflow as tf
import tensorflow_hub as hub
import random
import re
import os
import tempfile
import cv2
import numpy as np
import imageio
from IPython import display
from urllib import request 

今回は、Tensorflow Hubというライブラリ内のI3D(Inflated 3D Convnet)モジュールを使用していきます。

2.データセットのダウンロード

UCF101データセットというものを使用します。
こちらは、動画とその動画がなんの動作をしているところなのかをまとめたものとなっています。

動画取得のための変数を定義します。
変数前に「_」がついているものは、Pythonの特徴ともいれるソースの書き方だと思います。
今後、pythonの基礎的な部分もブログにしたいと思っていますので、その際に触れてみたいと思います。
pythonを行っているときに躓く一つの要因だと勝手に思っていますので。。。

UCF_ROOT = "http://crcv.ucf.edu/THUMOS14/UCF101/UCF101/"
_VIDEO_LIST = None
_CACHE_DIR = tempfile.mkdtemp()


動画をリスト化するための関数を定義

def list_ucf_videos():
  """Lists videos available in UCF101 dataset."""
  global _VIDEO_LIST
  if not _VIDEO_LIST:
    index = request.urlopen(UCF_ROOT).read().decode("utf-8")
    videos = re.findall("(v_[\w_]+\.avi)", index)
    _VIDEO_LIST = sorted(set(videos))
  return list(_VIDEO_LIST)

動画を読み込み、キャッシュを残す

def fetch_ucf_video(video):
  """Fetchs a video and cache into local filesystem."""
  cache_path = os.path.join(_CACHE_DIR, video)
  if not os.path.exists(cache_path):
    urlpath = request.urljoin(UCF_ROOT, video)
    print("Fetching %s => %s" % (urlpath, cache_path))
    data = request.urlopen(urlpath).read()
    open(cache_path, "wb").write(data)
  return cache_path

CV2を使用して、動画をロードする。

def crop_center_square(frame):
  y, x = frame.shape[0:2]
  min_dim = min(y, x)
  start_x = (x // 2) - (min_dim // 2)
  start_y = (y // 2) - (min_dim // 2)
  return frame[start_y:start_y+min_dim,start_x:start_x+min_dim]

def load_video(path, max_frames=0, resize=(224, 224)):
  cap = cv2.VideoCapture(path)
  frames = []
  try:
    while True:
      ret, frame = cap.read()
      if not ret:
        break
      frame = crop_center_square(frame)
      frame = cv2.resize(frame, resize)
      frame = frame[:, :, [2, 1, 0]]
      frames.append(frame)
      
      if len(frames) == max_frames:
        break
  finally:
    cap.release()
  return np.array(frames) / 255.0

def animate(images):
  converted_images = np.clip(images * 255, 0, 255).astype(np.uint8)
  imageio.mimsave('./animation.gif', converted_images, fps=25)
  with open('./animation.gif','rb') as f:
      display.display(display.Image(data=f.read(), height=300))

3.kinetics-400のラベルを取得する

KINETICS_URL = "https://raw.githubusercontent.com/deepmind/kinetics-i3d/master/data/label_map.txt"
with request.urlopen(KINETICS_URL) as obj:
  labels = [line.decode("utf-8").strip() for line in obj.readlines()]
print("Found %d labels." % len(labels))

f:id:dsf-kotaro:20210127200815p:plain


UCF101データセットの中身を確認してみます。

ucf_videos = list_ucf_videos()
  
categories = {}
for video in ucf_videos:
  category = video[2:-12]
  if category not in categories:
    categories[category] = []
  categories[category].append(video)
print("Found %d videos in %d categories." % (len(ucf_videos), len(categories)))

for category, sequences in categories.items():
  summary = ", ".join(sequences[:2])
  print("%-20s %4d videos (%s, ...)" % (category, len(sequences), summary))

f:id:dsf-kotaro:20210127200917p:plain


4.クリケットの動画を取得

sample_video = load_video(fetch_ucf_video("v_CricketShot_g04_c02.avi"))

print("sample_video is a numpy array of shape %s." % str(sample_video.shape))
animate(sample_video)

f:id:dsf-kotaro:20210127201013p:plain


5.i3dモデルの評価

TensorFlow Hubにある、学習済みモデルi3dを使用して、
クリケット動画がどのように判別されるか評価してみましょう。

model_input = np.expand_dims(sample_video, axis=0)

with tf.Graph().as_default():
  i3d = hub.Module("https://tfhub.dev/deepmind/i3d-kinetics-400/1")
  input_placeholder = tf.compat.v1.placeholder(shape=(None, None, 224, 224, 3), dtype=tf.float32)
  logits = i3d(input_placeholder)
  probabilities = tf.nn.softmax(logits)
  with tf.compat.v1.train.MonitoredSession() as session:
    [ps] = session.run(probabilities,
                       feed_dict={input_placeholder: model_input})

print("Top 5 actions:")
for i in np.argsort(ps)[::-1][:5]:
  print("%-22s %.2f%%" % (labels[i], ps[i] * 100))

f:id:dsf-kotaro:20210127201053p:plain


結果として、こちらの動画は、97%以上の確率で「クリケット動画である」という判別結果が出ました。
i3dは非常に高い識別ができるモデルとなっていることが分かります。

今日のプログラムは、ライブラリ内のモジュールの扱いが多く、知らないものもあったので、後日詳細解説したいと思います。

ではでは。。