ころがる狸

ころがる狸のデータ解析ブログ

【生成モデル】正規化フローでMNISTの画像生成

 こんにちは。育児のため時間がとれずブログ更新を1年放置していました。今後はしっかりと勉強時間を確保して、記事執筆を頑張っていきたいと思います!さて、今回は機械学習分野でもっとも注目を集めている技術の1つである生成モデルを取り上げます。生成モデルでは画像や文章、分子構造などの訓練データの分布を学習し、それに類似したデータを自動生成します。代表的な生成モデルとして以下があります。

  1. GAN(敵対的生成ネットワーク)
  2. VAE(変分自己符号化器)
  3. 拡散モデル
  4. 正規化フロー

 これらのうち、今回は正規化フローの技術的な解説及びMNISTの画像生成例を整理したいと思います。あまり聞きなれない技術ですが、画像生成、分子構造生成、スピーチ生成等で成果を上げている重要技術です。アムステルダム大が公開している深層学習コースのチュートリアル資料を参考に、コードを動かしながら画像を生成したいと思います。

チュートリアル(UvA Deep Learning Tutorials)のリファレンスは以下の通りです。この内容を一部補足するような形で説明します。
uvadlc-notebooks.readthedocs.io

それでは以下の項目に従って解説していきます。

正規化フローの概要

 生成モデルとして知名度のあるGAN, VAEと正規化フローの概要を比較したのが下の図になります。GANやVAEと同様に、正規化フローでもランダムに生成した潜在変数zから新たなデータを生成します。これらのモデルと正規化フローの大きな違いは、

  1. データの変換(正規化フローと呼ばれる)が逆変換可能な関数fで行われる
  2. データ分布 p_{x}(\boldsymbol{x})を直接学習する(=学習時に負の対数尤度を最小化する)
  3. zとxが同じサイズのためデータの損失なしに両者の変換を行える

といった点にあります。では、正規化フローでデータ分布 p_{x}をどのように求めるべきでしょうか。

生成モデルの比較(出典[1])

まず、確率分布の定義から、実データxと潜在変数zに関して以下の関係が成り立ちます。

 \int p_{x}(\boldsymbol{x}) d\boldsymbol{x} = \int p_{z}(\boldsymbol{z}) d\boldsymbol{z} = 1

これら2つの確率分布に対して、変数変換により以下の関係が導かれます。両辺に対数関数を取ることで、実データ分布 p_{x}(x)に関する対数尤度が得られます。ここで、関数 fを用いて z = f(x)^{-1}と表されることを使っています。右辺の第二項には対数ヤコビアン行列式(log-determinant of Jacobian, LDJ)が出現しますが、これは直感的には変換前後における確率分布の体積の変化率を表しています。
 \textrm{log} p_{x}(x) = \textrm{log} p_{z} (f(x)^{-1})  - \textrm{log} \left| \textrm{det}\frac{df(x)}{dx} \right|

 正規化フローの学習では、この \textrm{log} p_{x}(x)を最大化(=負の対数尤度を最小化)するようにパラメータが最適化されます。関数 fは複雑な形状になると予想されますが、できるだけシンプルに右辺を計算するためのテクニックとして、多くの変換可能関数を順番に作用させることを考えます。数式で表現すると以下の通りです。
 \boldsymbol{x} = \boldsymbol{z}_{K} = f_{K} \circ f_{K-1} \circ \cdots f_{1}(\boldsymbol{z}_{0})
確率分布の変換式を用いると、対数尤度 \textrm{log} p_{x}(x)を得るための式は以下のようになります。すなわち、正規化フローを計算するには変数変換を行うごとにLDJの和を取ればよいことが分かります。
 \textrm{log} p_{x}(\boldsymbol{x}) = \textrm{log} q_{K}(\boldsymbol{z}_{k}) = \textrm{log} q_{0}(\boldsymbol{z}_{0}) - \sum_{i=1}^K \textrm{log} \left| \textrm{det}\frac{df_{i}}{dx} \right|

正規化フローのイメージ(出典[1])

ここで変換可能な関数 fはニューラルネットワークを用いて構成し、上式に従い対数尤度を計算します。また、右辺のLDJは行列サイズNに対し O(N^3)で計算量が増加するため、行列式の計算量を削減するための fの構成の仕方が非常に重要です。

計算手順

  続いて、具体的な計算手順を見ていきましょう(下図参照)。MNIST等の画像データをでは各ピクセルに離散値が保存されていますが、連続値として扱えるようにdequantizaiton(非量子化)します。その後LDJの和を取りながら変数変換を繰り返し、負の対数尤度 -\textrm{log} p(\boldsymbol{x})が最小となるようにニューラルネットワークの学習を行います。その後、ガウス分布から画像と同サイズの乱数配列を生成し、正規化フローの逆変換を施すことで画像が生成されます。
 なお、正規化フローにおける学習では損失関数として負の対数尤度でスケールされたbits per dimensionと呼ばれる指標を用いることが多いそうです。対数尤度をデータ次元数で除算しているため、異なる画像データの学習結果でも比較できるとのこと(出典[2])。

計算手順例

画像のdequantization

 正規化フローでは確率分布の体積が1となることを前提としているため、画像データのように各ピクセルが離散値を取るとその値に無限大の分布を想定することになり(デルタ分布)、学習には不向きです。そのため離散値を連続変数へ変換するための処置がdequantizationです。dequantizationの手続きの概要は以下の通りです。

  • 離散値に対してノイズを付与
  • 離散値の上限(画像だと256)で除することで0-1の範囲にスケール
  • 逆シグモイド関数を作用させることで、マイナス無限大からプラス無限大の範囲の値として表現

このような処理を施すことで、各離散値がガウス分布からサンプリングされた数値であるかのようにみなせます。例えば8個の離散値があってそれらに一様分布を仮定し、上記のdequantizationを作用させると以下のような結果になります。それぞれの離散値に対応したガウス分布の体積は、全ての値に対して等しくなります。

非量子化分布(出典[2])

しかし、実際の画像データでは各ピクセルに割り当てられる整数は一様ではなく、下図のように偏りがあると考えるのが自然です。そのため各ピクセルに付与するノイズの大きさを正規化フローによって学ばせ、分布を最適化することができます(variational dequantization)。ノイズをフローによって最適化することから学習時のパラメータは増加してしまいますが、学習結果に良い影響を及ぼします。

非量子化分布(出典[2])

 さて、これらの手続きによりdequantizationを行う方法が分かりました。しかし正規化フローでは \textrm{log} \left| \textrm{det}\frac{df(x)}{dx} \right|を計算しなければいけません。ここではdequantizationの処理で登場するシグモイド関数を例として、LDJの計算方法を解説します。
 シグモイド関数は y = \frac{1}{1 - \textrm{exp}(-x)}と書けますから、ヤコビアンを計算する際にはその微分式を解析的に求めておきます。微分式の対数を取ると、次のように書き下すことができます(softplus関数を用いて簡潔に記載します)。
 \textrm{log}y^{'} = -x - 2\textrm{softplus}(-x)
これを要素として持つ行列の行列式を求める必要があります。ここで考えているのは画像データとしての配列に対する処理なので、各要素(=各ピクセル)の数値は独立であることを利用します。例えば画像のあるピクセルの色調が変わったからといって、他のピクセルは連動して変化しませんよね。そのため \frac{f(x_{i}))}{x_{j}} = 0 (i\neq j)となることから、ここでのヤコビアンは対角行列になります。対角行列の行列式は対角項の積として記述されることから、行列式を簡単に計算することができます。この考え方を利用して、dequantizationの処理に関するLDJは実は簡単に実装できるのです。
 参考までに、実装は以下のようになります。(対数を取ってるので積が和に切り替わっています。)

 #sigmoid関数による変換に伴うlog determinant jacobian(ldj)の計算方法 
 # 画像のチャネル、高さ、横方向の総和を取っている
 ldj += (-z-2*F.softplus(-z)).sum(dim=[1,2,3])

変換プロセス

 画像のdequantization処理のあとは、変数変換を繰り返しながら適切なデータ分布を獲得します。変数変換の設計には以下の要件が求められるでしょう。

  • 逆変換が可能か
  • ヤコビアン行列式を簡単に計算できるか

この要件を満たす仕組みとして、ここではcoupling layerをご紹介します[3]。coupling layerでは、入力となる潜在変数 z z_{1:j} z_{j+1:d}の2系統に切り分けます。下図に示した通り、 z_{1:j}には変換を加えずそのまま処理し、もう半分の z_{j+1:d}には z_{1:j}を入力としたバイアスとスケール \mu_{\theta}, \sigma_{\theta}をニューラルネットワークを用いて構成しておき、以下のように変換します。
 z_{j+1:d}^{'} = \mu_{\theta}(z_{i:j})+\textrm{exp}(\sigma_{\theta}(z_{1:j}))\bigodot z_{j+1:d}
この逆変換は以下のようになるため、第一の要件が満たされます。
 z_{j+1:d} = (z_{j+1:d}^{'} - \mu_{\theta}(z_{i:j}))\bigodot \textrm{exp}(-\sigma_{\theta}(z_{1:j}))

また、画像データの配列は各要素が独立であることを利用すれば、LDJはスケーリングファクター \sigma_{\theta}を用いて以下のように簡単に計算できます。これで第二の要件を満たすこともできます。
 \textrm{LDJ} = \sum_{i}\sigma_{\theta}(z_{1:j})_{i}

coupling layerのイメージ(出典[2])

このように、coupling layerは正規化フローを構成する上で非常に良い素性を持っていることが分かりました。ここまでに示したdequantizationと、複数のcoupling layerを組み合わせることで一連のフローが完成します。

正規化フローのイメージ(出典[2])

画像サイズの低減

 これまでに紹介したテクニックを使えば画像を生成できますが、正規化フローは入力データと潜在変数のサイズが等しいため高品質な画像データを扱うには計算時間を要します。そこで、形状変換により縦横のデータの次元を減らしたり(squeeze)、分割する(split)ことを考えます。具体的には、squeeze処理ではH × W × Cのサイズの画像をH/2 × W/2 ×4Cとします。縦横は1個飛ばしとし、飛ばしたピクセルをチャネル方向に結合します。split処理は正規化フローの途中に挟み込み、チャネル方向に2分割するといった処理を施します。

squeeze処理のイメージ(出典[2])

squeeze-split処理を追加した正規化フローは以下の通りです。multi-scale architectureと呼ばれることもあります。最終層では分割した潜在変数同士の再結合とreshapeが行われるため、結局は元の画像と同サイズの潜在変数が生成されます。squeeze処理はつまるところ配列のreshapeであり、チャネル方向の拡大に対応するためニューラルネットワークの隠れ層の次元数は増加してしまいます。しかし計算量と計算時間は必ずしも比例しないため、multi-scale architectureを採用することでサンプル速度の向上につながるというメリットがあります。また、後述するようにより真の数字に近い画像を生成できるようになるため、画像の様々なスケール(ローカル/グローバル)の特徴を効率的に学習することに役立っているのかもしれません。

multi-scale architecture正規化フローのイメージ(出典[2])

さて、以上で学習の準備が整いました。以下ではvariational dequantizationとmulti-scale architectureに則った正規化フローを用いてMNISTの画像を生成してみましょう。

画像の生成

出典[2]の公開コードに基づき画像を生成してみた結果がこちら。左がシンプルなモデル、右がmulti-scale architectureを用いた結果です。結果は一目瞭然でより洗練されたモデルを使うことで、実際の数字に類似した画像を生成できています。数字の体を成していないサンプルも多く完全とは程遠いですが、正規化フローにおいてmulti-scale architectureを用いた場合の効果を確認することが出来ました。

画像生成例

おわりに

 正規化フローの概要と実装レベルでの具体的な計算手順を確認できました。概要の理解と実装の理解には大きな違いがありますが、今回はチュートリアルを通した学習により正規化フローをより詳細に把握することができました。正規化フロー設計の要点は掴めたかと思うので、これを足掛かりに細心の文献にもキャッチアップしていきたいです。
 個人的に興味があるのが正規化フローによる分子構造の生成です。グラフ構造をどのように正規化フローに流し込むのか、そのLDJをどのように計算するのか、どのように計算を効率化するか。疑問点は多いですが、このあたりに注目すれば効率的に論文が読めそうです。

出典

[1]Flow-based Deep Generative Models | Lil'Log
[2]Tutorial 11: Normalizing Flows for image modeling — UvA DL Notebooks v1.2 documentation
[3][1605.08803] Density estimation using Real NVP