サブロウ丸

サブロウ丸

主にプログラミングと数学

サーベイ: Megatron-LM: Training Multi-Billion Parameter Language Models Using Model Parallelism

@article{shoeybi2019megatron,
  title={Megatron-lm: Training multi-billion parameter language models using model parallelism},
  author={Shoeybi, Mohammad and Patwary, Mostofa and Puri, Raul and LeGresley, Patrick and Casper, Jared and Catanzaro, Bryan},
  journal={arXiv preprint arXiv:1909.08053},
  year={2019}
}

[paper] [video] [blog] [github]

著者は皆NVIDIA

背景

  • transformerベースの自然言語処理モデルが巨大化、学習時のメモリ使用量の増加が課題に
  • メモリ使用量を削減する分散処理モデルの一つがtensor-parallel、これは、層単位のような、model並列よりもより細かい範囲の計算を分散処理するもの

どんなもの?

  • transformerベースのモデルのtensor-parallelismをpytorchをもとに、
  • pytorch.distributedの集団通信APIを用いてカスタマイズして作成した
  • model-parallelismとも組み合わせる

先行研究と比べてどこがすごい?

  • ユーザーがpytorch実装をほぼ変更しなくても良い
  • transformerベースのモデルに特化させた(並列と分割方法はシンプル)

技術や手法のキモはどこ?

  • 同期処理(all-reduce)が少なくなるように次のようにモデルを分割
    • MLP部分は垂直に分割(結果をconcatする形)
    • transformerはヘッドごとに計算(計算資源のあまりはデータ並列で) + MLPは垂直に分割
    • 最終章はvocabularyサイズの変換になるが、そこも分割管理
      • 計算の重複を許しても通信量が多くなりすぎないようにした

どうやって有効だと検証した?

  • GPU数の増加に対し、petaflopsが線形に増加できることを確認
  • GPT-2(83億パラメタ)とBERT(39億パラメタ)を学習
    • 32 DGX-2H server (totalで512 Tesla V100 SXM2 32GM GPUs)を使用
    • 8モデル並列 * 64データ並列
    • WikiText-103とLAMBADAデータセットで学習、SOTAを達成

議論はある?

  • 並列方法はシンプル
  • all-reduceの処理は高速に通信できる環境であることが、ほぼ前提条件
  • そのためtransformerのhead数が最大のテンソル並列数になる --> なのでそこまで並列数を大きくできない(32程度)& メモリの節約はそこまでできないか...
  • そのため、モデル並列 との組み合わせが必要か(実験ではデータ並列との組み合わせのみ)

関連