福岡人データサイエンティストの部屋

データサイエンスを極めるため、日々の学習を綴っています。

【医療画像AI】U-Netを使ってX線画像の領域分割やってみた【Python】#006



こんにちは!こーたろーです。


今回はU-Netというモデルを使った領域分割を行ってみます。


画像の中に何があるかを区別するために、画像の対象を区切って利用することがあります。
その時に必要となってくるのが、この領域分割です。





こちらの図のように、身体・肺・心臓・背景のように分割したい場合に、U-Netを使って学習させたモデルがどのぐらい領域分割できるか試してみます。


それでは早速始めていきます。


データの準備



データは毎度おなじみの日本放射線技術学会の画像部会が公開しているデータを使用します。
miniJSRT_database | 日本放射線技術学会 画像部会


リンクをスクロールして、「Segmentation01」をダウンロードして、展開後、ホームディレクトリに設置します。


今回利用しているデータは、分割後の画像はpng、現画像はbmpとなっていることに注意します。


なお、分割後の画像についてですが、画素数
   肺 :255
  心臓:85
  身体:50
  背景:0
となっています。


それでは、データを読み込んでいきたいと思います。

import os
import cv2
import numpy as np
from keras.layers import Input, Conv2D, MaxPooling2D, UpSampling2D, BatchNormalization
from keras.layers.merge import concatenate
from keras.models import Model
import matplotlib.pyplot as plt



まずはライブラリをインポートしておきます。


次に、今回の処理で指定する画像のサイズと学習のエポック数だけ定義しておきます。

IMAGE_SIZE = 64
EPOCHS = 10


データの読み込み



まずは、データの前処理を行っていきます。


続いて、大量のデータが複数のフォルダに分かれていることから、データを読み込むための関数を準備します。


関数では、ホームディレクトリにおいたファイルのパスと、変更後の画像サイズを指定していきます。

def load_images(inputpath, imagesize):
    imglist = []

    for root, dirs, files in os.walk(inputpath):
        files[:] = [file for file in files if not file.startswith(('__', '.'))]

        for fn in sorted(files):
            bn, ext = os.path.splitext(fn)                
            filename = os.path.join(root, fn)
            testimage = cv2.imread(filename, cv2.IMREAD_GRAYSCALE)
            height, width = testimage.shape[:2]
            testimage = cv2.resize(testimage, (imagesize, imagesize), interpolation = cv2.INTER_AREA)  
            testimage = np.asarray([testimage], dtype=np.float64)
            testimage = np.asarray(testimage, dtype=np.float64).reshape((1, imagesize, imagesize))
            testimage = testimage.transpose(1, 2, 0)

            imglist.append(testimage)
    imgsdata = np.asarray(imglist, dtype=np.float32)

    return imgsdata, sorted(files)



そしてデータを読み込みます。

image_train, image_train_filenames = load_images("./segmentation/org_train/", IMAGE_SIZE)
label_train, label_train_filenames = load_images("./segmentation/label_train/", IMAGE_SIZE)

image_test, image_test_filenames = load_images("./segmentation/org_test/", IMAGE_SIZE)
label_test, label_test_filenames = load_images("./segmentation/label_test/", IMAGE_SIZE)


データの前処理



続いて、データの前処理です。今回、領域は肺とそれ以外を区別するものを作ろうと考えています。
そのため、前処理段階で、画素数が255のものはそのまま残し、それ以外を0にする処理を加えます。
そして、最後に正規化を行います。

for organ in range(0, len(label_train)):
    for i in range(0, IMAGE_SIZE):
        for j in range(0, IMAGE_SIZE):
            if(label_train[organ][i][j][0] != 255):
                label_train[organ][i][j][0] = 0
            else:
                label_train[organ][i][j][0] = 255
                    
for organ in range(0, len(label_test)):
    for i in range(0, IMAGE_SIZE):
        for j in range(0, IMAGE_SIZE):
            if(label_test[organ][i][j][0] != 255):
                label_test[organ][i][j][0] = 0
            else:
                label_test[organ][i][j][0] = 255

image_train /= np.max(image_train)
label_train /= np.max(label_train)
image_test /= np.max(image_test)
label_test /= np.max(label_test)


U-Netのネットワークの定義



U-Netのライブラリ(パッケージ)は一応公開されているようでしたが、今回は手組みしてみました。


ちなみにいろいろな参考文献・書籍などをみながら、次のネットワークが一番わかりやすかったので参考にしてみました。

def network_unet():
    input_data = Input(shape=(IMAGE_SIZE, IMAGE_SIZE, 1))

    encode1 = Conv2D(64, kernel_size=3, strides=1, activation="relu", padding="same")(input_data)
    encode1 = BatchNormalization()(encode1)
    encode1 = Conv2D(64, kernel_size=3, strides=1, activation="relu", padding="same")(encode1)
    encode1 = BatchNormalization()(encode1)
    downsample1 = MaxPooling2D(pool_size=2, strides=2)(encode1)
    
    encode2 = Conv2D(128, kernel_size=3, strides=1, activation="relu", padding="same")(downsample1)
    encode2 = BatchNormalization()(encode2)
    encode2 = Conv2D(128, kernel_size=3, strides=1, activation="relu", padding="same")(encode2)
    encode2 = BatchNormalization()(encode2)
    downsampling2 = MaxPooling2D(pool_size=2, strides=2)(encode2)

    encode3 = Conv2D(256, kernel_size=3, strides=1, activation="relu", padding="same")(downsampling2)
    encode3 = BatchNormalization()(encode3)
    encode3 = Conv2D(256, kernel_size=3, strides=1, activation="relu", padding="same")(encode3)
    encode3 = BatchNormalization()(encode3)
    downsampling3 = MaxPooling2D(pool_size=2, strides=2)(encode3)
    
    encode4 = Conv2D(512, kernel_size=3, strides=1, activation="relu", padding="same")(downsampling3)
    encode4 = BatchNormalization()(encode4)
    encode4 = Conv2D(512, kernel_size=3, strides=1, activation="relu", padding="same")(encode4)
    encode4 = BatchNormalization()(encode4)
    downsampling4 = MaxPooling2D(pool_size=2, strides=2)(encode4)
    
    encode5 = Conv2D(1024, kernel_size=3, strides=1, activation="relu", padding="same")(downsampling4)
    encode5 = BatchNormalization()(encode5)
    encode5 = Conv2D(1024, kernel_size=3, strides=1, activation="relu", padding="same")(encode5)
    encode5 = BatchNormalization()(encode5)

    upsampling4 = UpSampling2D(size=2)(encode5)
    decode4 = concatenate([upsampling4, encode4], axis=-1)
    decode4 = Conv2D(512, kernel_size=3, strides=1, activation="relu", padding="same")(decode4)
    decode4 = BatchNormalization()(decode4)
    decode4 = Conv2D(512, kernel_size=3, strides=1, activation="relu", padding="same")(decode4)
    decode4 = BatchNormalization()(decode4)
    
    upsampling3 = UpSampling2D(size=2)(decode4)
    decode3 = concatenate([upsampling3, encode3], axis=-1)
    decode3 = Conv2D(256, kernel_size=3, strides=1, activation="relu", padding="same")(decode3)
    decode3 = BatchNormalization()(decode3)
    decode3 = Conv2D(256, kernel_size=3, strides=1, activation="relu", padding="same")(decode3)
    decode3 = BatchNormalization()(decode3)

    upsampling2 = UpSampling2D(size=2)(decode3)
    decode2 = concatenate([upsampling2, encode2], axis=-1)
    decode2 = Conv2D(128, kernel_size=3, strides=1, activation="relu", padding="same")(decode2)
    decode2 = BatchNormalization()(decode2)
    decode2 = Conv2D(128, kernel_size=3, strides=1, activation="relu", padding="same")(decode2)
    decode2 = BatchNormalization()(decode2)
    
    upsampling1 = UpSampling2D(size=2)(decode2)
    decode1 = concatenate([upsampling1, encode1], axis=-1)
    decode1 = Conv2D(64, kernel_size=3, strides=1, activation="relu", padding="same")(decode1)
    decode1 = BatchNormalization()(decode1)
    decode1 = Conv2D(64, kernel_size=3, strides=1, activation="relu", padding="same")(decode1)
    decode1 = BatchNormalization()(decode1)
    
    decode1 = Conv2D(1, kernel_size=1, strides=1, activation="sigmoid", padding="same")(decode1)
    
    model = Model(input=input_data, output=decode1)
    
    return model



このU-Netのモデル概要は、以下のサイトを参考にしてください。


U-Net: Convolutional Networks for Biomedical Image Segmentation





バイオメディカルのセグメンテーションのためのモデルのようです。

モデルの作成と学習・評価



ネットワークオブジェクトを定義して、学習を行っていきます。

model = network_unet()
model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['acc'])

training = model.fit(image_train, label_train, epochs=EPOCHS, batch_size=12,
               shuffle=True, validation_data=(image_test, label_test), verbose=1)



今回は、0と1(255を正規化した値)の2値分類に相当しますので、'binary_crossentropy'を選択しています。


計算の状況を確認します。accurasyとlossをグラフ化してみました。

def plot_history(history):
    plt.plot(history.history['acc'])
    plt.plot(history.history['val_acc'])
    plt.title('model accuracy')
    plt.xlabel('epoch')
    plt.ylabel('accuracy')
    plt.legend(['acc', 'val_acc'], loc='lower right')
    plt.show()
    
    plt.plot(history.history['loss'])
    plt.plot(history.history['val_loss'])
    plt.title('model loss')
    plt.xlabel('epoch')
    plt.ylabel('loss')
    plt.legend(['loss', 'val_loss'], loc='lower right')
    plt.show()
    
plot_history(training)






少しオーバーフィッティング気味かなという印象を受けます。
何はともあれ、出力画像が分割できているか確認してみましょう。

results = model.predict(image_train, verbose=1)

n = 20
plt.figure(figsize=(40, 4))
for i in range(n):
   ax = plt.subplot(3, n, i+1)
   plt.imshow(image_test[i].reshape(IMAGE_SIZE, IMAGE_SIZE))
   plt.gray()
   ax.get_xaxis().set_visible(False)
   ax.get_yaxis().set_visible(False)
   
   ax = plt.subplot(3, n, i+1+n)
   plt.imshow(results[i].reshape(IMAGE_SIZE, IMAGE_SIZE))
   plt.gray()
   ax.get_xaxis().set_visible(False)
   ax.get_yaxis().set_visible(False)

   ax = plt.subplot(3, n, i+1+2*n)
   plt.imshow(label_test[i].reshape(IMAGE_SIZE, IMAGE_SIZE))
   plt.gray()
   ax.get_xaxis().set_visible(False)
   ax.get_yaxis().set_visible(False)
   
plt.show()



原画像と分割画像と正解のラベル画像を並べてみました。





これままだまだ改良の余地がありそうですね。


おそらく前処理のどこかに問題があるか、255×255ピクセルの画像を64×64に圧縮しているのが悪いかは不明です。


また、違う結果がでたら報告できればと思います。


ではでは。


Python機械学習プログラミング Pythonで学ぶコード進行解析講座 Tensorflowでディープラーニングに挑戦!: コード進行データの加工方法からコード進行自動生成AIの作成まで詳細解説 Keras, Tensorflowでコーディング GoogleColaboratolyを開発環境に使用