ころがる狸

ころがる狸のデータ解析ブログ

【GAN + PyTorch】仕組みの解説とMNISTで画像生成

こんにちは。今日は敵対的生成ネットワーク(Generative Adversarial Network, GAN)を取り上げます。GANというと、適当な乱数から本物そっくりの画像を生成する技術として既にご存じかもしれません。画像以外にも物理モデルの生成や、化合物の構造生成などに適用されておりここ5年ほど多方面で盛り上がっています。本記事ではGANの基本について解説し、実際にMNISTの画像生成までを行いたいと思います。

MNISTと言えばこのような数字を表す画像によるデータセットです。これに似た画像を生成するのがGANのお仕事です。

f:id:Dajiro:20200523154356p:plain
MNISTデータセットの一部

【目次】

GANの直感的理解

GANのイメージ図がこちらです。GANの最も基本的な特徴は、生成器・識別器と呼ばれる2つのニューラルネットワークから構成されているということでしょう。これら2つのネットワークが協業することで学習データに似た画像を生成することができます。これらの具体的な働きを見てみます。

生成器(Generator, G)

画像の生成器に対応します。固定長のランダムな要素を持つベクトルをニューラルネットワークの入力として、画像の2次元形状に成形して出力するのが生成器の主な仕事です。無意味なベクトルから意味のある画像を生成できるなんて、ビックリしませんか?比喩を用いるなら、この固定長ベクトルは画像の「」と見做すことができます。出力されたフェイク画像は識別器によって真偽判定がなされます。生成器としては、自身が生成した画像を真と認識してほしい訳ですから、訓練する際にはフェイク画像に正解ラベルを与えるのが生成器の訓練における重要なテクニックになります。

生成器の学習時には、以下の関数を最小化します。

\frac{1}{m}\sum_{i=1}^{m}{\rm log}(1-D(G(\textbf{z}^{i})))

mはバッチサイズ、D(・)は識別器の出力を表し、真の画像なら1、偽の画像なら0を返します。またG(\textbf{z}^{i})はランダムなベクトル\textbf{z}^{i}から生成される画像を表します。生成器は判別器を騙したいので、偽の画像に対して1を返返すように訓練されます。そのため以上の関数を最小化することで生成器が更新されます。

識別器(Discriminator, D)

画像の真偽判定を行う識別器に対応します。Dの入力は、Gによって出力されたフェイク画像と学習データに含まれる真の画像データです。Dではそれぞれが真の画像か、偽の画像かを見分けるように訓練されます。最終的な出力は、真を表す1、偽を表す0となります。

識別器の目的関数はこちらです。これを最大化することで識別器としての機能が生まれます。

\frac{1}{m}\sum_{i=1}^{m}({\rm log}D(\textbf{x}^{i})+{\rm log}(1-D(G(\textbf{z}^{i}))))

D(\textbf{x}^{i})は真の画像\textbf{x}^{i}が入力されたときのDの出力、D(G(\textbf{z}^{i}))は偽の画像G(\textbf{z}^{i})が入力されたときの出力です。それぞれDが1、0を返すことを期待されるので、上式を最大化することが判別機の目的です。

図解

おそらく識別器の方が理解しやすいのではないでしょうか。こちらは一般的な画像の分類問題の考え方ですね。一方、生成器の学習に関してはあまりイメージがつかないかもしれません。生成器の学習には識別器の出力も使うことになりますが、このとき識別器の重みパラメータは更新せず、あくまで生成器の重みだけを更新する点に注意してください。ここでは識別器は単に予測のみを行っています。以下の図が以上の要約です。

f:id:Dajiro:20200521210800p:plain
GANのイメージ図。生成器の方が理解するのが難しい。

GANの実装

それでは、さっそく実装してみましょう。Pytorchを用いてMNISTの画像生成を行いました。また計算はAWSのGPUインスタンスを用いて行いました(GPUのローカルな計算環境がないんです汗 AWSだとマシンによりますが1時間100円くらいでGPUが使えるので、まぁいいか)。実装には以下のDCGANのチュートリアルを大いに参考にしました。
DCGAN Tutorial — PyTorch Tutorials 1.5.1 documentation

まずは、必要なモジュールやライブラリのインポート

import torch
import torch.nn as nn
from torch.utils.data import Dataset
from torchvision import datasets, transforms
import torch.nn.functional as F
import torch.optim as optim
import torchvision.utils as vutils
import torchvision

import numpy as np
import matplotlib.pyplot as plt

続いて、インプットの設定やデータのロードです。

#####input#####
b_size = 128
ngpu = 1
nz = 100
lr = 0.0001
num_epochs = 100
###############

#生成器でtanhを使っているため入力の範囲を[-1, 1]にする。
class Mydataset(Dataset):
    def __init__(self, mnist_data):
        self.mnist = mnist_data
    def __len__(self):
        return len(self.mnist)
    def __getitem__(self, idx):
        X = self.mnist[idx][0]
        X = (X * 2) - 1
        y = self.mnist[idx][1]
        
        return X, y

#MNISTのデータロード
mnist_data = datasets.MNIST('../data/MNIST', 
                            transform = transforms.Compose([
                                transforms.ToTensor()
                                ]),
                            download=True)
mnist_data_norm = Mydataset(mnist_data)
dataloader = torch.utils.data.DataLoader(mnist_data_norm,
                                          batch_size=b_size,
                                          shuffle=True)
device = torch.device("cuda:0" if (torch.cuda.is_available() and ngpu > 0) else "cpu")

続いて、生成器を定義しましょう。ポイントとしては、ベクトルを受け取り多層パーセプトロンを通した後に画像サイズに形を変えています。またバッチノーマライゼーションを使い各層の出力分布を調整します。

#generatorの定義
class Generator(nn.Module):
    def __init__(self, ngpu, input_size = 100, hid1_size = 256,
                 hid2_size = 512, hid3_size = 1024, batch_size = 4):
        super(Generator, self).__init__()
        self.ngpu = ngpu
        self.b_size = batch_size
        self.fc1 = nn.Linear(input_size, hid1_size)
        self.fc2 = nn.Linear(hid1_size, hid2_size)
        self.fc3 = nn.Linear(hid2_size, hid3_size)
        self.fc4 = nn.Linear(hid3_size, 28 * 28 * 1)
    
        self.bn1 = nn.BatchNorm1d(hid1_size)
        self.bn2 = nn.BatchNorm1d(hid2_size)
        self.bn3 = nn.BatchNorm1d(hid3_size)
        
        self.LeakyReLU = nn.LeakyReLU(0.2)
    
    def forward(self, input):
        x = input.view(-1, 100)
        x = self.LeakyReLU(self.fc1(x))
        x = self.bn1(x)
        x = self.LeakyReLU(self.fc2(x))
        x = self.bn2(x)
        x = self.LeakyReLU(self.fc3(x))
        x = self.bn3(x)
        x = torch.tanh(self.fc4(x))
        x = x.view(-1, 28, 28)
        return x

判別器です。画像を受け取り、それをベクトルに変換した後にスカラー値を出力しています。

#discriminatorの定義
class Discriminator(nn.Module):
    def __init__(self, ngpu, hid1_size = 1024, hid2_size = 512,
                 hid3_size = 256, batch_size = 4):
        super(Discriminator, self).__init__()
        self.ngpu = ngpu
        self.b_size = batch_size
        self.fc1 = nn.Linear(784, hid1_size)
        self.fc2 = nn.Linear(hid1_size, hid2_size)
        self.fc3 = nn.Linear(hid2_size, hid3_size)
        self.fc4 = nn.Linear(hid3_size, 1)
    
        self.LeakyReLU = nn.LeakyReLU(0.2)
    
    def forward(self, input):
        x = input.view(-1, 784)
        x = self.LeakyReLU(self.fc1(x))
        x = self.LeakyReLU(self.fc2(x))
        x = self.LeakyReLU(self.fc3(x))
        x = torch.sigmoid(self.fc4(x))
        return x

損失関数とoptimizerを定義しましょう。損失関数はバイナリクロスエントロピー、更新方法にはAdamを使っています。本来は学習率などのハイパーパラメータを細かく調整するところなのでしょうが、そのような最適化は行っていません。

#損失関数の定義
criterion = nn.BCELoss()

#テスト用に使うランダムなベクトル
fixed_noise = torch.randn(5, nz, device=device)

#教師データのラベル
real_label = 1
fake_label = 0

#更新手法
optimizerD = optim.Adam(netD.parameters(), lr=lr)
optimizerG = optim.Adam(netG.parameters(), lr=lr)

そして学習部分の定義です。

# トレーニングループ

# 進捗を記録するためのリスト
img_list = []
G_losses = []
D_losses = []
iters = 0

print("Starting Training Loop...")
# For each epoch
for epoch in range(num_epochs):
    # For each batch in the dataloader
    for i, data in enumerate(dataloader, 0):

        ############################
        # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
        ###########################
        ## Train with all-real batch
        netD.zero_grad()
        # Format batch
        real_cpu = data[0].to(device)
        b_size = real_cpu.size(0)
        label = torch.full((b_size,), real_label, device=device)
        # Forward pass real batch through D
        output = netD(real_cpu).view(-1)
        # Calculate loss on all-real batch
        errD_real = criterion(output, label)
        # Calculate gradients for D in backward pass
        errD_real.backward()
        D_x = output.mean().item()

        ## Train with all-fake batch
        # Generate batch of latent vectors
        noise = torch.randn(b_size, nz, device=device)
        # Generate fake image batch with G
        fake = netG(noise)
        label.fill_(fake_label)
        # Classify all fake batch with D
        output = netD(fake.detach()).view(-1)
        # Calculate D's loss on the all-fake batch
        errD_fake = criterion(output, label)
        # Calculate the gradients for this batch
        errD_fake.backward()
        D_G_z1 = output.mean().item()
        # Add the gradients from the all-real and all-fake batches
        errD = errD_real + errD_fake
        # Update D
        optimizerD.step()

        ############################
        # (2) Update G network: maximize log(D(G(z)))
        ###########################
        netG.zero_grad()
        label.fill_(real_label)  # fake labels are real for generator cost
        # Since we just updated D, perform another forward pass of all-fake batch through D
        output = netD(fake).view(-1)
        # Calculate G's loss based on this output
        errG = criterion(output, label)
        # Calculate gradients for G
        errG.backward()
        D_G_z2 = output.mean().item()
        # Update G
        optimizerG.step()

        # Output training stats
        if i % 100 == 0:
            print('[%d/%d][%d/%d]\tLoss_D: %.4f\tLoss_G: %.4f\tD(x): %.4f\tD(G(z)): %.4f / %.4f'
                  % (epoch, num_epochs, i, len(dataloader),
                     errD.item(), errG.item(), D_x, D_G_z1, D_G_z2))

        # Save Losses for plotting later
        G_losses.append(errG.item())
        D_losses.append(errD.item())

        # Check how the generator is doing by saving G's output on fixed_noise
        if (iters % 500 == 0) or ((epoch == num_epochs-1) and (i == len(dataloader)-1)):
            with torch.no_grad():
                fake = netG(fixed_noise).detach().cpu()
            img_list.append(fake)

        iters += 1

結果

100エポック回したときの結果を見てみましょう。こちらは生成器と判別器の損失関数です。生成器の損失関数は減少していっているので、狙い通りと言えますね。判別器の損失はあまり下がっていませんが、更新するたびに判別の難易度は上がっているはずなので、むしろ上昇していると解釈できるでしょうか。

f:id:Dajiro:20200523133027p:plain
生成器と識別器の損失関数。

そしてようやく、実際に生成した画像です。0エポックでは当然ランダムノイズです。

f:id:Dajiro:20200523153632p:plain
学習前の生成画像
そして10エポック経過時の結果。もうすでに数字に近くなっていますね。意外と早い。
f:id:Dajiro:20200523153654p:plain
10エポック段階の生成画像
こちらが50エポック段階。だいぶノイズが低減してきましたね。
f:id:Dajiro:20200523153737p:plain
50エポック段階の生成画像
そして最後に、学習後の生成画像。左は真の画像ですが、右の生成器の画像とほとんど見分けがつきません。多層パーセプトロンモデルでも、こんなにもMNISTらしい画像が生成できるのですね。MNISTは字の崩れにも寛容なので人の顔画像生成などでは難易度は上がるのでしょうが、少なくとも今回のタスクでは満足いく結果と言えると思います。
f:id:Dajiro:20200524145659p:plain
学習後の生成画像(左:真の画像、右:生成器の画像)

ここまで読んで下さってありがとうございます。今回は多層パーセプトロンモデルでやりましたが、CNN(畳み込みニューラルネットワーク)を使ったDCGANの結果がこちらにあるので、結果を見比べてみて下さい!
dajiro.hatenablog.com