サブロウ丸

サブロウ丸

主にプログラミングと数学

DistributedDataParallel (pytorch) の内部デザイン

https://pytorch.org/docs/stable/notes/ddp.html#internal-design の日本語訳 + おまけの脚注 pytorch はv1.12

Internal Design

ここでは、torch.nn.parallel.DistributedDataParallelがどのように動作しているかを、1つの反復処理の各ステップの詳細に踏み込んで明らかにする。

前提条件

DDPはc10d ProcessGroupに依存して通信を行います。したがって、アプリケーションはDDPを構築する前にProcessGroupのインスタンスを作成する必要があります。

構築

  • DDPコンストラクタは、ローカルモジュールへの参照を取り、ランク0のプロセスからグループ内の他のすべてのプロセスにstate_dict()をブロードキャストして、すべてのモデルレプリカが全く同じ状態から開始されるようにします*1
  • その後、各DDPプロセスはローカルのReducerを作成し、backwardパスにおける勾配同期を担います。通信効率を上げるために、Reducerはパラメータの勾配をバケットに整理し、一度に1つのバケットを縮約(reduce)します*2バケットサイズはDDPコンストラクタのbucket_cap_mb引数で設定することができます*3。パラメータ勾配からバケットへの割り当て(マッピング)は、バケットサイズの上限とパラメータサイズに基づいて構築時に決定される。モデルのパラメータは、与えられたモデルから Model.parameters() の逆順に(おおよそ)バケットに割り振られます。逆順にする理由は、DDPはバックワードパスで勾配がほぼこの順番で準備完了になることを期待しているからです。
  • 下図(Implementationの下)はその例です。grad0 と grad1 は bucket1 に、他の2つの gradients は bucket0 にあることに注意してください。もちろん、この仮定、すなわちModel.parameters()の逆順に勾配が計算されること、が常に正しいとは限りません。そうなった場合、Reducerは早いタイミングで通信を開始することができないため、DDPのbackward処理速度に悪影響を与える可能性があります。
  • また、バケット化とは別にReducerは構築時にautogradフックを登録します(1パラメータにつき1フック)。これらのフックはバックワードパスで勾配が準備できたときにトリガーされます。

フォワードパス

find_unused_parametersがTrueに設定されていれば、DDPは入力をローカルモデルに渡す際にローカルモデルからの出力を解析します。このモードでは、DDPはモデル出力からautogradグラフを走査してbackwardパスに関与するパラメータを見つけ出すことで逆説的に未使用パラメータを洗い出し、すべての未使用パラメータに縮約の準備ができたと疑似的なフラグをつけます。パラメータ勾配を準備完了とマーキングすることは、DDPがバケットをスキップする助けにはなりません*4が、バックワードパスの間、DDPが未使用のパラメタの勾配算出を永遠に待つことを防ぐことができます。autogradグラフの操作は余分なオーバーヘッドを発生させるので、アプリケーションは必要なときだけfind_unused_parametersをTrueに設定すべきことに注意してください*5

バックワードパス

backward()関数はDDPの制御外の損失テンソルに対して直接呼び出され、DDPは構築時に登録されたautogradフックを使用して勾配同期をトリガーします。ある勾配が準備完了になると、そのgradアキュムレータの対応するDDPフックが起動し、DDPはそのパラメータ勾配の縮約の準備ができたとマークします。1つのバケット内の勾配がすべて準備完了になると、Reducerはそのバケットに対して非同期allreduceを開始し*6、全プロセスの勾配の平均を計算します。すべてのバケットの準備ができたら、Reducerはすべてのallreduce操作が終了するのを待つためにブロックをします。これが終わると、平均化された勾配が全パラメータの param.grad フィールドに書き込まれます。したがって、バックワードパスの後、異なるDDPプロセス間で同じ対応するパラメータのgradフィールドは同じになるはずです*7

オプティマイザー・ステップ

オプティマイザの観点からは、ローカルモデルの最適化が行われます。すべてのDDPプロセス上のモデル複製は、同じ状態から開始され、すべての反復において同じ平均化された勾配を持つので、同期を保つことができます。

https://user-images.githubusercontent.com/16999635/72401724-d296d880-371a-11ea-90ab-737f86543df9.png

Implementation

以下は、DDPの実装コンポーネントへのポインタです。積み上げられたグラフは、コードの構造を示しています。

https://user-images.githubusercontent.com/16999635/72313120-4e7c1c80-3658-11ea-9c6d-44336b2daeac.png

ProcessGroup

ProcessGroup.hpp: すべてのプロセスグループ実装の抽象APIを含む.c10d*8ライブラリでは,ProcessGroupGloo, ProcessGroupNCCL, ProcessGroupMPIという3つの実装が提供されています.DistributedDataParallelでは、初期化時にランク0のプロセスから他のプロセスにモデル状態を送るためにProcessGroup::broadcast()を、グラディエントを合計するためにProcessGroup::allreduce()を使用しています。

Store.hpp: プロセスグループのインスタンスがお互いを見つけるためのランデブーサービスを支援します。

DistributedDataParallel

distributed.py: DDPのPythonエントリポイントです。C++のライブラリを呼び出すnn.parallel.DistributedDataParallelモジュールの初期化ステップとフォワード関数が実装されています.また,_sync_param関数は,1つのDDPプロセスが複数のデバイスで動作する場合に,プロセス内のパラメータ同期を行い,ランク0のプロセスから他のすべてのプロセスへモデルバッファをブロードキャストします.プロセス間パラメータ同期はReducer.cppで行われます。

comm.h: coalesced broadcastヘルパー関数を実装しており、初期化中にモデルの状態をブロードキャストし、フォワードパスの前にモデルバッファを同期させるために呼び出されます。

reducer.h: バックワードパスにおける勾配同期のコア実装を提供します。3つのエントリポイント関数があります。

  • Reducer: コンストラクタはdistributed.pyで呼ばれ、Reducer::autograd_hook()をgradient accumulatorsに登録します。
  • autograd_hook()関数は、勾配が準備できたときにautogradエンジンによって呼び出されます。
  • prepare_for_backward()は、distributed.pyのDDPフォワードパスの最後に呼び出されます。これは、DDPのコンストラクタでfind_unused_parametersがTrueに設定されている場合、未使用のパラメータを見つけるためにautogradグラフを走査します。

*1:全てのrankのモデルは同じパラメタ値を持つ状態で学習開始を待つ。

*2:これによりbackward計算と通信を同時に行うことができる。つまり通信レイテンシを(多少?)隠蔽することが可能。

*3:bucket_cap_mb=1で1MBのバケットサイズになる。デフォルトは25MB。https://github.com/pytorch/pytorch/blob/4618371da56c887195e2e1d16dad2b9686302800/torch/nn/parallel/distributed.py#L460-L464

*4:すなわち通信量は変わらない

*5:要するにモデルパラメタに未使用のものがなければfind_unused_parameters=Falseとする。もしforward内部で条件分岐などで使用されないパラメタがある場合はfind_unused_parameters=Trueとすることで通信処理の開始を遅らせないことができる。

*6:通信を行うスレッドがあり、通信はジョブ(Work)としてqueueに入力される。そのqueueの中のWorkを順に処理していく。同期が必要な部分はwork->wait()でそのworkが終了するのを待つ。

*7:すなわちパラメタの共有は初めの1度のみ

*8:おそらくcdistributedの略。先頭のcと末尾のdの間に10文字あるのでk8sのようにそう略しているっぽい。