福岡の社会人データサイエンティストの部屋

データサイエンスを極めるため、日々の学習を綴っています。

Universal Encoder使ってみた!(自然言語処理)【図解速習DeepLearning】#014

こんにちは!こーたろーです。

本日は、また課題テキストの【図解速習DEEP LEARNING】に戻って、課題を進めていきます!

今回は、TF-Hubが提供しているUniversal Sentence Encoderという学習済みモデルを使って、文章の類似度を判定していきます!





1.必要なライブラリーのインポート


eager_execution は disableにしておきましょう。 こちらはTensorflowのバージョンアップに伴う対応です。


import tensorflow as tf
import tensorflow_hub as hub
import matplotlib.pyplot as plt
import numpy as np
import os
import pandas as pd
import re
import seaborn as sns

tf.compat.v1.disable_eager_execution()



2.モデルの取得




今回は、TF-Hubから学習済みのモデルを取得して活用するため、モデルのロードが必要となります。


embed = hub.Module("https://tfhub.dev/google/universal-sentence-encoder/1")



3.単語・文章・段落を分散表現に変更する



word = "Elephant"
sentence = "I am a sentence for which I would like to get its embedding."
paragraph = (
    "Universal Sentence Encoder embeddings also support short paragraphs. "
    "There is no hard limit on how long the paragraph is. Roughly, the longer "
    "the more 'diluted' the embedding will be.")
messages = [word, sentence, paragraph]

tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)

with tf.compat.v1.Session() as session:
  session.run([tf.compat.v1.global_variables_initializer(), tf.compat.v1.tables_initializer()])
  message_embeddings = session.run(embed(messages))

  for i, message_embedding in enumerate(np.array(message_embeddings).tolist()):
    print("Message: {}".format(messages[i]))
    print("Embedding size: {}".format(len(message_embedding)))
    message_embedding_snippet = ", ".join(
        (str(x) for x in message_embedding[:3]))
    print("Embedding: [{}, ...]\n".format(message_embedding_snippet))

f:id:dsf-kotaro:20210210205557p:plain




4.文意の類似度を判定し、可視化


def plot_similarity(labels, features, rotation):
  corr = np.inner(features, features)
  sns.set(font_scale=1.2)
  g = sns.heatmap(
      corr,
      xticklabels=labels,
      yticklabels=labels,
      vmin=0,
      vmax=1,
      cmap="YlOrRd")
  g.set_xticklabels(labels, rotation=rotation)
  g.set_title("Semantic Textual Similarity")


def run_and_plot(session_, input_tensor_, messages_, encoding_tensor):
  message_embeddings_ = session_.run(
      encoding_tensor, feed_dict={input_tensor_: messages})
  plot_similarity(messages_, message_embeddings_, 90)


入力の「massage」を定義して、確認していきます。



messages = [
    "I like my phone",
    "My phone is not good.",
    "Your cellphone looks great.",

    "Will it snow tomorrow?",
    "Recently a lot of hurricanes have hit the US",
    "Global warming is real",

    "An apple a day, keeps the doctors away",
    "Eating strawberries is healthy",
    "Is paleo better than keto?",

    "How old are you?",
    "what is your age?",
]

similarity_input_placeholder = tf.compat.v1.placeholder(tf.string, shape=(None))
similarity_message_encodings = embed(similarity_input_placeholder)
with tf.compat.v1.Session() as session:
  session.run(tf.compat.v1.global_variables_initializer())
  session.run(tf.compat.v1.tables_initializer())
  run_and_plot(session, similarity_input_placeholder, messages,
               similarity_message_encodings)



f:id:dsf-kotaro:20210210205645p:plain




ヒートマップで類似度を表示させています。


3つずつで類似した文章になっているのが色で確認んできますね。



【発展】STSベンチマークで評価



STSベンチマークというものを使うと、分の埋め込みによって計算された類似度スコアが、どの程度人間の判断と一致するかが評価できるそうです。


import pandas
import scipy
import math


def load_sts_dataset(filename):
  sent_pairs = []
  with tf.io.gfile.GFile(filename, "r") as f:
    for line in f:
      ts = line.strip().split("\t")
      sent_pairs.append((ts[5], ts[6], float(ts[4])))
  return pandas.DataFrame(sent_pairs, columns=["sent_1", "sent_2", "sim"])


def download_and_load_sts_data():
  sts_dataset = tf.keras.utils.get_file(
      fname="Stsbenchmark.tar.gz",
      origin="http://ixa2.si.ehu.es/stswiki/images/4/48/Stsbenchmark.tar.gz",
      extract=True)

  sts_dev = load_sts_dataset(
      os.path.join(os.path.dirname(sts_dataset), "stsbenchmark", "sts-dev.csv"))
  sts_test = load_sts_dataset(
      os.path.join(
          os.path.dirname(sts_dataset), "stsbenchmark", "sts-test.csv"))

  return sts_dev, sts_test


sts_dev, sts_test = download_and_load_sts_data()
text_a = sts_dev['sent_1'].tolist()
text_b = sts_dev['sent_2'].tolist()
dev_scores = sts_dev['sim'].tolist()
sts_input1 = tf.compat.v1.placeholder(tf.string, shape=(None))
sts_input2 = tf.compat.v1.placeholder(tf.string, shape=(None))

sts_encode1 = tf.nn.l2_normalize(embed(sts_input1))
sts_encode2 = tf.nn.l2_normalize(embed(sts_input2))
sim_scores = tf.reduce_sum(tf.multiply(sts_encode1, sts_encode2), axis=1)
def run_sts_benchmark(session):
  emba, embb, scores = session.run(
      [sts_encode1, sts_encode2, sim_scores],
      feed_dict={
          sts_input1: text_a,
          sts_input2: text_b
      })
  return scores


with tf.compat.v1.Session() as session:
  session.run(tf.compat.v1.global_variables_initializer())
  session.run(tf.compat.v1.tables_initializer())
  scores = run_sts_benchmark(session)

pearson_correlation = scipy.stats.pearsonr(scores, dev_scores)
print('Pearson correlation coefficient = {0}\np-value = {1}'.format(
    pearson_correlation[0], pearson_correlation[1]))

f:id:dsf-kotaro:20210210205805p:plain



自然言語処理については、あと3回みたいです!
次回もお楽しみに♪

ではでは。。



Pythonスタートブック [増補改訂版]

Pythonスタートブック [増補改訂版]

  • 作者:辻 真吾
  • 発売日: 2018/04/12
  • メディア: 大型本