PytorchのDeviceMesh
DeviceMeshはGPUなどのリソースグループを管理するツールで、これを使えば分散学習に割り当てるGPUリソースを柔軟に割り当てられる。
分散並列手法にはいくつかの種類があり、大まかにデータ並列とモデル並列の二つがある。LLMのようなパラメータ数が多いモデルを学習する際には、これらを組み合わせて実行したいことがある。その際、サーバー内ではモデル並列、サーバー間でデータ並列を行うといった細かい指定が可能だ。
図は以下のDeviceMeshドキュメントから引用。
Getting Started with DeviceMesh — PyTorch Tutorials 2.7.0+cu126 documentation
DeviceMeshではデバイス(計算リソース)が多次元行列の形式で並んでいることを前提としている。例えば、6つのデバイスが2行3列で並んでいると考える(2枚のGPUを持つサーバーが3つある状況を想定するとわかりやすい)。これは2次元の例だが、3次元以上で計算リソースを整理することも可能だ。
以下では、この設定、つまり2枚のGPUを持つサーバーが3つある状況を想定して説明する。次はDeviceMeshを作成するテストスクリプトだ。
import torch import torch.distributed as dist from torch.distributed.device_mesh import DeviceMesh def main(): dist.init_process_group(backend="gloo") rank = dist.get_rank() world_size = dist.get_world_size() mesh = DeviceMesh( device_type="cpu", mesh=[[0, 1], [2, 3], [4, 5]], mesh_dim_names=("row", "col") ) import time; time.sleep(rank) print(f"Rank {rank}: mesh['row'] = {mesh['row']}") print(f"Rank {rank}: mesh['col'] = {mesh['col']}") dist.destroy_process_group() if __name__ == "__main__": main()
- 実行コマンド
torchrun --nproc_per_node=6 run_device_mesh.py- (3つのサーバーがある環境を想定したが、簡単な説明のためこのコマンド自体は一つのホストで6つのプロセスを実行するものになっている)
- DeviceMeshの作成について
- 内部処理の都合上、分散処理の準備(プロセスグループの作成や適切な環境変数)が整っている必要がある
- mesh=[[0, 1], [2, 3], [4, 5]] のようにデバイス番号を与える。もし2つのGPUを持つサーバーが3つある場合は、rankが若いサーバーからデバイス番号が認識される。
- mesh['row']で行方向のデバイス番号を取得する。実行するデバイスごとに結果が異なる。上記のシチュエーションで行くと、自身と同じサーバーにあるデバイスの番号を取得できる。
- mesh['col']で列方向のデバイス番号を取得。これも実行するデバイスごとに結果が異なる。上記のシチュエーションで行くと、複数サーバーをまたがって同じローカル番号のデバイスの番号を取得できる。
このスクリプトの出力結果は以下のようになる。
Rank 0: mesh['row'] = DeviceMesh('cpu', [0, 1], mesh_dim_names=('row',))
Rank 0: mesh['col'] = DeviceMesh('cpu', [0, 2, 4], mesh_dim_names=('col',))
Rank 1: mesh['row'] = DeviceMesh('cpu', [0, 1], mesh_dim_names=('row',))
Rank 1: mesh['col'] = DeviceMesh('cpu', [1, 3, 5], mesh_dim_names=('col',))
Rank 2: mesh['row'] = DeviceMesh('cpu', [2, 3], mesh_dim_names=('row',))
Rank 2: mesh['col'] = DeviceMesh('cpu', [0, 2, 4], mesh_dim_names=('col',))
Rank 3: mesh['row'] = DeviceMesh('cpu', [2, 3], mesh_dim_names=('row',))
Rank 3: mesh['col'] = DeviceMesh('cpu', [1, 3, 5], mesh_dim_names=('col',))
...
DeviceMeshを使ったデータ並列とモデル並列
モデル並列(テンソル並列; TP)
モジュールごとに分散処理の方式を指定できる。モジュールは線形層(nn.Linear)やモデル(nn.Model)などを含む抽象概念。ここでは上記のdevice meshを使って説明する。
parallelize_module
たとえば、同じサーバー内の2つのGPU間でTPを行いたい場合は、下記のようになる。
m = torch.nn.Linear(10, 10) # 線形層を作成 tp_mesh = mesh["row"] # TPにしたい次元 m = parallelize_module( module=m, device_mesh=tp_mesh, parallelize_plan=ColwiseParallel() )
ここで
tp_mesh = mesh["col"]
とすると、RuntimeError: ('Found TP device_mesh on the 0 dimension of its parent mesh.', 'Currently we only support intranode TP and TP needs to be the innermost dimension on its parent mesh.')というエラーが出る。device_mesh の一番外側(つまり距離的には近い)デバイスしかTPが使えないとある。親切だなぁ(?)。TP自体はテンソル演算ごとに通信を必要とするため、確かに近いデバイス間でTPを行うのが一般的な戦略である。
ParallelStyle
parallelize_moduleでは並列化の方式を指定する。https://github.com/pytorch/pytorch/blob/cbcf6772231a2c216be707627b6613e8c79a86ed/torch/distributed/tensor/parallel/style.py#L45
次のような方式が使えるようだ。
class ColwiseParallel(ParallelStyle): class RowwiseParallel(ParallelStyle): class SequenceParallel(ParallelStyle): class PrepareModuleInput(ParallelStyle): class PrepareModuleOutput(ParallelStyle): class PrepareModuleInputOutput(ParallelStyle):
- ColwiseParallel
- 入力をbroadcast(コピー)
- 独立にテンソル演算
Input → [ Linear (W_col_0) | Linear (W_col_1) ]
↓ ↓
Output_0 Output_1
[ Input_0 | Input_1 ] → [ Linear (W_row_0) | Linear (W_row_1) ]
↓ ↓
Partial Output Partial Output
↓ ↓
(Reduce)
↓
Final Output
1層目はRowwiseParallelで、2層目はColwiseParallelで、みたいな細かい指定も可能。
m1 = torch.nn.Linear(10, 10) m2 = torch.nn.Linear(10, 10) m = torch.nn.Sequential(m1, m2) tp_mesh = mesh["row"] pm = parallelize_module(module=m, device_mesh=tp_mesh, parallelize_plan={"m1": RowwiseParallel(), "m2": ColwiseParallel()}) print(pm)
DP (Data Parallel) + TP
データ並列とテンソル並列を両方用いるパターン。といっても、DDPを被せるだけ。注意する点としてモデルのモジュールにTPを設定する前に、モデルをDDPモデルに変換しておく必要がある。
# Define the model model = torch.nn.Sequential( torch.nn.Linear(10, 10), torch.nn.ReLU(), torch.nn.Linear(10, 1), ) # Data Parallelism setup dp_mesh = mesh["col"] dp_group = dp_mesh.get_group() print(f"Rank {rank}: dp_group = {dp_group}") # First, apply DDP model = DDP(model, process_group=dp_group) # Then, apply Tensor Parallelism tp_mesh = mesh["row"] model.module = parallelize_module( module=model.module, device_mesh=tp_mesh, parallelize_plan={ "0": ColwiseParallel(), "2": RowwiseParallel(), }, )
あとはサンプラーの作成する。この場合はrow (=サーバー内)でTPを行っているため、同じサーバー内のデバイスには同じ入力データが入力されるようにしなければならない。以下のようにDPの個数をnum_replicasで指定し、自身のrankがDPの何番目のレプリカであるかをdp_rankとして入力してやれば良い。
global_rank = dist.get_rank()
dp_rank = dist.get_rank(group=dp_group)
dp_world_size = dist.get_world_size(group=dp_group)
sampler = DistributedSampler(
dataset,
num_replicas=dp_world_size,
rank=dp_rank,
shuffle=True,
drop_last=False
)
DistributedSamplerのコード
import torch from torch.nn.parallel import DistributedDataParallel as DDP import torch.distributed as dist from torch.distributed.device_mesh import DeviceMesh from torch.distributed.tensor.parallel import parallelize_module, ColwiseParallel, RowwiseParallel class LinearWrapper(torch.nn.Linear): ... class LinearWrapper1(LinearWrapper): ... class LinearWrapper2(LinearWrapper): ... class LinearWrapper3(LinearWrapper): ... class SimpleModel(torch.nn.Module): def __init__(self, rank): super(SimpleModel, self).__init__() self.rank = rank self.linear1 = LinearWrapper1(10, 10) self.linear2 = LinearWrapper2(10, 10) self.linear3 = LinearWrapper3(10, 1) # Register backward hooks for all sequential layers def backward_hook(module, grad_input, grad_output): print(f"[rank{self.rank}] Backward pass - {module.__class__.__name__}") return grad_input def forward_hook(module, input, output): print(f"[rank{self.rank}] Forward pass - {module.__class__.__name__}") return output self.linear1.register_forward_hook(forward_hook) self.linear2.register_forward_hook(forward_hook) self.linear3.register_forward_hook(forward_hook) self.linear1.register_full_backward_hook(backward_hook) self.linear2.register_full_backward_hook(backward_hook) self.linear3.register_full_backward_hook(backward_hook) def forward(self, x): x = torch.relu(self.linear1(x)) x = torch.relu(self.linear2(x)) x = torch.relu(self.linear3(x)) return x def main(): dist.init_process_group(backend="gloo") rank = dist.get_rank() world_size = dist.get_world_size() if rank == 0: import os os.environ["TORCH_CPP_LOG_LEVEL"] = "INFO" os.environ["TORCH_DISTRIBUTED_DEBUG"] = "DETAIL" print(f"Rank {rank}/{world_size} is running.") mesh = DeviceMesh( device_type="cpu", mesh=[[0, 1], [2, 3], [4, 5]], mesh_dim_names=("col", "row") ) import time; time.sleep(rank * 0.5) print(f"Rank {rank}: mesh['row'] = {mesh['row']}") print(f"Rank {rank}: mesh['col'] = {mesh['col']}") # Define the model # model = torch.nn.Sequential( # torch.nn.Linear(10, 10), # torch.nn.ReLU(), # torch.nn.Linear(10, 1), # ) model = SimpleModel(rank) # Data Parallelism setup dp_mesh = mesh["col"] dp_group = dp_mesh.get_group() print(f"Rank {rank}: dp_group = {dp_group}") # First, apply DDP model = DDP( model, process_group=dp_group, find_unused_parameters=True ) # Then, apply Tensor Parallelism tp_mesh = mesh["row"] model.module = parallelize_module( module=model.module, device_mesh=tp_mesh, parallelize_plan={ "linear1": ColwiseParallel(), "linear2": RowwiseParallel(), }, ) # training step model.train() optimizer = torch.optim.SGD(model.parameters(), lr=0.01) get_local_root_rank = lambda mesh: mesh.mesh[0].item() if rank == get_local_root_rank(tp_mesh): input = torch.randn(3, 10) else: input = torch.empty(3, 10) dist.broadcast(input, src=get_local_root_rank(tp_mesh), group=tp_mesh.get_group()) for _ in range(10): output = model(input) loss = output.sum() print(f"Rank {rank}: loss = {loss}") loss.backward() optimizer.step() optimizer.zero_grad() # finalizes dist.destroy_process_group() if __name__ == "__main__": main()
TORCH_CPP_LOG_LEVEL=INFO TORCH_DISTRIBUTED_DEBUG=DETAILの環境変数を設定して、呼び出された集団通信のログが表示される。
