ころがる狸

ころがる狸のデータ解析ブログ

【DCGAN vs GAN】MNISTの生成画像比較と実装のコツ

こんにちは。先日、多層パーセプトロンモデルを使ったGAN(敵対的生成ネットワーク)で画像生成を行いました。しかし機械学習で画像と言えば、畳み込みニューラルネットワーク(CNN)ですよね。とうわけで生成器・識別器にCNNと転置CNNを使って効率的に画像を学習できるようにしたDCGAN(Deep Convolutional GAN)による画像生成を検証してみました。これにはPyTorch公式でセレブの顔を生成する分かりやすいチュートリアルがありますが、本記事ではMNISTの数字画像生成という比較的簡単なタスクをやってみました。

【参考】DCGANのPyTorchのチュートリアルです。本記事の実装の9割はこちらに依っています。
DCGAN Tutorial — PyTorch Tutorials 1.5.1 documentation
【参考】GAN入門にはこちらの記事をご覧ください!
dajiro.hatenablog.com

DCGANの概念図

GANの基本的な概念については上の記事をご覧ください。DCGANの学習の流れは一般的なGANとまったく変わりありませんが、生成器・識別器に畳み込み・転置畳み込みニューラルネットワークを用いているのが最大の特徴です。前回は多層パーセプトロンで画像を生成しましたが、それでは画像1ピクセル周りの特徴などを効率的に取り込むことができません。それを可能にしたのが(転置)畳み込みニューラルネットワークです。

識別器~畳み込みニューラルネットワーク

画像の真偽を判定する識別器には、畳み込みニューラルネットワークが使われます。これは画像を入力として、フィルタ(カーネルとも呼ばれる)によって画像の細かい特徴を自動抽出します。CNNのイメージ図は以下の通りです。

f:id:Dajiro:20200524104554p:plain
CNNのイメージ図
CNNの各段階では、画像の特徴の取り込みが行われます。畳み込み処理に関してはこちらのリンクで非常に分かりやすいため是非ご覧ください。そこで紹介されている畳み込み処理を表すgifがこちらです。ここでは説明を省きますが、元画像のサイズ、カーネルのサイズ、画像の外枠の埋め込み(パディング)、カーネルの移動幅(ストライド)というパラメータによって畳み込み処理は特徴づけられます。
GitHub - vdumoulin/conv_arithmetic: A technical report on convolution arithmetic in the context of deep learning

f:id:Dajiro:20200524004633g:plain
畳み込みのgif画像。paddingあり、strideあり。
畳み込み処理を行うと、画像のサイズが圧縮されます。圧縮後の画像サイズは以下の式で表されます。

H_{out}=\frac{H_{in}+2P-FH}{S}+1
W_{out}=\frac{W_{in}+2P-FW}{S}+1

ここで、H, Wは高さ・幅を表すパラメータです。またH_{out}, W_{out}は出力画像サイズ、H_{in}, W_{in}は入力画像サイズ、FH, FWはカーネルサイズ、S, Pはそれぞれストライドとパディングのパラメータです。実際にCNNを組む際には、設定したパラメータでどのようなサイズの画像が返ってくるのか計算しながら設計しましょう

生成器~転置畳み込みニューラルネットワーク

画像を生成するためには、転置畳み込みニューラルネットワークが使われます。これはベクトルを入力として、それを拡大するようにして画像を生成する仕組みのことを指します。

f:id:Dajiro:20200524110932p:plain
転置畳み込みニューラルネットワークによる画像生成

転置畳み込みの様子をもう少し詳しく見てみましょう。元の小さな画像の外枠・間に空のピクセルを入れて、そこにカーネル処理を施すことによって画像を拡大します。パディング・ストライド・カーネルサイズといったパラメータはここでも登場します。

f:id:Dajiro:20200524005023g:plain
転置畳み込みのgif画像。paddingあり、strideあり。
そして出力画像のサイズを計算するための公式がこちらです。今回のタスクでは28×28の出力を得たいので、それを目的としてパラメータを設定します
H_{out}=(H_{in}-1)S-2P+FH
W_{out}=(W_{in}-1)S-2P+FW
ちなみにこれは、畳み込み処理における画像サイズを求める公式をH_{in}, W_{in}について解いたものと一致しています。また、他にdilationやout_paddingというパラメータがありそれらを組み込んだ完全な公式が存在しますが、私自身が理解できてないので(汗)、ここではいじらないことにしました。以下のリンクが参考になるかもしれません。
GitHub - vdumoulin/conv_arithmetic: A technical report on convolution arithmetic in the context of deep learning
ConvTranspose2d — PyTorch master documentation

PyTorch実装

PyTorchのDCGANチュートリアルの実装にのっとっているため、そちらをご覧ください。注意点としては、GANを構成する生成器と更新器の画像サイズの設計です。例えば画像を逆畳み込みで拡大する場合、元画像サイズ6, カーネルサイズ2, ストライド2, パディング1の場合出力される画像は10×10の画像です。このようにパラメータと画像サイズの関係に注意しながらGANを構成します。識別器と生成器以外はチュートリアルの実装とほぼ同じです。

生成器

# Generator Code

class Generator(nn.Module):
    def __init__(self, ngpu):
        super(Generator, self).__init__()
        self.ngpu = ngpu
        self.main = nn.Sequential(
            # input is Z, going into a convolution
            nn.ConvTranspose2d( nz, ngf * 8, 3, 1, 0, bias=False),
            nn.BatchNorm2d(ngf * 8),
            nn.ReLU(True),
            # state size. (ngf*8) x 3 x 3
            nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 4),
            nn.ReLU(True),
            # state size. (ngf*4) x 6 x 6
            nn.ConvTranspose2d( ngf * 4, ngf * 2, 2, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 2),
            nn.ReLU(True),
            # state size. (ngf*2) x 10 x 10
            nn.ConvTranspose2d( ngf * 2, ngf, 2, 2, 2, bias=False),
            nn.BatchNorm2d(ngf),
            nn.ReLU(True),
            # state size. (ngf) x 16 x 16
            nn.ConvTranspose2d( ngf, nc, 2, 2, 2, bias=False),
            nn.Tanh()
            # state size. (nc) x 28 x 28
        )

    def forward(self, input):
        return self.main(input)

識別器

class Discriminator(nn.Module):
    def __init__(self, ngpu):
        super(Discriminator, self).__init__()
        self.ngpu = ngpu
        self.main = nn.Sequential(
            # input is (nc) x 28 x 28
            nn.Conv2d(nc, ndf, 2, 2, 2, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf) x 16 x 16
            nn.Conv2d(ndf, ndf * 2, 2, 2, 2, bias=False),
            nn.BatchNorm2d(ndf * 2),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf*2) x 10 x 10
            nn.Conv2d(ndf * 2, ndf * 4, 2, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 4),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf*4) x 6 x 6
            nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 8),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf*8) x 3 x 3
            nn.Conv2d(ndf * 8, 1, 3, 1, 0, bias=False),
            nn.Sigmoid()
        )

    def forward(self, input):
        return self.main(input)

結果

ようやく結果です。損失関数の振る舞いを見ると、生成器の損失は下がるよりむしろ上がっていますが、学習できているのか不安になりますね。識別器の方は損失が減少しています。

f:id:Dajiro:20200524010602p:plain
生成器(G)と判別器(D)の損失関数

そして実際に生成された画像がこちら。これは1エポックの学習を行って得られた画像ですが、ニョロニョロとした画像が生成されておりまだ数字には成長していません。なんだが、植物の萌芽のようで神秘的な印象を受けました

f:id:Dajiro:20200524112424p:plain
最初の学習時の画像(左:真の画像、右:生成器の画像)
続いて、5エポックの学習完了時の画像です。既に数字のような画像が生成されていますね。人の顔画像生成に比べると簡単なタスクなので、わずかな学習ステップでもそれらしい結果が得られているのかもしれません。
f:id:Dajiro:20200524112706p:plain
5エポック時の生成画像(左:真の画像、右:生成器の画像)
そして、学習終了段階の画像がこちら。これはほとんど見分けがつかないレベルですね。よく見ると数字とは言えない画像が混じっていますが、数字の崩れた感じというか、MNISTらしさはしっかり学習できているように見えます。
f:id:Dajiro:20200524112857p:plain
学習終了時の生成画像(左:真の画像、右:生成器の画像)

それでは最後に、多層パーセプトロン(MLP)で得られた結果と見比べてみましょう。MLPのGANの方がややノイズが乗っているようには見え、数字が太い印象があります。DCGANの方がよりはっきりとした画像な気がしますが、当初予想していたよりもはっきりとした違いが見えません。もっと難しい画像の生成で比較すると、恐らくDCGANの方が良い精度になると思われますが、MNISTではあまり差がないことが分かりました。

f:id:Dajiro:20200524150822p:plain
GANとDCGANの生成画像の比較(左:MLPによるGAN, 右:DCGAN)