ころがる狸

ころがる狸のデータ解析ブログ

【CNN+Grad-CAM】仕組みの解説と画像の予測根拠可視化

こんばんは、Dajiroです。本ブログでは既に画像を予測する方法を学びましたが、今回はCNNによる画像予測の根拠についてご紹介します。その代表的な技術である(Guided) Grad-CAMについての仕組み解説と、実際に得られた予測根拠を見ていきます。

画像認識についてのワークフローはこちらの記事をご覧ください。
dajiro.hatenablog.com

また、本技術の元論文はこちらです。
arxiv.org

【目次】

Grad-CAMの仕組み

Grad-CAMとは【Gradient-weighted Class Activation Mapping】の略で、一言で要約すると予測値に対する勾配を重み付けすることで、重要なピクセルを可視化する技術です。勾配が大きいピクセルは予測値に大きな影響を与える=重要という発想です。この技術を使うことで以下のような画像が得られます。こちらは私が大昔に撮影したピザの写真ですが、これをresnet50というCNNの学習済みモデルに通すと正しく『ピザ』と予測できます。この時の分類根拠を可視化したのが真ん中と右の図で、ピザに相当するピクセルを重点的に重み付し、更にGuided Grad-CAMの結果からはピザの輪郭や具材的なものを捉えられていることが分かります。

f:id:Dajiro:20200626212404p:plain
Gard-CAM, Guided Grad-CAMによるピザ画像の特徴量可視化

ワークフロー

Grad-CAMの計算のワークフローは以下のようになります。

  1. 画像を順方向に入力し、畳み込み層分類結果を取り出す
  2. 分類結果を使って誤差逆伝搬し、畳み込み層の勾配を計算する
  3. 畳み込み層の勾配のグローバルアベレージプーリング(GAP)を取る
  4. GAPを重みとした畳み込み層の重み付き和を取り、元の画像サイズにリサイズする

それでは、この計算の流れを図を使って説明していきます。

順方向計算と勾配計算

まずは1, 2の順方向計算と勾配計算です。これに関しては一般的なCNNを理解していればそれほど難しい操作ではありません。まずは画像を入力し、CNNを通してクラス分類を行います。その際に得られる畳み込み層の出力と、クラス分類の出力を取り出しておきます。また、この順方向計算の後に誤差逆伝搬を行い、畳み込み層の各要素に対するクラス分類出力の勾配を計算します。

f:id:Dajiro:20200626222751p:plain
【ステップ①, ②】順方向計算と畳み込み層の勾配の計算。

グローバルアベレージプーリングと重み付き和

勾配が得られたら、それぞれの畳み込み層に対してグローバルアベレージプーリング(GAP)を取ります。これは非常にシンプルな操作で、各畳み込み層の勾配の平均値を取るだけです。得られる値はスカラー値なので、これを元の順方向で得られた畳み込み層の値(ステップ②で取り出していました)に重み付し、全ての畳み込み層で加算することで1枚の画像が得られます。ここに活性化関数Reluを挟み、小さな畳み込み層の出力をリサイズすることによって最初の図に示したGrad-CAMの画像が得られます。
Grad-CAMは可視化根拠を得るために非常に強力な技術ですが、考え方はシンプルです。素晴らしいですね。

f:id:Dajiro:20200626225612p:plain
【ステップ③, ④】グローバルアベレージプーリングと重み付き和。

Grad-CAMの問題点と解決策

さて、Grad-CAMを用いることでピザを予測するCNNは正しくピザを見てることが分かりました。しかしこの方法では小さな画像を大きな画像にリサイズしているため、解像度が低く予測根拠を大まかにしか把握できません。CNNはピザの形を見て判断しているのでしょうか、それとも具材?これに答えるためにはより高解像な可視化が必要となります。

可視化根拠を高解像度化するためにはどうするべきでしょう。そう、入力画像のピクセル値に関する予測値の勾配を取ればいいんです。予測において重要な勾配を効率的に取り出すために、guided backpropagationという方法を用います。これは、一般的なReLU関数による勾配計算処理に、前の層の勾配が負の値を取るものをゼロに置き換えるという処理を追加で行っています。以下の図が参考になるでしょう。

f:id:Dajiro:20200626233328p:plain
guided backpropagationのイメージ図。活性化関数としてReLUを使っている。
https://arxiv.org/pdf/1412.6806.pdf

こうして得られた入力層の勾配を上記のGrad-CAMのヒートマップの値に掛けることにより、ピクセルレベルの高い分解能を持ちつつ、重要部分を可視化できるようになります。このようにguided backpropagationとGrad-CAMの手法をミックスしたものがGuided Grad-CAMです。

ちなみに、guided backpropagationだけを使って可視化するとクラス分類を行う根拠とした部分をうまく抽出することができません。参考までに両方の結果を見てみましょう。Grad-CAMを使わないと、画像のどこを注視しているのかイマイチ分かりません。

f:id:Dajiro:20200626233620p:plain
guided backpropagation VS Guided Grad-CAM

おわりに

Grad-CAMの考え方は意外にシンプルですが、画像の予測根拠を可視化する上で非常に有益なツールです。そのため今後も使う場面は多いと思います。他にもLIMEやintegrated gradientなどAIの説明可能性を上げるための手法が報告されているため、XAI (Explanable AI)に関してもっと勉強していきたいと思います。

最後になりましたが、Grad-CAMの動作に使ったコードはこちらになります。コードを解読することでGrad-CAMの仕組みと実装方法の両方を学べるので是非ご覧ください。
github.com