ころがる狸

ころがる狸のデータ解析ブログ

【PyTorch×転移学習】学習済みモデルライブラリTIMMのご紹介

こんにちは、dajiroです。今回は高精度な画像分類を行うのに便利なライブラリTIMMをご紹介します。PyTorchでは画像分類用の学習済みモデルが公式で提供されていますが、使われているモデルがやや古く栄枯盛衰の激しい機械学習の世界では現代最高レベルの予測精度を発揮することは困難です。そこで、画像分類タスクに取り組む際にはTIMMのような外部ライブラリを用いることで数多くの、優れた、最新のモデルにアクセスし、より高精度の精度のモデルを構築するという戦略が考えられます。実際データ分析コンペKaggleでもTIMMを用いた転移学習による画像分類も行われています。以下では、

  • 転移学習とは
  • TIMMについて
  • 実装例

の3点を順に説明していきます。

なお、学習済みモデルのライブラリにはTIMMの他にも以下のようなレポジトリがあるので、ご参考まで。
GitHub - Cadene/pretrained-models.pytorch: Pretrained ConvNets for pytorch: NASNet, ResNeXt, ResNet, InceptionV4, InceptionResnetV2, Xception, DPN, etc.

転移学習とは?

転移学習とは、大規模データセットによって訓練された学習済みモデルを用いて自分の解析したいデータ(画像・文章)に対して学習を行うことです。学習済みモデルには既に画像や文章などの特徴が埋め込まれているため、手元のデータが仮に小規模であっても微調整(ファインチューニング)をすることで高い予測精度を発揮できるようになります。そのため、機械学習において転移学習は最も重要な技術の一つと言えるでしょう。

TIMMの使い方は?

インストール方法とTIMMの使用法

提供されているモデルはGitHubで公開されており、パッケージはpipでインストールすることができます。

pip intall timm

GitHubのレポジトリは以下のリンクからご確認下さい。このレポジトリの作者が作成したモデルに加え、TensorflowやMXNetなどの別のフレームワークで訓練されたモデルもPyTorchで使えるよう変換されています。数が多いためここでは紹介できませんが、例えば2019年に登場した強力なモデルEfficient NetやKaggleでよく使われるモデルであるResNeXtベースのモデルが多数公開されており簡単に試すことができます。
github.com

実装例

それでは、TIMMを使った転移学習の例をご紹介します。コードはGoogle Colaboratoryで作成しました。以下の私のレポジトリで公開しているので省略しますが、使ったデータセットとtimm特有のエッセンシャルな部分だけ掲載します。使った学習済みモデルはEfficientNet_b0です。
github.com

データセット

機械学習向けの大規模な画像データセットImageNetで構築されたモデルの精度を測る、ImageNetV2というテスト用データセットを用います。もとのImageNetの構築から10年後に作成された独立なデータセットとなるためテストにはもってこいです。本来はテストのみ行うと思いますが、ここでは転移学習をテーマとしているので再学習も試しにやってみました。
GitHub - modestyachts/ImageNetV2: A new test set for ImageNet
こんな感じの画像データから構成されています。

f:id:Dajiro:20200724155826p:plain
ImageNetV2

コード

ライブラリの読み込見込みは下記のようにします。

import timm

データの標準化に必要となる平均値と分散は以下のように取り出します。ソースコードを読むとargsにはinput_size, mean, stdなど渡せるようですがここでは省略しすべてデフォルト設定を活用。

args = {}
model_name = 'efficientnet_b0'
data_config = timm.data.resolve_data_config({}, model=model_name, verbose=True)
print(conf["mean"], conf["std"])

核心部であるモデルはこのように構築します。最後の全結合層は必要に応じて出力ノード数を変更できるようにしておきます。

class EfficientNet_b0(nn.Module):
    def __init__(self, n_out):
        super(EfficientNet_b0, self).__init__()
        #モデルの定義
        self.effnet = timm.create_model(model_name, pretrained=True)
        #最終層の再定義
        self.effnet.classifier = nn.Linear(1280, n_out)

    def forward(self, x):
        return self.effnet(x)

モデルを実体化し、全結合層のみを学習対象とします。

model = EfficientNet_b0(n_class)
params_to_update = []
update_param_names = ["effnet.classifier.weight", "effnet.classifier.bias"]
for name, param in model.named_parameters():
    if name in update_param_names:
      param.requires_grad = True
      params_to_update.append(param)
    else:
      param.requires_grad = False

5エポックだけ回してみましたが、学習は進んでいるように見えます。

phase: val LR: 0.001 epoch: 0 loss: 6.9350 accuracy: 0.0010
phase: train LR: 0.001 epoch: 1 loss: 6.0797 accuracy: 0.1490
phase: val LR: 0.001 epoch: 1 loss: 4.4728 accuracy: 0.3150
phase: train LR: 0.001 epoch: 2 loss: 2.1574 accuracy: 0.7431
phase: val LR: 0.001 epoch: 2 loss: 3.1850 accuracy: 0.4150
phase: train LR: 0.001 epoch: 3 loss: 0.8378 accuracy: 0.9224
phase: val LR: 0.001 epoch: 3 loss: 2.8108 accuracy: 0.4515
phase: train LR: 0.001 epoch: 4 loss: 0.3998 accuracy: 0.9771
phase: val LR: 0.001 epoch: 4 loss: 2.6884 accuracy: 0.4525
Finished Training

終わりに

TIMMには優れたモデルが数多く実装されているため、Kaggleなどのコンペでは複数モデル構築⇒アンサンブル化(多数決)するという手順を取るのが定石の一つと言えると思います。TIMMに限らず使える優れたモデルは惜しみなく使っていきたいところです。
また、余談になりますが2020年7月現在最強のモデルの候補と言えるのは、2019年末にGoogleが公開したBiTだと思われます。重みパラメータが公開されているため試しに使ったことがありますが、使用経験上TIMMのモデルよりも数%高い分類予測精度をたたき出すことがあります。
ai.googleblog.com

恐るべしGoogle、BiTもKaggleもColabもすべて彼の手に・・・。AIがこれほど世にもて囃されるのはGoogleの卓越した経営・マーケティング戦略の結果なのかもしれません。AIが実際役に立つことに疑いの余地はありませんが、盛り上がるほど彼らを利することにもなるのもまた事実。