ころがる狸

ころがる狸のデータ解析ブログ

【SAM】最新オプティマイザーで画像分類の精度検証!

f:id:Dajiro:20210613155321p:plain
みなさんご無沙汰しております、Dajiroです。久しぶりのブログ投稿です。ここ半年ほど、データベースやAPI、AWSの勉強で忙しかったのですが、ようやく機械学習に帰ってこれました。今回の記事では、最新のオプティマイザであるSAM(sharpness aware minimization)を使った画像分類の精度検証を行ってみたいと思います!

本技術は、昨年の12月にarxiv上でGoogleチームによって報告された手法で、画像分類のタスクにおいて従来手法を上回る予測精度を示したことで大きな注目を集めています。従来のオプティマイザとの大きな違いは、モデルのパラメータを探索する際にパラメータ空間における「フラットさ」を見ている点にあります。これにより、高い予測精度を発揮するとともにノイズに対する頑健性も強化することに成功したと報告しています。

それでは、詳細な内容と精度検証に移りましょう。と、その前にSAMの原著論文は以下になります。
[2010.01412] Sharpness-Aware Minimization for Efficiently Improving Generalization
また、動作検証にはGithubで公開されているPyTorch実装を活用させて頂きました。
GitHub - davda54/sam: SAM: Sharpness-Aware Minimization (PyTorch)
今回使ったGoogleColabのノートは以下で公開しています。
github.com

SAMの概要

従来手法、及びSAMを用いたパラメータ空間探索の結果は以下の図によって上手く可視化されています。従来手法(SGD、左図)は損失関数を最小化する1点を探す手法なので、損失関数は小さくはなりますが最適化パラメータ周囲の空間はデコボコしています。これは教師データに対する過学習や汎化性能の低減につながります。一方でSAMを用いた場合は、周囲が平坦なパラメータを探索するように損失関数が設計されているため、選ばれたパラメータの周辺の値も一様に低い損失関数を示します。これにより高い汎化性能とノイズに対する頑健性(robustness)の向上が期待されますが、実際にCIFAR-10などの著名なデータセットでstate-of-the-artな性能を示すことが報告されています。

f:id:Dajiro:20210314093345p:plain
SGD(左)とSAM(右)で探索したパラメータ空間

それでは、どのようにしてSAMは周囲の平坦さを見ることに成功しているのでしょうか?SAMの損失関数の定義を見るとそのヒントが掴めます。

\underset{w}{{\rm min}}\,L_{S}^{SAM}(w) + \lambda \|w\|^{2}_{2}

ここで、

L_{S}^{SAM}(w) \,\underset{=}{\Delta}\, \underset{\|\epsilon\|_{p}\,\leq\,\rho}{{\rm max}}L_{S}(w + \epsilon)

です。また、L_{S}は学習データセットSに対する損失関数、wはモデルパラメータ、\epsilonは周辺のパラメータ、pはノルムの次数、\rhoはハイパーパラメータです。

この損失関数のコンセプトは、パラメータ空間におけるモデルパラメータw周辺で最も急峻な崖が最小であるようなwを見つけよう、ということです。このようなwの周辺は、最も急峻である方向\epsilonに進んでみても勾配が緩やかなはずです。このようなwを探すことで、当初の目標である周辺もフラットなパラメータを発見することが可能となります。

それでは実際にどうやって実装するの?という話になりますが、それは他の解説記事や原著論文に譲りましょう(すみません・・・)。主要点のみ説明すると、急峻な方向\epsilonを計算する必要がありますがそのために勾配の計算を行います。更に、得られた\epsilonを用いてL_{S}^{SAM}(w)の勾配を計算し、得られた勾配を用いてパラメータを更新します。すなわち、SAMでは勾配の計算が2回必要となるため、従来手法に比べ計算時間が約2倍かかるアルゴリズムとなります。そのため原著論文でもSAMの学習にはエポック数は半分にして計算しています。

精度検証(SGD, Adam, SAM)

それでは精度検証をしてみましょう。検証の概要は以下の通りです。(注:原著論文のモデルWideResNetsは踏襲していませんのでご注意を)

  • 学習データ:CIFAR-10(学習データ50k, テストデータ10k)
  • 学習モデル:二層の畳み込み層とマックスプーリング層を持つCNN
  • オプティマイザ:SGD + momentum(検証1)、Adam(検証2)、SAM + SGD + momentum(検証3)
  • 損失関数:クロスエントロピー
  • その他:学習率0.001、バッチサイズ32、画像サイズ32、エポック数50

3つのオプティマイザーを用いた損失関数の精度比較は以下の通りです。学習用データに関する振る舞いを見ると、SAMの減少が最も緩やかであることが分かります。一方でSGDやAdamでは比較的はやくロスが減少していきます。テスト用データに関する損失関数は非常に面白く、SAMは50エポックでは過学習の傾向が見られない(減少を続け反発しない)、SGDやAdamでは早い段階で既に減少は停止しています。このことから、SAMを用いることで学習用データに対する過学習を抑制できていることが示唆されます。

f:id:Dajiro:20210314120938p:plain
学習用データに対する損失関数。
f:id:Dajiro:20210314122121p:plain
テスト用データに関する損失関数。

続いて、正解率(accuracy)を見てみましょう。こちらもSAMの特性が現れた結果になっていると思います。学習用データに関する正解率は、Adamは立ち上がりが早いですがSAMは緩やかです。テスト用データに対する結果を見ると、Adamの正解率は10エポックでサチっていますがSAMは50エポックまでほぼ単調に増加しています。SGDはその中間といったところです。

f:id:Dajiro:20210314121555p:plain
学習用データに対する正解率。
f:id:Dajiro:20210314121615p:plain
テスト用データに対する正解率。

テスト用データに対する正解率の最大値を見ると、

  1. SGD + momentum:0.669(25エポック)
  2. Adam:0.654(11エポック)
  3. SAM + SGD + momentum:0.657(50エポック)

という結果に。SGDが最も大きいですが、最大正答率自体には大きな変化はありませんでした。原著文献ではWideResNetを用いてSOTAな性能を示していたので、この辺りはモデルやハイパラ依存性が多分にありそうです。

まとめ

素朴なCNNモデルを用いて、話題のオプティマイザーSAMによる精度検証を行いました。得られた知見は以下の通りです。

  • 従来手法に比べ過学習が抑制できる可能性がある。
  • 今回の実験及び原著論文の結果から、従来手法と同等もしくはそれ以上の予測精度が期待できる。

今回の実験では正解率は全オプティマイザーで同等だったので、適切にアーリーストッピング法などを導入すれば強いてSAMを使う必要はなかったかもしれません。また、勾配を2回求めるため計算時間を要するのもネックになるかもしれません。ですが、過学習の抑制や(今回の実験では検証してませんが)エラーデータ混入の場合の対応においては優位性のある手法であると感じました。現在はオプティマイザーといえばAdamなどが大きな地位を占めていますが、SAMもそれと同レベルの地位を占めることになるかもしれません。

おまけ

今回はSAMとSGDを組み合わせて実験をしました。これを、Adamを組み合わせて実験するとほぼAdamと同じ傾向が見られました。原因は、、、なんでしょう。このあたりの理解も進めていきたいです。

f:id:Dajiro:20210314123550p:plain
SAM + Adamを用いた場合の損失関数と学習率。