こんにちは。今日は敵対的生成ネットワーク(Generative Adversarial Network, GAN)を取り上げます。GANというと、適当な乱数から本物そっくりの画像を生成する技術として既にご存じかもしれません。画像以外にも物理モデルの生成や、化合物の構造生成などに適用されておりここ5年ほど多方面で盛り上がっています。本記事ではGANの基本について解説し、実際にMNISTの画像生成までを行いたいと思います。
MNISTと言えばこのような数字を表す画像によるデータセットです。これに似た画像を生成するのがGANのお仕事です。
【目次】
GANの直感的理解
GANのイメージ図がこちらです。GANの最も基本的な特徴は、生成器・識別器と呼ばれる2つのニューラルネットワークから構成されているということでしょう。これら2つのネットワークが協業することで学習データに似た画像を生成することができます。これらの具体的な働きを見てみます。
生成器(Generator, G)
画像の生成器に対応します。固定長のランダムな要素を持つベクトルをニューラルネットワークの入力として、画像の2次元形状に成形して出力するのが生成器の主な仕事です。無意味なベクトルから意味のある画像を生成できるなんて、ビックリしませんか?比喩を用いるなら、この固定長ベクトルは画像の「種」と見做すことができます。出力されたフェイク画像は識別器によって真偽判定がなされます。生成器としては、自身が生成した画像を真と認識してほしい訳ですから、訓練する際にはフェイク画像に正解ラベルを与えるのが生成器の訓練における重要なテクニックになります。
生成器の学習時には、以下の関数を最小化します。
はバッチサイズ、
は識別器の出力を表し、真の画像なら1、偽の画像なら0を返します。また
はランダムなベクトル
から生成される画像を表します。生成器は判別器を騙したいので、偽の画像に対して1を返返すように訓練されます。そのため以上の関数を最小化することで生成器が更新されます。
識別器(Discriminator, D)
画像の真偽判定を行う識別器に対応します。Dの入力は、Gによって出力されたフェイク画像と学習データに含まれる真の画像データです。Dではそれぞれが真の画像か、偽の画像かを見分けるように訓練されます。最終的な出力は、真を表す1、偽を表す0となります。
識別器の目的関数はこちらです。これを最大化することで識別器としての機能が生まれます。
は真の画像
が入力されたときの
の出力、
は偽の画像
が入力されたときの出力です。それぞれ
が1、0を返すことを期待されるので、上式を最大化することが判別機の目的です。
図解
おそらく識別器の方が理解しやすいのではないでしょうか。こちらは一般的な画像の分類問題の考え方ですね。一方、生成器の学習に関してはあまりイメージがつかないかもしれません。生成器の学習には識別器の出力も使うことになりますが、このとき識別器の重みパラメータは更新せず、あくまで生成器の重みだけを更新する点に注意してください。ここでは識別器は単に予測のみを行っています。以下の図が以上の要約です。
GANの実装
それでは、さっそく実装してみましょう。Pytorchを用いてMNISTの画像生成を行いました。また計算はAWSのGPUインスタンスを用いて行いました(GPUのローカルな計算環境がないんです汗 AWSだとマシンによりますが1時間100円くらいでGPUが使えるので、まぁいいか)。実装には以下のDCGANのチュートリアルを大いに参考にしました。
DCGAN Tutorial — PyTorch Tutorials 1.5.1 documentation
まずは、必要なモジュールやライブラリのインポート
import torch import torch.nn as nn from torch.utils.data import Dataset from torchvision import datasets, transforms import torch.nn.functional as F import torch.optim as optim import torchvision.utils as vutils import torchvision import numpy as np import matplotlib.pyplot as plt
続いて、インプットの設定やデータのロードです。
#####input##### b_size = 128 ngpu = 1 nz = 100 lr = 0.0001 num_epochs = 100 ############### #生成器でtanhを使っているため入力の範囲を[-1, 1]にする。 class Mydataset(Dataset): def __init__(self, mnist_data): self.mnist = mnist_data def __len__(self): return len(self.mnist) def __getitem__(self, idx): X = self.mnist[idx][0] X = (X * 2) - 1 y = self.mnist[idx][1] return X, y #MNISTのデータロード mnist_data = datasets.MNIST('../data/MNIST', transform = transforms.Compose([ transforms.ToTensor() ]), download=True) mnist_data_norm = Mydataset(mnist_data) dataloader = torch.utils.data.DataLoader(mnist_data_norm, batch_size=b_size, shuffle=True) device = torch.device("cuda:0" if (torch.cuda.is_available() and ngpu > 0) else "cpu")
続いて、生成器を定義しましょう。ポイントとしては、ベクトルを受け取り多層パーセプトロンを通した後に画像サイズに形を変えています。またバッチノーマライゼーションを使い各層の出力分布を調整します。
#generatorの定義 class Generator(nn.Module): def __init__(self, ngpu, input_size = 100, hid1_size = 256, hid2_size = 512, hid3_size = 1024, batch_size = 4): super(Generator, self).__init__() self.ngpu = ngpu self.b_size = batch_size self.fc1 = nn.Linear(input_size, hid1_size) self.fc2 = nn.Linear(hid1_size, hid2_size) self.fc3 = nn.Linear(hid2_size, hid3_size) self.fc4 = nn.Linear(hid3_size, 28 * 28 * 1) self.bn1 = nn.BatchNorm1d(hid1_size) self.bn2 = nn.BatchNorm1d(hid2_size) self.bn3 = nn.BatchNorm1d(hid3_size) self.LeakyReLU = nn.LeakyReLU(0.2) def forward(self, input): x = input.view(-1, 100) x = self.LeakyReLU(self.fc1(x)) x = self.bn1(x) x = self.LeakyReLU(self.fc2(x)) x = self.bn2(x) x = self.LeakyReLU(self.fc3(x)) x = self.bn3(x) x = torch.tanh(self.fc4(x)) x = x.view(-1, 28, 28) return x
判別器です。画像を受け取り、それをベクトルに変換した後にスカラー値を出力しています。
#discriminatorの定義 class Discriminator(nn.Module): def __init__(self, ngpu, hid1_size = 1024, hid2_size = 512, hid3_size = 256, batch_size = 4): super(Discriminator, self).__init__() self.ngpu = ngpu self.b_size = batch_size self.fc1 = nn.Linear(784, hid1_size) self.fc2 = nn.Linear(hid1_size, hid2_size) self.fc3 = nn.Linear(hid2_size, hid3_size) self.fc4 = nn.Linear(hid3_size, 1) self.LeakyReLU = nn.LeakyReLU(0.2) def forward(self, input): x = input.view(-1, 784) x = self.LeakyReLU(self.fc1(x)) x = self.LeakyReLU(self.fc2(x)) x = self.LeakyReLU(self.fc3(x)) x = torch.sigmoid(self.fc4(x)) return x
損失関数とoptimizerを定義しましょう。損失関数はバイナリクロスエントロピー、更新方法にはAdamを使っています。本来は学習率などのハイパーパラメータを細かく調整するところなのでしょうが、そのような最適化は行っていません。
#損失関数の定義 criterion = nn.BCELoss() #テスト用に使うランダムなベクトル fixed_noise = torch.randn(5, nz, device=device) #教師データのラベル real_label = 1 fake_label = 0 #更新手法 optimizerD = optim.Adam(netD.parameters(), lr=lr) optimizerG = optim.Adam(netG.parameters(), lr=lr)
そして学習部分の定義です。
# トレーニングループ # 進捗を記録するためのリスト img_list = [] G_losses = [] D_losses = [] iters = 0 print("Starting Training Loop...") # For each epoch for epoch in range(num_epochs): # For each batch in the dataloader for i, data in enumerate(dataloader, 0): ############################ # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z))) ########################### ## Train with all-real batch netD.zero_grad() # Format batch real_cpu = data[0].to(device) b_size = real_cpu.size(0) label = torch.full((b_size,), real_label, device=device) # Forward pass real batch through D output = netD(real_cpu).view(-1) # Calculate loss on all-real batch errD_real = criterion(output, label) # Calculate gradients for D in backward pass errD_real.backward() D_x = output.mean().item() ## Train with all-fake batch # Generate batch of latent vectors noise = torch.randn(b_size, nz, device=device) # Generate fake image batch with G fake = netG(noise) label.fill_(fake_label) # Classify all fake batch with D output = netD(fake.detach()).view(-1) # Calculate D's loss on the all-fake batch errD_fake = criterion(output, label) # Calculate the gradients for this batch errD_fake.backward() D_G_z1 = output.mean().item() # Add the gradients from the all-real and all-fake batches errD = errD_real + errD_fake # Update D optimizerD.step() ############################ # (2) Update G network: maximize log(D(G(z))) ########################### netG.zero_grad() label.fill_(real_label) # fake labels are real for generator cost # Since we just updated D, perform another forward pass of all-fake batch through D output = netD(fake).view(-1) # Calculate G's loss based on this output errG = criterion(output, label) # Calculate gradients for G errG.backward() D_G_z2 = output.mean().item() # Update G optimizerG.step() # Output training stats if i % 100 == 0: print('[%d/%d][%d/%d]\tLoss_D: %.4f\tLoss_G: %.4f\tD(x): %.4f\tD(G(z)): %.4f / %.4f' % (epoch, num_epochs, i, len(dataloader), errD.item(), errG.item(), D_x, D_G_z1, D_G_z2)) # Save Losses for plotting later G_losses.append(errG.item()) D_losses.append(errD.item()) # Check how the generator is doing by saving G's output on fixed_noise if (iters % 500 == 0) or ((epoch == num_epochs-1) and (i == len(dataloader)-1)): with torch.no_grad(): fake = netG(fixed_noise).detach().cpu() img_list.append(fake) iters += 1
結果
100エポック回したときの結果を見てみましょう。こちらは生成器と判別器の損失関数です。生成器の損失関数は減少していっているので、狙い通りと言えますね。判別器の損失はあまり下がっていませんが、更新するたびに判別の難易度は上がっているはずなので、むしろ上昇していると解釈できるでしょうか。
そしてようやく、実際に生成した画像です。0エポックでは当然ランダムノイズです。
ここまで読んで下さってありがとうございます。今回は多層パーセプトロンモデルでやりましたが、CNN(畳み込みニューラルネットワーク)を使ったDCGANの結果がこちらにあるので、結果を見比べてみて下さい!
dajiro.hatenablog.com