Stochastic Average Gradient法を解説する

 Stochastic Average Gradient (SAG)はNIPS 2012で提案された新しい最適化手法である。目的関数がstrongly convexである場合、という条件付きではあるが、線形収束が保証されている。要するに、速い。

 論文の解説についてはOiwa神の記事を参照すると良いと思う。以下では、SAGの考え方について、一般的なSGDとの差異を中心に説明したい。

SGDの復習

 SAGを説明する前に、SGD(Stochastic Gradient Descent、確率的勾配降下法)とはどんな手法だったかを確認しておこう。

 SGDはランダムに1つのデータを取ってきてgradientを計算し、そのgradientでパラメーターを更新する手法であると言える。(Full) Gradient Descentだと、現在のパラメーターを使ってすべてのデータに対してgradientを計算してその平均を取ったもの(ここではfull gradientと呼ぶことにしよう)を使ってパラメーターを更新するわけだが、いちいちデータ全部に対するgradientを計算するのは処理が重たい。そこで、1つのデータに対するgradientをfull gradientの近似とみなしてパラメーターを更新してしまう、というのがSGDである。他の解釈もあるかもしれないが、SAGと対比するには、こう解釈するとわかりやすい。

SAGとは

 SAGの名前はSGDに似ている。しかし、その考え方はSGDとはちょっと違う。

 SAGも、full gradientの計算が重いので一部の計算をサボって近似しよう、という出発点はSGDと変わらない。しかし、SAGの場合はfull gradientを少しずつ更新していく、という形を取る。つまり、ランダムに1つのデータを取ってきて、現在のパラメーターに対するgradientを計算し、それを使ってfull gradientを更新する。そのfull gradientを使ってパラメーターを更新する。

 SGDではfull gradientの代わりにデータ点1つに対するgradientを使っていたが、SAGだとfull gradientの近似を明示的に保存しておく必要がある。これはパラメーターと同じ次元のベクトルになる。

 ランダムにデータを1つ取ってきてfull gradientを更新するという処理は、full gradient = 各データに対するgradientの和であるので、ナイーブに実装するならば各データに対するgradientを全部保存しておく必要がある。しかし、実際には、各データに対するgradientとは、例えばヒンジロスならば、if y w ・ x > 1 then 0, otherwise y x という感じなので、この場合であれば、gradientが0ベクトルなのか、それともy xなのかだけを覚えておけば用は足りる。他の損失関数でも、超平面からの距離だけ保存しておけば、そこからgradientは再現できる。

 まとめると、SAGはSGDと比べて計算途中で必要な情報が増える。具体的には、

  • 現在のfull gradient
  • 各データのgradientを再現するための情報

を保持しておく必要がある。

また、自然言語処理のように高次元かつスパースなデータを扱う場合、パラメーターの遅延更新を行いたくなるが、そうするともう一本ベクトルが増える。(これも、パラメーターと同じ次元のベクトルになる。)

この解説だけを読んでもよくわからんと思うけど、上記を踏まえた上でlcsgd.ccを読むと、たぶんよくわかると思う。

むー、例がないので、SGDを理解している人にしかわからない説明になってしまった。SGDがよくわからんという人は、preferred research blogの記事を読むか、もしくは日本語入力を支える技術を読んでね!

このエントリーをはてなブックマークに追加

Latest articles