[paper] [video] [blog] [github]
著者は皆NVIDIA
概要
背景
- transformerベースの自然言語処理モデルが巨大化、学習時のメモリ使用量の増加が課題に
- メモリ使用量を削減する分散処理モデルの一つがtensor-parallel、これは、層単位のような、model並列よりもより細かい範囲の計算を分散処理するもの
どんなもの?
- transformerやfeedforward計算の部分をtensor-parallelismにより並列に計算する
- また複数のlayerをまとまりとした model-parallelismとも組み合わせる
- pytorch.distributedの集団通信APIを用いてカスタマイズして実装
先行研究と比べてどこがすごい?
- transformerベースのモデルに特化した並列手法を提案
- pytorchベースの実装を公開
技術や手法のキモはどこ?
- 同期処理(all-reduce)が少なくなるように次のようにモデルを分割
- MLP部分は垂直に分割
- 入力X、 出力Zに対して Z = Dropout(GeLU(XA)B) の計算を Droptout(GeLU(X[A_1, A_2])[B_1, B_2])とモデルパラメタA、Bを分割すると
Y_1 = GeLU(XA_1) → Z_1 = Y_1B_1
とY_2 = GeLU(XA_2) → Z_2 = Y_2B_2
の計算を独立に行える- Z = [Z_1, Z_2]と各々の出力をgatherし、Dropoutを実行
- 次の層も同じように並列に計算するなら出力Zをbroadcastする
- transformerはヘッドごとに並列に計算(計算資源のあまりはデータ並列を行える) + MLP部分は上記のように垂直に分割
- MLP部分は垂直に分割
- 最終層はvocabularyサイズのアフィン変換があるが、そこも垂直に管理
- vocabularyはかなり大きくする必要があるので、実は結構ここでパラメタ数が必要になる
- 実際V = 51,200語彙で論文中では実験しているので1層のaffineでもh(word 埋め込みサイズ)に対し、パラメタ数はh * V個
- h = 2304 → h * V は約1.1億, h = 25600 → h * V は約13億パラメタ
どうやって有効だと検証した?
- GPU数の増加に対し、petaflopsが線形に増加できることを確認
- GPT-2(83億パラメタ)とBERT(39億パラメタ)を学習
- 32 DGX-2H server (totalで512 Tesla V100 SXM2 32GM GPUs)を使用
- 8モデル並列 * 64データ並列
- WikiText-103とLAMBADAデータセットで学習、SOTAを達成
議論はある?
- tensor-parallelismではbroadcast, gather, all-reduceなどの集団通信を多用するので、その処理が高速な環境で実行することがほぼ前提条件
- そのためtransformerのhead数が最大のテンソル並列数になり、ここが制約になり並列数をそれほど大きくできない(32程度)
関連
演算量と通信量
ここではモデル内における計算量と通信量について論文中の記載をもとにまとめます。まずは使用する記号をまとめます
記号 | 説明 | 数値例(Table 1から抜粋) |
---|---|---|
B | バッチサイズ | 512 - 3072 |
l | transformer層数 | 24 - 128 |
h | wordの埋め込み次元数 | 2304 - 25600 |
s | sentenseの長さ | 2048 |
V | 語彙数 | 51200 |
以下ではカギかっこでテンソル行列のサイズを表します。例えば[B, s, h]は3次元でそれぞれの大きさがB, s, hのテンソルを表しています。
また次で使いますが [m, k] 行列と[k, n]行列の積の演算数は2mknです(要素の積と和の総数)
演算量
1層のtransformerでの計算量(演算数)を見ていきます。
forward
単純なattentionの計算では入力Xに対して
- Q = X W_Q, K = X W_K, V = X W_V
- softmax(QK/d) V
の計算が行われます(ここでdは定数) さて、自然言語モデルの訓練の場合はそれぞれの行列のサイズは下記のようになります。
行列 | テンソルサイズ |
---|---|
X | [B, s, h] |
W_Q, W_K, W_V | [h, h] |
Q, K, V | [B, s, h] |
softmax(QK/d) | [B, s, s] |
バッチサイズにより3次元のテンソル演算が含まれますが、[B, m, k] 行列と[k, n]行列の積の場合は[m, k]行列と[k, n]行列の積をB回行うことに相当するので演算数は2Bmknになります。これらより、入力Xに対するattension部分のforwardの際の演算量はQ, K, Vの演算量は合わせて6Bsh2、softmax(QK/d) Vは4Bs2 hとなるため
になります。同様にattensionの次のfeedforward層の部分の演算量については、中間層のサイズを4h(論文ではこうしている)にすると、入力X [B, s, h]に対し、W_1 [h, 4h], W_2 [4h, h] 行列を(非線形変換を挟んで)作用させるため、その演算量は
ということで全体の演算量としては、
になります。
図にまとめるとこんな感じ。括弧の中に演算量を記載しています。
backward
次はbackwardではどれほどの演算量になるかを考えますが、結論からいうと
になります。内訳を見ていくと、activation checkpointingと呼ばれる演算に1forward分、勾配算出に2forward分の演算量が必要になります。勾配計算から見ていくと、forwardでは
という演算に対してCの勾配dCを用いてAの勾配をdA = dC * XT、Xの勾配を dX = AT * dC で求めます。ということで1つの行列演算からなるforwardの場合はbackward時に2回行列演算が行われるんですね。よく計算すると、backwardの勾配算出の際には2forward分の演算量が必要であることが分かります。
またactivation checkpointingについてですが、これはメモリ節約のために行われるテクニックです。backwardの勾配算出の際にはforward時に行われた行列演算の結果(activation)を用います。しかし、forward時の全てのactivationを保存しようとしても、ハードウェアのメモリ容量の制限によりその全てを保存できない場合があります。その対処として、forward時には全てのactivationの保存は諦めて、backward時にもう一度局所的にforwardを行なって必要なactivationを復元する、という戦略が用いられます。これがactivation checkpointingと呼ばれるテクニックです。これにより局所局所のチェックポイントの行列演算結果さえ覚えておけば良くなるので、メモリ使用量を大きく節約できます。このactivationをcheckpointing用いるとbackward時にforwardがモデル全体でもう一度行われることになるので演算量も1forward分だけ追加で必要になります。
以上をまとめるとbackward時には3forward分の演算量を要する、というわけです。
最後に+alphaですがこれはパラメタ更新に要する演算量です。パラメタの勾配を求めた後にその値を用いて最適化アルゴリズム(SGDやAdamなど)に従ってパラメタを更新します。この際の更新量を計算する際には32bitで行われる場合が多いようです(それ以外の計算は16bitで行う)。ということで、パラメタ数 * 2(精度が2倍)程度の演算量が追加で必要になります。
通信量
forwardとbackwardを行う際のテンソル並列の通信量について記載します。
forward
1layerにおける通信の流れを書くと図のようになります。
[B, s, h]サイズの行列のbroadcast1回と、[B, s, h]サイズのテンソルを集めるgatherが1回行われます。(要するに[B, s, h]をそれぞれのプロセスが保持するallgatherが実行される)
backward
backward時は勾配算出の際の流れはこんな感じ。矢印は逆方向を向くと置き換えてください。
[B, s, h]サイズの行列のbroadcast1回と、[B, s, h]サイズのテンソルを集めるreduceが1回行われます。(要するに[B, s, h]サイズのallreduceが実行される)
またactivationを行う場合はforwardが行われるのでforwardと同じ量の通信が必要になります。
更新履歴
- (2023/5/3) Shion Hondaさんのコメントを元に演算量と通信量 > backwardのパラグラフを修正しました。具体的にはactivation→activation checkpointingの用語の適正化です。コメント大変ありがとうございました。