ころがる狸

ころがる狸のデータ解析ブログ

【ミニバッチ学習と学習率】低バッチサイズから始めよ

こんにちは。私は業務で3年ほど深層学習を扱っておりますが、まだまだ学ぶことが多いと感じています。深層学習ではGANのような魅力的な飛び道具も多くありますが、今日は学習率とバッチサイズという深層学習の基本的かつ重要なノウハウついて書いてみたいと思います。

そもそも学習率・バッチサイズとは

深層学習では、ニューラルネットワークによる出力と教師データの差分値を含む損失関数を定義し、この重みパラメータによる微分を最小化する方向に学習を行います。これによりネットワークの出力と教師データが一致する方向に学習が進み、ネットワークが入力データの特徴を自動的に学習します。損失関数を重みパラメータの関数と見たときの形状は、谷と山があるデコボコが高次元の空間に広がっていると考えてください。この中で、もっとも深い谷に効率的に到達するために学習率・ミニバッチ学習という概念が必要になります。

学習率とは一回の学習で重みパラメータをどれくらい変化させるか、という指標になります。これが大きいと一気にパラメータが更新され、逆に小さいとチビチビと進んでいくイメージです。なので学習率が大きいと損失関数が最小となる谷を一気に通り過ぎるリスクがありますし、小さいと局所解にトラップされて学習が進まなくなる可能性があります。

ミニバッチ学習は、複数のデータの塊(ミニバッチ)に対する損失関数の微分を計算し重みを更新します。ミニバッチのサイズが大きいとデータの特徴が平均化されてしまい、データの個性が失われるリスクがあります。一方で局所解にトラップされることが少なくなりますし学習も早くなります。

バッチサイズと学習率の決め方

バッチサイズと学習率は結局、良いバランスで選べということになるわけですが、どのように決めるべきかは悩ましい問題です。しかしとっかかりとなる指針がないわけではありません。エンジニアによって指針は異なるかもしれませんが、以下に従うのが良いように思います。

  1. 小さなバッチサイズ(1、2、4、8、16・・・)から始める。学習時間が気になる場合や手っ取り早く精度が知りたい場合はデータ数を小さくするなどの工夫をする。
  2. バッチサイズが小さいとノイズの影響を受け局所解にトラップされやすい。この場合は学習率を上げて対応する。
  3. バッチサイズが大きくなると入力パラメータの特徴が平均化されるため、数万規模のバッチサイズではデータの個々の特徴が失われる可能性がある。そのためバッチサイズの上限は1024、2048、4096程度の範囲でとどめておくのがベター
  4. 一般にバッチサイズが小さいほど精度が良いとされる。逆に低バッチサイズで精度が悪く、大きいほど精度が高いという場合には、入力データの表現力がそもそも高くない可能性を考慮すべきである。特徴量エンジニアリングから検討し直そう。
  5. ハイパーパラメータの決定方法として、ベイズ最適化(BO)に基づいたハイパーパラメータ探索法がある。自動でパラメータを決めてくれるので便利だが、探索空間を適当に決めてはいけない。BOで最適化する場合、探索空間を決める根拠をしっかり持つこと。1-4の上の指針に従うのが良いと思う(ハイパーパラメータ自動探索ライブラリの有名どころではOptunaなどがある)。

preferred.jp

といったところでしょうか。データサイズが数十万・数百万あるところでバッチサイズ1,2などを設定すると、頻繁に重みの更新を行う必要があるため学習にかかる時間が長くなりじれったいです。かといって大きいバッチサイズを使っていれば良いというわけではないので、注意が必要です。