ころがる狸

ころがる狸のデータ解析ブログ

【pytorch-pfn-extras+Ignite】画像分類のワークフロー解説

こんにちは、dajiroです。今日はPyTorchによる画像分類(CNN)に取り組んでみたいと思います。CNNの仕組み・実装方法に関してはウェブ上に十分な資料があると思うので、ここではPyTorchの学習部分を簡単に実装できるIgnitepytorch-pfn-extrasいうライブラリに焦点を当てて画像分類のワークフローをご紹介します。

本記事のまとめ
pytorch-pfn-extrasとIgniteを使ったEarlyStoppingが気持ちいい。設定すべき項目は意外とある。
こんな感じの出力が得られます。

epoch       iteration   train/loss  lr          model/fc2.bias/grad/min  val/loss    val/acc   
1           9           0.109115    0.001       -0.0068558               3.6902      0.4         
2           18          0.0799794   0.001       -0.00446897              3.31913     0.7         
3           27          0.210086    0.001       -0.00804866              8.57938     0.2         
4           36          0.770095    0.001       -0.0150844               4.7263      0.2         
5           45          0.174224    0.001       -0.0103558               2.99467     0.3         
6           54          0.0947392   0.001       -0.00595115              3.58225     0.2         
7           63          0.0278888   0.001       -0.00177731              4.41173     0.2         
8           72          0.00690222  0.001       -0.000468847             5.11954     0.2         
9           81          0.00298255  0.001       -0.000227255             5.54856     0.3         
10          90          0.00148836  0.001       -0.000129092             5.81546     0.3         
State:
	iteration: 90
	epoch: 10
	epoch_length: 9
	max_epochs: 100
	output: 0.00029667618218809366
	batch: <class 'list'>
	metrics: <class 'dict'>
	dataloader: <class 'torch.utils.data.dataloader.DataLoader'>
	seed: 12

参考資料
Igniteの公式ドキュメント
pytorch.org
pytorch-pfn-extrasのGithub。日本を代表するテック企業Preffered Networksさんが開発したライブラリです。機械学習フレームワークChainerの機能をPyTorchに移植するために開発されたようです。公開されたばかりなのでドキュメントはそこまで充実していません。コードを読みながら使い方を覚えましょう。
github.com

なぜIgniteか、extrasか?

実装が簡単になる

Igniteとpytorch-pfn-extras(以下、extrasと省略)を使う理由をコードを見ながら説明します。PyTorchによる学習を行う際には、十中八九以下のようなエポックとミニバッチに対するループが走ります。その内部で勾配の初期化や微分の計算、重みの更新が行われています。

#PyTorchでの学習の一般的な実装
#エポックのループ
for epoch in range(epochs):
    #ミニバッチを取り出すループ
    for i, data in enumerate(train_loader, 0):
        inputs, labels = data

        optimizer.zero_grad()

        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
    #validation等の実装(省略)
    #for ~ 

このような共通処理をモジュール化しEarlyStoppingなどの追加機能を充実させたのがIgniteです。上記のコードと同じことが2行でできます。もちろん学習の進捗等が見れないのでこのまま使うことは無いのですが。

#Igniteによる最小限の実装
trainer = create_supervised_trainer(model, optimizer, criterion)
trainer.run(train_loader, max_epochs=epochs)

しかし、Igniteで学習のレポートを出力しようとするとイベントハンドラーと呼ばれる関数を自前で準備する必要がありやや面倒です。出力周りも何らかのライブラリを使ってスマートにしたいという願望がありますが、これを可能にしたのがextrasです。以下のようなコードをシャシャっと書くだけで出力が超絶スマートになります。また学習曲線も自動でプロットできるので結果を簡単に見たいときに非常に便利です。

#コアとなるpytorch-pfn-extraの実装部分
my_extensions = [
    extensions.LogReport(),
    extensions.ProgressBar(),
    extensions.observe_lr(optimizer=optimizer),
    extensions.ParameterStatistics(model, prefix='model'),
    extensions.VariableStatisticsPlot(model),
    extensions.IgniteEvaluator(
        evaluator, val_loader, model, progress_bar=False),
    extensions.PlotReport(['train/loss'], 'epoch', filename='loss.png'),
    extensions.PrintReport([
        'epoch', 'iteration', 'train/loss', 'lr',
        'model/fc2.bias/grad/min', 'val/loss', 'val/acc',
    ]),
]

models = {'main': model}
optimizers = {'main': optimizer}
manager = ppe.training.IgniteExtensionsManager(
    trainer, models, optimizers, epochs,
    extensions=my_extensions)

@trainer.on(Events.ITERATION_COMPLETED)
def report_loss(engine):
    ppe.reporting.report({'train/loss': engine.state.output
                          })

このようなコードを見ると、書くこと多いなぁ、と思われるかもしれませんが一度覚えればあとはコピペして使い回したり、微調整するだけです。

extrasとIgniteは親和性が高い

trainerモジュールによって実装を容易にするライブラリは他にもありますが(Catalyst, Lightning)、extrasはIgniteとの併用を想定した機能が搭載されており親和性が非常に高いです。私はもともと業務でChainerを使っていたので、その機能を色濃く反映しているextrasの機能をふんだんに活用するためにIgniteを必然的に選びました。

画像分類のワークフロー

さて、さっそく画像分類タスクにおける実装のワークフローを見ていきましょう。コードを全部載せたので長いです。必要な部分を都度お読みください。あと、ここは大事なポイントですがGPU使ってませんのでご了承下さい(持ってないんです・・・)。

0. パッケージのインストール

はじめにPyTorch, Ignite, extrasがインストールできていることを確認しましょう。PyTorchのインストール方法は公式ドキュメントをご参照下さい。extrasとIgniteはpipで入れます。

pip install pytorch-pfn-extras
pip install pytorch-ignite torchvision

:pipでIgniteをインストールすると、Igniteのバージョン0.3.0が入るはずです(2020/6/9現在)。これなら問題ありませんが、condaを使うと最新版の0.4rc0.post1がインストールされ、これとextrasを併用するとInvalid version formatというエラーが出てextrasをうまく使うことができませんでした。バージョン確認はしておきましょう。

1. 問題設定

本ブログでは大体MNISTを使ってきましたが、画像データの読み込みに使うImageFolderの使い方を理解するため、今回は自分で撮影した写真を学習対象にしました。とはいえ枚数が少ないので、あくまで学習の流れを追うための練習用のタスクです。7分類、計70枚の超小規模なデータを使ったラベル予測を行いました。(これでも訓練データに対してはラベルを100%近く予測できるようになります)。

f:id:Dajiro:20200607155343p:plain
今回用いる画像データセット。サイズはまちまち。

2. ライブラリの読み込み

ライブラリ一式を読み込みます。PyTorchとIgnite関係です。

#PyTorch関係
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils as utils
import torchvision.utils as vutils
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torch.optim as optim

#Ignite関係(今回はEarlyStoppingも使ってみる)
from ignite.engine import Engine, Events, create_supervised_trainer, create_supervised_evaluator
from ignite.metrics import Accuracy, Loss
from ignite.handlers import EarlyStopping

#extra関係
import pytorch_pfn_extras as ppe
from pytorch_pfn_extras.training import extensions

3. 入力パラメータの読み込み

パラメータ類を一式読み込んじゃいましょう。

###############
# 入力パラメータ
###############
#画像データのパス
data_path = "/home/user/image_data"
#バッチサイズ
batch_size = 4
#エポック数
epochs = 100
#画像サイズ
image_size = 64
#ワーカー数
workers = 0
#訓練用データ比率
train_ratio = 0.7
#検証用データ比率
val_ratio = 0.2
#学習率
lr = 0.001
#チャネル数
nc = 3
ngpu = 0
#使うデバイスの設定も行っておきます
device = torch.device("cuda:0" if (torch.cuda.is_available() and ngpu > 0) else "cpu")

4. データセットの読み込みとデータローダーの定義

データセットの読み込みにはImageFolderを使いますが、パスの設定や画像の配置には注意が必要です。設定したパス(/home/user/image_dataなど)の下に/ラベル名/*jpgという形で画像を配置する必要があります。例えば今回のケースではimage_data配下にcathedral, mountain, flower, sea, lake, fool, shrineの7つのディレクトリを作成しその下に画像をコピーしました。取り出したdatasetオブジェクトには(画像, ラベル(0 ~ 6))のタプル形式で学習用データが入っています。これをDataLoaderに渡すとミニバッチごとのまとまりとして取り出せます。

def get_data_loaders(batch_size, train_ratio, val_ratio):
 #データセット取り出し。画像のリサイズ、標準化、テンソル形式への変換等はtransformsで実行。
 #データ拡張したい場合もtransformsを利用する。
 dataset = dset.ImageFolder(root=data_path,
                               transform=transforms.Compose([
                                   transforms.Resize(image_size),
                                   transforms.CenterCrop(image_size),
                                   transforms.ToTensor(),
                                   transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                               ]))

    #訓練・検証・テスト用データに分割
 num_data = len(dataset)
    num_list = [int(train_ratio * num_data),
                int(val_ratio * num_data),
                num_data -
                int(train_ratio * num_data) -
                int(val_ratio * num_data)]
    train_data, val_data, test_data = utils.data.random_split(dataset,
                                                              [num_list[0],
                                                               num_list[1],
                                                               num_list[2]])
 
 #ミニバッチごとにデータを取り出せるDataLoaderの定義
    train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size,
                                             shuffle=True, num_workers=workers,
                                             collate_fn = None)
    val_loader = torch.utils.data.DataLoader(val_data, batch_size=batch_size,
                                             shuffle=False, num_workers=workers,
                                             collate_fn = None)
    test_loader = torch.utils.data.DataLoader(test_data, batch_size=batch_size,
                                             shuffle=False, num_workers=workers,
                                             collate_fn = None)
    
    return train_loader, val_loader, test_loader

5. ネットワークの定義

今回は小規模なデータセットなので、VGGのような深いネットワークを組むと逆に学習が進みません。薄いネットワークを使います。

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 13 * 13, 512)
        self.fc2 = nn.Linear(512,  512)
        self.fc3 = nn.Linear(512, 7)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 13 * 13)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

今回は全データを64×64にリサイズして学習させました。これを畳み込み層に通すとカーネル、ストライド、パディング等のパラメータに依存して出力画像サイズが決まります。ネットワークを組む際には出力画像のサイズを都度計算して確かめましょう。

#出力画像サイズの計算メソッド。
def calcSize(input_size, kernel_size, stride, padding):
    """Calculate output image size for given parameters."""
    output_size = (input_size + 2 * padding - kernel_size) / stride + 1
    
    return output_size

6. 実体化

さて、ここまでで学習モデルの概念設計ができたので、れぞれ実体化させましょう。

#データローダー
train_loader, val_loader, test_loader = get_data_loaders(batch_size,
                                                         train_ratio,
                                                         val_ratio)
# 損失関数(closs entropy)
criterion = nn.CrossEntropyLoss()
#モデル
model = Net()
#Optimizer(Adamを使用)
optimizer = optim.Adam(model.parameters(), lr=lr, betas=(0.9, 0.999))
optimizer.step()

7. トレーナー・EarlyStoppingの定義、学習実行(extras, Ignite使用部分)

ここから、extrasとIgnite実装の核心部分です。細かい説明をする前に、実装を見てみましょう。Igniteの機能であるEarlyStoppingを用いていますが、それ以外はextrasのGithubにある例題とほぼ同じです。

#trainer(学習を行う本体)とevaluator(学習進捗の確認に使用)の定義
#GPUを使う場合は、deviceをここで指定すると入力データがcuda対応になる
trainer = create_supervised_trainer(model, optimizer, criterion, device = device)
evaluator = create_supervised_evaluator(model,
                                        metrics={
                                            'acc': Accuracy(),
                                            'loss': Loss(criterion),
                                            }, device = device)
#extrasの部分。出力するメトリクス、出力画像、プログレスバーの出力等を調整。
my_extensions = [
    extensions.LogReport(),
    extensions.ProgressBar(),
    extensions.observe_lr(optimizer=optimizer),
    extensions.ParameterStatistics(model, prefix='model'),
    extensions.VariableStatisticsPlot(model),
    extensions.IgniteEvaluator(
        evaluator, val_loader, model, progress_bar=False),
    extensions.PlotReport(['train/loss'], 'epoch', filename='loss.png'),
    extensions.PrintReport([
        'epoch', 'iteration', 'train/loss', 'lr',
        'model/fc2.bias/grad/min', 'val/loss', 'val/acc',
    ]),
]

models = {'main': model}
optimizers = {'main': optimizer}
manager = ppe.training.IgniteExtensionsManager(
    trainer, models, optimizers, epochs,
    extensions=my_extensions)

@trainer.on(Events.ITERATION_COMPLETED)
def report_loss(engine):
    ppe.reporting.report({'train/loss': engine.state.output
                          })
 
#EarlyStoppingに使うメトリクスの定義
def score_function(engine):
    loss = engine.state.metrics['loss']
    return - loss
#ハンドラーをevaluatorに追加
handler = EarlyStopping(patience=5, score_function=score_function, trainer=trainer)
evaluator.add_event_handler(Events.COMPLETED, handler)    
#学習の実行
trainer.run(train_loader, max_epochs=epochs)
詳細解説(のようなもの)

では、順を追ってみていきましょう。エポック-ミニバッチのループと学習を自動化してくれるtrainerはcreate_supervised_trainer、学習の進捗確認をするevaluatorはcreate_supervised_evaluatorを用いて設定します。Loss, Accuracyなどの良く使うメトリクスはevaluatorのmetricsにデフォルトで用意されているようです。引数に既に設定したmodel等を入れるだけなので簡単です。

trainer = create_supervised_trainer(model, optimizer, criterion)
evaluator = create_supervised_evaluator(model,
                                        metrics={
                                            'acc': Accuracy(),
                                            'loss': Loss(criterion),
                                            })

続いてextrasによる学習進捗を管理するマネージャーの準備です。my_extensionsでどのメトリクスを表示するか、プログレスバーを出力するか等の細かい調整をします。これをIgniteと連動したマネージャーであるppe.training.IgniteExtensionsManagerに渡してやります。
デコレートされた関数がありますが、これはIgnite独特の記法で指定したタイミングで関数が指定した処理を実行します。訓練用データの損失であるengine.state.outputを'train/loss'に辞書型で渡します。一方でval/loss, val/accはここで指定しなくても自動で出力されるので、内部処理はあまり理解できていません。

my_extensions = [
    extensions.LogReport(),
    extensions.ProgressBar(),
    extensions.observe_lr(optimizer=optimizer),
    extensions.ParameterStatistics(model, prefix='model'),
    extensions.VariableStatisticsPlot(model),
    extensions.IgniteEvaluator(
        evaluator, val_loader, model, progress_bar=False),
    extensions.PlotReport(['train/loss'], 'epoch', filename='loss.png'),
    extensions.PrintReport([
        'epoch', 'iteration', 'train/loss', 'lr',
        'model/fc2.bias/grad/min', 'val/loss', 'val/acc',
    ]),
]

models = {'main': model}
optimizers = {'main': optimizer}
manager = ppe.training.IgniteExtensionsManager(
    trainer, models, optimizers, epochs,
    extensions=my_extensions)

@trainer.on(Events.ITERATION_COMPLETED)
def report_loss(engine):
    ppe.reporting.report({'train/loss': engine.state.output
                          })

そしてようやくEarlyStoppingです。学習終了を判断する指標を返すscore_functionを作成します。これをもとにハンドラーを定義し、evaluatorに追加します。

def score_function(engine):
    loss = engine.state.metrics['loss']
    return - loss
#ハンドラーをevaluatorに追加
handler = EarlyStopping(patience=5, score_function=score_function, trainer=trainer)
evaluator.add_event_handler(Events.COMPLETED, handler)    

そして最後にぽちっとな。

trainer.run(train_loader, max_epochs=epochs)

8. 結果

上記の学習を実行すると、以下のようなログが出力され学習の進捗を確認できます。Chainerらしい分かりやすいログです!訓練用データのlossは大きく減少しており学習はうまく行えているようです。一方で検証用データに対するaccuracyは30%とてんでダメですね。まぁデータ数が少なすぎるから仕方ないですね(おとなしくMNISTでやっておけば良かった・・・)。
そしてなんといっても、EarlyStopping。検証用データに対するval/lossは5エポックで底を打ち、そこから5連続で値が上昇し続けたためEarlyStoppingが発動し10エポック目で学習が自動停止してくれました。思い通りの挙動になりました。スマートで気持ちいい!

また、出力結果はresultというディレクトリに自動で格納されます。

epoch       iteration   train/loss  lr          model/fc2.bias/grad/min  val/loss    val/acc   
1           9           0.109115    0.001       -0.0068558               3.6902      0.4         
2           18          0.0799794   0.001       -0.00446897              3.31913     0.7         
3           27          0.210086    0.001       -0.00804866              8.57938     0.2         
4           36          0.770095    0.001       -0.0150844               4.7263      0.2         
5           45          0.174224    0.001       -0.0103558               2.99467     0.3         
6           54          0.0947392   0.001       -0.00595115              3.58225     0.2         
7           63          0.0278888   0.001       -0.00177731              4.41173     0.2         
8           72          0.00690222  0.001       -0.000468847             5.11954     0.2         
9           81          0.00298255  0.001       -0.000227255             5.54856     0.3         
10          90          0.00148836  0.001       -0.000129092             5.81546     0.3         
State:
	iteration: 90
	epoch: 10
	epoch_length: 9
	max_epochs: 100
	output: 0.00029667618218809366
	batch: <class 'list'>
	metrics: <class 'dict'>
	dataloader: <class 'torch.utils.data.dataloader.DataLoader'>
	seed: 12

終わりに

Igniteとextrasの双方の記法を覚えなければならない学習コストはありますが、EarlyStoppingと学習の出力周りは最高です。 Igniteのハンドラーの定義はクセのある書き方でしたが、extrasを使うことで一気にスマートになりました。ほかのラッパー(Catalyst, Lightning)は使ったことがないので分かりませんが、extrasとIgniteはかなり強力なコンビだとおもいます。今後日常的に使うツールになる強い予感がしました。今日書いたコードはコピペして使いまわすことにします。