目次
背景 / モチベーション
近年、深層機械学習分野では、性能向上のためにモデルのサイズやデータセットが巨大化する傾向にあり、例えば自然言語処理モデルのGPT-3では100billion(1千億)ものパラメタが使用されています。
そのような巨大なモデルの学習には複数マシン上でのモデル並列が必要不可欠ですが、その場合マシンノード間での通信が発生します。通常ノード間の通信はノード内の通信に比べて非常に低速であるため、ここがボトルネックになる場合も。
参考: 代表的な自然言語処理モデルサイズの推移
代表的な自然言語処理モデルのパラメタ数の推移を示した図です (billion = 10億)。 この巨大化の傾向は今も続いています。図は"Efficient large-scale language model training on gpu clusters using megatron-lm."のFigure1を引用。
どんなもの?
- "cross-mesh resharding"処理のスケジューリングの最適化
- 計算と通信のオーバーラップに優れたpipeline処理の提案
cross-mesh reshardingとは、あるノードのデバイス群から別のノードのデバイス群へ所望の配置でデータを転送することです。 具体例を次に示します。
cross-mesh reshardingの例
nodeBからnodeAのデバイスへデータの転送リクエストを図示。送信/受信すべきデータの種類がデバイスごとに異なることに注意してください。この図はデバイスd4, d$は0-7のデータを持っており、デバイスd0は0,1,4,5、d1は2,3,6,7のデータを受け取る必要があることを示しています。
cross-mesh reshardingの最適化
このような転送リクエストが与えられたときに (1) どのデバイスがデータを送るか? (2) 送るとしてそのタイミングはどうするか? の最適化が考えられます。図ではd0へはd4かd5のどちらかがデータを送れば良いため、ここで自由度があるし、またデバイスは一度に一つの送信/受信しか行えないので、送信/受信先が複数ある場合にどの順番で通信を行うかを考慮する必要があります。
先行研究と比べてどこがすごい?
- "cross-mesh resharding"処理のスケジューリングの最適化
- 先行研究は集団通信に関するものに相当、ただcross-mesh reshardingのような柔軟なデータ配置を直接行える集団通信はないので、本稿で紹介するような工夫が必要です。
技術や手法のキモはどこ?
- "cross-mesh resharding"処理のスケジューリングの最適化
- 計算と通信のオーバーラップに優れたpipeline処理の提案
- バッチごとに同期をとるパイプラインを採用
- 計算と通信のオーバーラップのために計算をなるべく前倒しで行う
数値実験
実装
設定
cross-mesh resharding単体の評価と、GPT-3 likeモデルとU-Net Transformerモデルの訓練時のスループット評価を行った。
モデル | 比較手法 |
---|---|
cross-mesh resharding単体 | Send/Recv, Alpa |
GPT-3 | Send/Recv, ALpa, Broadcast |
U-Net Transformer | 同上 |
計算環境はAWS p3.8xlarge instance, 4 NVIDIA V100 (16GB) GPUs, 32 vCPUs. GPUはNVLinkで、ノードは10Gpbsのcross-node bandwidthの帯域で接続されている。
cross-mesh reshardingの単体評価
(送信ノードが一つ)
有効帯域で比較、提案手法は次のケースで高いスケーラビリティを示していますね。(a) 受信デバイスが一つのノードに集まっている場合 (b) 受信デバイスが複数のノードにまたがっている場合
(送信ノードが複数)
有効帯域で比較、受信デバイスと送信デバイスが共に複数のノードにまたがっている場合の実験結果。送信リクエストを複数作成し、それぞれで評価。提案手法が高い有効帯域を示しています。
訓練時のスループット
Single Send/Recvはスループットの理論上限値です。提案手法はこの上限値に近い値を記録しています。なぜU-Transformerで特に提案手法がAlpaより改善できているかというとskip-connectionにより、より多くのcross-mesh reshardingがワークロードで発生するから、という考察。たしかに。
次に読むべき論文は?
- Alpa
- Zheng, Lianmin, et al. "Alpa: Automating Inter-and Intra-Operator Parallelism for Distributed Deep Learning." 2022
- CoCoNet
- Jangda, Abhinav, et al. "Breaking the computation and communication abstraction barrier in distributed machine learning workloads." 2022.
- U-Transformer
- Petit, Olivier, et al. "U-net transformer: Self and cross attention for medical image segmentation." 2021.
- Collective Communicaton
- Barham, Paul, et al. "Pathways: Asynchronous distributed dataflow for ML." 2022
cross-mesh reshardingの最適化の詳細
cross-mesh reshardingの例(再掲)
nodeBからnodeAのデバイスへデータの転送リクエストを図示。送信/受信すべきデータの種類がデバイスごとに異なることに注意してください。この図はデバイスd4, d$は0-7のデータを持っており、デバイスd0は0,1,4,5、d1は2,3,6,7のデータを受け取る必要があることを示しています。
多対多通信の分解
この通信は次のような一対多通信に分解できる。上は"$d_4$か$d_5$が$d_0$と$d_1$にデータを送る"、ということを意味しています。
1対多通信戦略
次に1対多通信を行うさいに、どのような通信ワークロードが考えられるかを、Node1, Node2の4つのGPUデバイスにデータを送信する場合を例に見ていきます。まずは次の二つ。(a) 1対1通信をくり返す (b) 半分ずつ送信し、ノード内でデータを共有してもらう
次はこの二つ。(c) 1/4ずつ送信し、受信ノード間で通信してもらう (d) データを細切れにし、順次それぞれのデバイスに通信する。受信デバイスは他のデバイスに送信する。
計算量見積もりでは(d)が最も優れているので、本稿では(d)の送信方法を採用。
通信の最適化
通信タスク$i$について、その送信デバイス候補集合を$n_i$とし、この通信に要する時間を$T_i$とします。1対多通信として(d)を採用することにしたので、$T_i$は定数として算出できることに注意してください。この中から一つデバイス$n_{i} \in n_i$を選択し、$n_{i}$から時刻$S_i$に送信が開始されるとします。ここで、この$n_{i*}$と$S_i$は最適化対象の変数です。最後の通信タスクが終了する時間の最小化を行います。このとき、デバイスは一度に一つのタスクの送受信のみ行えるので(3)の制約が必要です。この問題を解くのは難しかったらしく、近似手法として、枝刈り深さ優先探索、と、ランダム化貪欲法を用いたと報告されています。
1: 1対多通信として(d)を採用することにしたので、$T_i$は定数として算出できる。 2: この問題を解くのは難しかったらしく、近似手法として、枝刈り深さ優先探索、と、ランダム化貪欲法を用いた。
通信の最適化(Load balance only)
問題の簡約化版です。通信タイミング$S_i$は考えずに、送信デバイスごとのタスク時間が平準化されるように$n_{i*}$を決めるもの。これは割り当て問題を解けば良いですね。オーダー(O(n))。
数値実験(cross-mesh resharding最適化)
Naive: 送信デバイスを適当(最小デバイス番号)に設定 Load Balance Only: 送信スケジュールの最適化はせずに送信デバイスの負荷が同等になるようにする
piplelineの詳細
Eager 1F1B schedule
- 1F1Bのスケージュールを元に人の手で改良
- 計算できる部分を前倒しして実行することで、あとは送るだけの状態にしておく
- 受信側で計算中にそのデータの受信を行うことでオーバーラップの効果を狙う
数値実験(Eager 1F1B)
Broadcast: 1F1B & Overlapなし Overlap: 1F1B & Overlapあり Eager 1F1B: 提案手法(Overlapあり)
議論はあるか?
1対多通信(Broadcast-based resharding)
(d)について、データを細切れ(実験では100程度を使用)で送信するとのことだが、細切れにした分通信パケットに付与される情報の割合が高くなるため、通信量としは増加。そのため、これが最適化どうかはよくわからない。あと受信と送信を同時(パイプライン的に)に行う必要があるが、実装はどうなっているのだろう。