射撃しつつ前転

Large Batch TrainingとCritical Batch Sizeについて

DNNの大規模学習では、大規模化に伴い、必然的にbatch sizeが上がってしまう傾向にある。しかし、batch sizeがある程度以上の規模を超えると、それ以上batch sizeを大きくしても、精度向上につながらなくなる。つまり、大量に計算機資源を投入しても、得られるリターンが(ほぼ)なくなってしまう。この臨界点をcritical batch sizeという。

以下では、いくつかの論文のお気持ちを考察しつつ、critical batch sizeの問題について考えたい。

なぜBatch Sizeが大きくなるのか

複数のプロセッサを使ってひとつのモデルの学習を行う場合、計算タスクは、おおまかには以下の3種類の並列化手法の組み合わせによって分割される。

  • data parallel
  • pipeline parallel
  • tensor parallel

この3つの中で、台数を増やす際にbatch sizeを上げないで済むのはtensor parallelだけである。(pipeline parallelは、実効性能を保つためには最低でも台数と同じ程度のbatch sizeを必要とする。例:Zero Bubble (Almost) Pipeline Parallelism。)そして、tensor parallelだけでは並列度を稼ぐのが難しい。行列積を分割実行するわけなので、あまり細かく分割するとひとつひとつの計算タスクが小さくなりすぎる。

となると、大規模に学習しようとすると、data parallelやpipeline parallelも採用せざるをえず、伝統的な同期的計算で並列化するならば、batch sizeを上げるしかないのである。

用語解説

専門用語が色々出てきてしまうので、ここで一度、解説を入れておく。

  • mini batch size: いわゆる batch size
  • micro batch size: 実際にモデルに流す際のbatch size
  • gradient accumulation: micro batchを流して得たgradientをそのままモデル更新に使うのではなく、いくつか累積して後からまとめてパラメーターを更新すること
    • 次式のような理解でよいはず:mini batch size = gradient accumulation steps * micro batch size * number of data parallel

Critical Batch Sizeとは

上述の通り「このバッチサイズを超えると学習の効率が落ちるよ」という限界があり、critical batch sizeと呼ばれる。An Empirical Model of Large-Batch Training においてこの概念が提唱された。

どうしてこのような現象が発生するか。同一の計算資源、計算時間のもとでbatch sizeを大きくすると、勾配のノイズが減少することによる精度向上と更新回数の減少による精度低下が同時に発生する。バッチサイズが小さい領域では前者が勝つ。しかし、前者による改善はある程度のバッチサイズで限界を迎え、そこを超えると更新回数の減少による精度低下だけが残る。これがcriitcal batch sizeという現象が出てくる仕組みである。というような説明が書いてあったように思う。

critical batch sizeは、ミニバッチ内の勾配分散と、全体の勾配から大まかに推定可能である。以下のように考えればよいだろう。

  • mini batchの勾配の分散が大きい ⇒ その勾配情報にはノイズがたくさん含まれている
  • mini batchの勾配の分散が、データ全体の勾配と同じくらい ⇒ノイズが十分にキャンセルされている

当然だが、具体的なデータセットとモデルによってcritical batch sizeは変化する。

実用的なBatch Sizeはどれくらいか

ImageNetの学習においては、数千〜数万程度のbatch sizeが当然のように使われており、データセットやモデルの規模から考えると割と十分なbatch sizeであった。

一方で、LLM学習におけるbatch sizeは、実はあまり大きくない。調査した範囲では、すべての事例において、batch sizeは512〜2048の間であった。これを超えると精度が下がる、と考えられているようである。batch sizeが2048までしか許されないとなると、大規模学習においては問題となり得る。

例として、1万枚のGPUを使って学習する場合を考えてみよう。このとき、data parallelとpipeline parallelの2種類を組み合わせて学習できるだろうか? data parallelとpipeline parallelは目安としては並列度と同程度のバッチサイズが必要になるため、どちらを使うにせよ、バッチサイズが最低でも1万になってしまう。つまり、batch sizeの制約によって、1万台のサーバーは活用できない。そこで、tensor parallelでtensor parallelで並列度を8くらい取ることにすると、tensor parallelはバッチサイズが10000/8=1250で抑えられる。

1250は2048よりは十分に小さいのでこれでよし、というわけには行かず、実際には、micro batch sizeをメモリが許す範囲で大きくしたい(そうしないと行列積がだったはずの計算が行列-ベクトル積になってしまい、GPUの利用効率が著しく落ちる)という事情がある。そのため、512〜2048というbatch sizeは、実際には、1万台規模で学習するにあたっては極めて厳しい制約である。

Critical Batch Sizeと関係の深い数字はなにか

critical batch sizeはmodelやdatasetに依存する。では、modelとdatasetのどちらに対してより強い依存があるのか。もしくは、両者に大きく依存するのか。

How Does Critical Batch Size Scale in Pre-training?では、大量の実験から、critical batch sizeはdataset sizeと相関が強いことを発見した。つまり、より大規模なdatasetを用いれば、critical batch sizeを上げられる。データの枯渇が叫ばれている現在、「より大規模なdatasetを使おうね」と言われても困る気もするが…。こういうところも、LLMによる学習データの生成の研究が着目されている理由の一端になっている、と言ってもよいのかもしれない。Phi-4 Technical Reportによると、最近公開されたPhi-4というモデルでは、学習データのうち40%は合成データらしい。

Federated LearningによるCritical Batch Sizeの回避の可能性

これまでの話は、伝統的な、完全に同期的な学習(1GPUで計算しても、1万GPUで計算しても、計算順序の違いによる浮動小数点数演算の誤差以外には違いがない)を前提としてきた。この条件を諦めることで、スケールを上げられる余地はまだありそうに見える。

計算機をいくつかのクラスタに分け、クラスタ間でたまに情報を共有しつつ並列に学習を行う方法をfederated learning(連合学習)と呼ぶ。federated learningにはプライバシー保護、通信量削減などのメリットが挙げられているが、それぞれのクラスタが学習する際のbatch sizeを伝統的手法より小さくできるという点もメリットになり得る。

Don’t Use Large Mini-Batches, Use Local SGD では、critical batch sizeという言葉は出てこないが、batch sizeを大きくすると精度が低下する場合があり、local SGD(各クラスタは独立に学習を行い、H step毎にパラメーターを共有し、その平均を新しいモデルパラメーターとする)を用いることで、単にlarge batch sizeで学習するよりも高い汎化性能を実現できると報告している。ただし、データセットとしてはCIFAR-10/100, ImageNetが使われており、LLMについての研究ではない。

Why (and When) does Local SGD Generalize Better than SGD? では、local SGDについての研究の中では比較的大規模な実験が行われている。α = η H が大事なハイパーパラメーターであると述べられている。上述のDon’t use Large Mini-Batchesの論文にも示されている通り、local SGDは単なるSGDよりもtest accuracyが良い場合がある。この論文ではその条件をある程度明らかにしているが、learning rateは下げる、Hの頻度は割と低い方がいい、程度のことしかまだわかっていないように見える。

DiLoCo: Distributed Low-Communication Training of Language Modelsについては先日紹介したが、これも一種のfederated learningである。クラスタ内部で使うinner optimizerと、クラスタ間情報を学習祭に使うouter optimizerの2種類のoptimizerを用意する。H step学習した後、学習前後でのパラメーターの差分をgradientだとみなしてクラスタ間で共有し、outer optimizerを使ってパラメーターを更新する。ただし、DiLoCoではfederated learningの主眼をクラスタ間の通信量の削減に置いており、critical batch sizeについては特に言及はない。

感想

critical batch sizeについて、そのお気持ちを解説した。どうやって巨大なモデルを学習するのか、そのために必要なネットワーク帯域はどれくらいになるのか、というのは、ハードウェアを設計する側や購入する側からすると非常に重要な情報である。もしfederated learningやDeMoのような手法の組み合わせで大規模分散学習が可能となるなら、ネットワーク帯域はかなり節約でき、結果、投資をかなり節約できることになる。

しかし、調査した範囲では、federated learningの分野においてはcritical batch sizeの問題は今の所重視されていないように見える。それよりも汎化性能の方が重視されているようだが、問題設定が複雑な分、全貌がまだ見渡せない。例えば、federated learningにおけるクライアントの各ラウンドでの参加率は高いほうがよいように思われるが、実際には参加率が高すぎると汎化性能が落ちる、というような話があったりする。おそらくは、そこにbatch sizeの問題も絡んでくるため、どこかに唯一の大きな問題が潜んでいるのか、それとも、複数の問題が混ざっていてよくわからない状況になっているのか、その区別すらつけられない状況であるように見える。

簡単だったはずの問題のスケールを上げようとするととたんに難しい問題になる、というのは様々なところで頻出する問題であるが、機械学習は小さいスケールでもそもそもややこしい問題なので、スケールを上げようとするともう訳がわからないくらいに難しい問題になってしまっているように思う。スケールを上げると汎化性能に影響があったりもするわけで、問題の抽象化が難しい。

計算機環境のスケールを上げたくなるのはモデルを大きくしたいからであり、一方で、モデルのサイズはある程度でよい、他の工夫で問題を解いていくのだ、というアプローチを取っているグループもあり、例えば、上述のMicrosoftのPhiシリーズもそうである。

スケーラビリティを追い求めるのは莫大なコストがかかる一方で、scaling lawによって性能の向上が予言されている。スケーラビリティ以外の工夫はいつ何が起こるのか予想しづらいが、非連続的なすごいブレークスルーの可能性が期待できる。どちらかだけを追求するのが正解というものでもなく、状況に応じて両方をバランスよくやっていくことが必要なのではないか。