ころがる狸

ころがる狸のデータ解析ブログ

【半教師ありGAN】GANによるデータ拡張とMNISTの画像分類

こんにちは。2020年の上半期も終わりそうです。時間が経つのは本当に早いですね。個人的には、ブログをとにかく書きまくった半年でした。
それでは、本日の記事の紹介です!

はじめに

今日は、GANを使った半教師あり学習(半教師ありGAN)による画像分類に取り組んでみたいと思います。半教師あり学習というのは、データセットのうち少数にのみラベル付けがされており、残りのデータはラベル付けされておらず教師データが存在しないデータセットに対する学習のことを指します。データ数が少ないと機械学習を用いても一データセットの部の特徴しか捉えることしかできず、汎化性能の向上があまり見込めません。そこで、少ない教師ありの学習データから高い予測精度を実現するための1つの方策が、GANによるデータ拡張になります。

f:id:Dajiro:20200606085442p:plain
学習方法の3分類。

画像認識分野では、データ拡張というテクニックが一般的に用いられています。学習用データに回転・拡大・反転、またはコントラストを変えるといった処理を施すことで、1枚の画像から複数枚の画像を複製することができます。これにより画像の予測精度の向上が期待できます。

ただし、あくまでこれはオリジナル画像からの複製のため本当の意味で学習データの多様性が向上したとは言えません。可能であれば、様々な画像そのものを大量に認識させたいところです。そこで、敵対的生成ネットワーク(GAN)を使って本物そっくりの画像を生成すればいいじゃないか!という発想が生まれてきます。これにより識別器の分類精度向上を図ったのがGANによるデータ拡張であったり、半教師ありGANであったりします。GANの仕組みは過去の記事で解説していますのでこちらもご覧下さい。
dajiro.hatenablog.com

半教師ありGANの仕組み

ここでは概念的な説明を行います。半教師ありGANでは、以下の3つの学習データを用いて分類器を訓練します。

  1. ラベルありデータ
  2. ラベルなしデータ
  3. 人工的に生成されたフェイクデータ

分類器の訓練

これらを分類器に渡した後、入力されたデータの種類によって分類器には異なる学習がなされます。まずラベルありデータに対しては、それがどのラベルに分類されるものかという分類を行います。これは一般的な画像などの分類問題と同じなので、難しくはありませんね。

次にラベルなしデータフェイクデータですが、これらが入力されるとそのデータが真のデータか、GANによって生成された偽のデータかの2値分類が行われます。つまり分類器はラベル分類と2値分類の2つのタスクによって賢くなっていくというイメージです。

生成器の訓練

そして最後にフェイクデータを生成する生成器の訓練です。生成器としては分類器が誤って真と判定するような本物そっくりの画像を生成したいので、フェイクデータのラベルを「真」とすることで学習を行います。分類器から吐き出される結果が「真」なら損失関数は小さく(正解)、「偽」なら大きくなる(不正解)となるので真のデータと判定されるように生成器の学習が進みます。この辺りは基本的なGANの仕組みと同様です。

イメージ図はこんな感じです。こちらの説明と図を対応させながら眺めてみてください。

f:id:Dajiro:20200605003456p:plain
半教師ありGAN(SGAN)のイメージ図

実装(省略)

実際にMNISTで半教師ありGANを動かしてみましたが、実装はここでは省略します。以下の本に実装やGANの仕組みの詳細が解説されているので、ぜひ買ってみてください!

比較用に、SGANを使わない普通の教師あり学習も計算しました。そのコードだけ記載します。

supervised_losses = []
imgs_list = []
accuracy_lst = []
iteration_checkpoints = []

def train(iterations, batch_size, sample_interval):
    for iteration in range(iterations):
        #----------
        # 画像準備
        #----------
        #本物、ラベル付き
        imgs, labels = dataset.batch_labeled(batch_size)
        #ラベルのワンホット表現
        labels = to_categorical(labels, num_classes = num_classes)
        
        #----------
        # 訓練開始
        #----------
        #教師ありデータに対する分類器
        d_loss_supervised, accuracy = mnist_classifier.train_on_batch(imgs, labels)
        
        if (iteration + 1) % sample_interval == 0:
            supervised_losses.append(d_loss_supervised)
            accuracy_lst.append(accuracy)
            iteration_checkpoints.append(iteration + 1)
            
            print(
                "%d [D loss supervised: %.4f, acc.: %.2f%%]"
                 % (iteration + 1, d_loss_supervised, 100 * accuracy))

実験結果

MNISTを用いて実験を行った結果を見てみましょう。もちろんMNISTの画像データにはすべてラベルがついてるので、ラベルを与える枚数を変えながら学習を行い、テスト用データに対する分類精度を確認しました。

下の棒グラフではラベルあり画像の枚数を50, 100, 1000と変えたときのSGANと、比較用として教師ありデータのみを用いた分類器の予測精度をプロットしました。たった50枚の画像を使った学習でも予測精度88%を実現できています!GANを使わない普通の画像分類では精度66%なので、その差は歴然としています。また、100枚では90%, そして1000枚では98%となりそのほとんどを分類出来ていることが分かります。なお教師あり画像1000枚の場合はSGANを使わない場合でも94%の精度が出ているので、教師あり画像枚数が少ないほどSGANは威力を発揮すると言えそうですね。

f:id:Dajiro:20200606102628p:plain
教師あり画像の枚数を変えたときのSGANの予測精度。比較用に教師ありデータのみを用いた精度もプロット。

最後に、GANがどのような偽画像を生成しているのかを見てみましょう。最初はランダムな画像ですが、徐々に「MNISTしぐさ」を覚えていっているようです。また完全ではありませんが数字のような文字が生成されています。このような偽画像と本物の教師なし画像を見比べて、分類器は本物の画像の特徴を学習しているのだと思います。

f:id:Dajiro:20200606102113p:plain
SGANによって生成された各エポックでの偽画像。

少量データほど威力が高いということが分かったので、SGANをガンガン使っていきたいと思います!・・・と言いたいところですが、なかなか「ほとんどラベル付けされていない大量データを分類する」タスクに巡り合えないんですよね。診療画像が大量にある医療現場などでは、もしかすると親和性が高いかもしれませんね。