Universal Encoder使ってみた!(自然言語処理)【図解速習DeepLearning】#014
こんにちは!こーたろーです。
本日は、また課題テキストの【図解速習DEEP LEARNING】に戻って、課題を進めていきます!
今回は、TF-Hubが提供しているUniversal Sentence Encoderという学習済みモデルを使って、文章の類似度を判定していきます!
1.必要なライブラリーのインポート
eager_execution は disableにしておきましょう。 こちらはTensorflowのバージョンアップに伴う対応です。
import tensorflow as tf import tensorflow_hub as hub import matplotlib.pyplot as plt import numpy as np import os import pandas as pd import re import seaborn as sns tf.compat.v1.disable_eager_execution()
2.モデルの取得
今回は、TF-Hubから学習済みのモデルを取得して活用するため、モデルのロードが必要となります。
embed = hub.Module("https://tfhub.dev/google/universal-sentence-encoder/1")
3.単語・文章・段落を分散表現に変更する
word = "Elephant" sentence = "I am a sentence for which I would like to get its embedding." paragraph = ( "Universal Sentence Encoder embeddings also support short paragraphs. " "There is no hard limit on how long the paragraph is. Roughly, the longer " "the more 'diluted' the embedding will be.") messages = [word, sentence, paragraph] tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR) with tf.compat.v1.Session() as session: session.run([tf.compat.v1.global_variables_initializer(), tf.compat.v1.tables_initializer()]) message_embeddings = session.run(embed(messages)) for i, message_embedding in enumerate(np.array(message_embeddings).tolist()): print("Message: {}".format(messages[i])) print("Embedding size: {}".format(len(message_embedding))) message_embedding_snippet = ", ".join( (str(x) for x in message_embedding[:3])) print("Embedding: [{}, ...]\n".format(message_embedding_snippet))
4.文意の類似度を判定し、可視化
def plot_similarity(labels, features, rotation): corr = np.inner(features, features) sns.set(font_scale=1.2) g = sns.heatmap( corr, xticklabels=labels, yticklabels=labels, vmin=0, vmax=1, cmap="YlOrRd") g.set_xticklabels(labels, rotation=rotation) g.set_title("Semantic Textual Similarity") def run_and_plot(session_, input_tensor_, messages_, encoding_tensor): message_embeddings_ = session_.run( encoding_tensor, feed_dict={input_tensor_: messages}) plot_similarity(messages_, message_embeddings_, 90)
入力の「massage」を定義して、確認していきます。
messages = [ "I like my phone", "My phone is not good.", "Your cellphone looks great.", "Will it snow tomorrow?", "Recently a lot of hurricanes have hit the US", "Global warming is real", "An apple a day, keeps the doctors away", "Eating strawberries is healthy", "Is paleo better than keto?", "How old are you?", "what is your age?", ] similarity_input_placeholder = tf.compat.v1.placeholder(tf.string, shape=(None)) similarity_message_encodings = embed(similarity_input_placeholder) with tf.compat.v1.Session() as session: session.run(tf.compat.v1.global_variables_initializer()) session.run(tf.compat.v1.tables_initializer()) run_and_plot(session, similarity_input_placeholder, messages, similarity_message_encodings)
ヒートマップで類似度を表示させています。
3つずつで類似した文章になっているのが色で確認んできますね。
STSベンチマークというものを使うと、分の埋め込みによって計算された類似度スコアが、どの程度人間の判断と一致するかが評価できるそうです。
import pandas import scipy import math def load_sts_dataset(filename): sent_pairs = [] with tf.io.gfile.GFile(filename, "r") as f: for line in f: ts = line.strip().split("\t") sent_pairs.append((ts[5], ts[6], float(ts[4]))) return pandas.DataFrame(sent_pairs, columns=["sent_1", "sent_2", "sim"]) def download_and_load_sts_data(): sts_dataset = tf.keras.utils.get_file( fname="Stsbenchmark.tar.gz", origin="http://ixa2.si.ehu.es/stswiki/images/4/48/Stsbenchmark.tar.gz", extract=True) sts_dev = load_sts_dataset( os.path.join(os.path.dirname(sts_dataset), "stsbenchmark", "sts-dev.csv")) sts_test = load_sts_dataset( os.path.join( os.path.dirname(sts_dataset), "stsbenchmark", "sts-test.csv")) return sts_dev, sts_test sts_dev, sts_test = download_and_load_sts_data()
text_a = sts_dev['sent_1'].tolist() text_b = sts_dev['sent_2'].tolist() dev_scores = sts_dev['sim'].tolist() sts_input1 = tf.compat.v1.placeholder(tf.string, shape=(None)) sts_input2 = tf.compat.v1.placeholder(tf.string, shape=(None)) sts_encode1 = tf.nn.l2_normalize(embed(sts_input1)) sts_encode2 = tf.nn.l2_normalize(embed(sts_input2)) sim_scores = tf.reduce_sum(tf.multiply(sts_encode1, sts_encode2), axis=1)
def run_sts_benchmark(session): emba, embb, scores = session.run( [sts_encode1, sts_encode2, sim_scores], feed_dict={ sts_input1: text_a, sts_input2: text_b }) return scores with tf.compat.v1.Session() as session: session.run(tf.compat.v1.global_variables_initializer()) session.run(tf.compat.v1.tables_initializer()) scores = run_sts_benchmark(session) pearson_correlation = scipy.stats.pearsonr(scores, dev_scores) print('Pearson correlation coefficient = {0}\np-value = {1}'.format( pearson_correlation[0], pearson_correlation[1]))
自然言語処理については、あと3回みたいです!
次回もお楽しみに♪
ではでは。。
- 作者:辻 真吾
- 発売日: 2018/04/12
- メディア: 大型本