本稿ではDistributedDataParallelのサンプルコードを示し、また実行中にどのような通信が行われているかを確認します。
参考:
- Getting Started with Distributed Data Parallel — PyTorch Tutorials 1.13.0+cu117 documentation
- pytorch DistributedDataParallel 事始め - Qiita
- PyTorchでの分散学習時にはDistributedSamplerを指定することを忘れない! - Qiita
サンプルコード
GLOO backend
並列数はプログラム内のworld_size変数を直接編集。 python (プログラム).py で実行
import os import sys import tempfile import numpy as np import torch import torch.distributed as dist import torch.nn as nn import torch.optim as optim import torch.multiprocessing as mp from torch.nn.parallel import DistributedDataParallel as DDP def setup(rank, world_size): os.environ["MASTER_ADDR"] = "localhost" os.environ["MASTER_PORT"] = "12355" # initialize the process group dist.init_process_group("gloo", rank=rank, world_size=world_size) def cleanup(): dist.destroy_process_group() class ToyDataset(torch.utils.data.Dataset): def __init__(self, input_size, output_size, length): self.len = length self.data = torch.randn(length, input_size) self.label = torch.randn(length, output_size) def __getitem__(self, index): return self.data[index], self.label[index] def __len__(self): return self.len class ToyModel(nn.Module): def __init__(self): super(ToyModel, self).__init__() self.net1 = nn.Linear(10, 10) self.relu = nn.ReLU() self.net2 = nn.Linear(10, 5) def forward(self, x): return self.net2(self.relu(self.net1(x))) def demo_basic(rank, world_size): setup(rank, world_size) print(f"Running basic DDP example on rank {rank}.") # create model and move it to GPU with id rank model = ToyModel() ddp_model = DDP(model) dataset = ToyDataset(input_size=10, output_size=5, length=100) sampler = torch.utils.data.distributed.DistributedSampler( dataset, num_replicas=world_size, rank=rank, shuffle=False ) data_loader = torch.utils.data.DataLoader( dataset, batch_size=5, shuffle=sampler is None, sampler=sampler ) loss_fn = nn.MSELoss() optimizer = optim.SGD(ddp_model.parameters(), lr=0.001) for data, labels in data_loader: print("rank", rank, "data", data, "labels", labels) optimizer.zero_grad() outputs = ddp_model(data) loss_fn(outputs, labels).backward() optimizer.step() cleanup() def run_demo(demo_fn, world_size): mp.spawn(demo_fn, args=(world_size,), nprocs=world_size, join=True) if __name__ == "__main__": world_size = 2 run_demo(demo_basic, world_size)
MPI backend
並列数は mpirun -n (並列数) python (プログラム).py というように実行時のmpirunのオプションで指定。
import os import sys import tempfile import numpy as np import torch import torch.distributed as dist import torch.nn as nn import torch.optim as optim import torch.multiprocessing as mp from torch.nn.parallel import DistributedDataParallel as DDP def setup(): # initialize the process group dist.init_process_group("mpi") rank = dist.get_rank() world_size = dist.get_world_size() return rank, world_size def cleanup(): dist.destroy_process_group() class ToyDataset(torch.utils.data.Dataset): def __init__(self, input_size, output_size, length): self.len = length self.data = torch.randn(length, input_size) self.label = torch.randn(length, output_size) def __getitem__(self, index): return self.data[index], self.label[index] def __len__(self): return self.len class ToyModel(nn.Module): def __init__(self): super(ToyModel, self).__init__() self.net1 = nn.Linear(10, 10) self.relu = nn.ReLU() self.net2 = nn.Linear(10, 5) def forward(self, x): return self.net2(self.relu(self.net1(x))) def demo_basic(): rank, world_size = setup() print(f"Running basic DDP example on rank {rank}.") # create model and move it to GPU with id rank model = ToyModel() ddp_model = DDP(model) dataset = ToyDataset(input_size=10, output_size=5, length=10) sampler = torch.utils.data.distributed.DistributedSampler( dataset, num_replicas=world_size, rank=rank, shuffle=False ) data_loader = torch.utils.data.DataLoader( dataset, batch_size=2, shuffle=sampler is None, sampler=sampler ) loss_fn = nn.MSELoss() optimizer = optim.SGD(ddp_model.parameters(), lr=0.001) for data, labels in data_loader: print("rank", rank, "data", data, "labels", labels) optimizer.zero_grad() outputs = ddp_model(data) loss_fn(outputs, labels).backward() optimizer.step() cleanup() def run_demo(demo_fn): demo_fn() if __name__ == "__main__": run_demo(demo_basic)
通信確認
torch/csrc/distributed/c10d/ProcessGroupMPI.cppを次のように変更してコンパイル。 実行中にどのタイミングでどの通信が行われるかを確認してみる。次のようにプログラムの一部を変更して行われる通信を表示する。
diff --git a/torch/csrc/distributed/c10d/ProcessGroupMPI.cpp b/torch/csrc/distributed/c10d/ProcessGroupMPI.cpp index 556ab13887..e9a84e6f87 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupMPI.cpp +++ b/torch/csrc/distributed/c10d/ProcessGroupMPI.cpp @@ -374,6 +376,7 @@ c10::intrusive_ptr<ProcessGroup::Work> ProcessGroupMPI::enqueue( const c10::optional<std::vector<at::Tensor>>& inputTensors) { auto work = c10::make_intrusive<WorkMPI>(entry->dst, profilingTitle, inputTensors); std::unique_lock<std::mutex> lock(pgMutex_); + std::cout << "PYTORCH:ProcessGroupMPI -- rank " << getRank() << ": enqueu " << profilingTitle << std::endl; queue_.push_back(std::make_tuple(std::move(entry), work)); lock.unlock(); queueProduceCV_.notify_one();
ビルド方法は下記を参考。
実行コード
ddp_mpi.py
import os import sys import torch import torch.distributed as dist import torch.nn as nn import torch.optim as optim from torch.nn.parallel import DistributedDataParallel as DDP class ToyDataset(torch.utils.data.Dataset): def __init__(self, input_size, output_size, length): self.len = length self.data = torch.randn(length, input_size) self.label = torch.randn(length, output_size) def __getitem__(self, index): return self.data[index], self.label[index] def __len__(self): return self.len class ToyModel(nn.Module): def __init__(self): super(ToyModel, self).__init__() self.net1 = nn.Linear(10, 10000) self.relu1 = nn.ReLU() self.net2 = nn.Linear(10000, 10000) self.relu2 = nn.ReLU() self.net3 = nn.Linear(10000, 5) def forward(self, x): x = self.relu1(self.net1(x)) x = self.relu2(self.net2(x)) x = self.net3(x) return x def run(): # initialize the process group dist.init_process_group("mpi") rank = dist.get_rank() world_size = dist.get_world_size() print(f"Running basic DDP example on rank {rank}.") # create model and move it to GPU with id rank model = ToyModel() ddp_model = DDP( model, find_unused_parameters=not True, bucket_cap_mb=1, # 1MB ) dataset = ToyDataset(input_size=10, output_size=5, length=100) sampler = torch.utils.data.distributed.DistributedSampler( dataset, num_replicas=world_size, rank=rank, shuffle=False ) data_loader = torch.utils.data.DataLoader( dataset, batch_size=2, shuffle=sampler is None, sampler=sampler ) loss_fn = nn.MSELoss() optimizer = optim.SGD(ddp_model.parameters(), lr=0.001) for data, labels in data_loader: # print("rank", rank, "data", data, "labels", labels) print("rank", rank, ": zero_grad") optimizer.zero_grad() print("rank", rank, ": forward") outputs = ddp_model(data) print("rank", rank, ": calculate loss") loss_fn(outputs, labels).backward() print("rank", rank, ": backward") optimizer.step() # dist.barrier() dist.destroy_process_group() if __name__ == "__main__": run()
実行結果
mpirun -n 2 python ddp_mpi.py
PYTORCH:ProcessGroupMPI -- rank 0: enqueu mpi:barrier PYTORCH:ProcessGroupMPI -- rank 1: enqueu mpi:barrier Running basic DDP example on rank 0. Running basic DDP example on rank 1. PYTORCH:ProcessGroupMPI -- rank 0: enqueu mpi:all_gather PYTORCH:ProcessGroupMPI -- rank 1: enqueu mpi:all_gather PYTORCH:ProcessGroupMPI -- rank 0: enqueu mpi:broadcast PYTORCH:ProcessGroupMPI -- rank 1: enqueu mpi:broadcast PYTORCH:ProcessGroupMPI -- rank 0: enqueu mpi:broadcast PYTORCH:ProcessGroupMPI -- rank 1: enqueu mpi:broadcast rank 0 : zero_grad rank 0 : forward rank 1 : zero_grad rank 1 : forward rank 1 : calculate loss rank 0 : calculate loss PYTORCH:ProcessGroupMPI -- rank 1: enqueu mpi:all_reduce PYTORCH:ProcessGroupMPI -- rank 0: enqueu mpi:all_reduce rank 1 : backward rank 0 : backward rank 1 : zero_grad rank 0 : zero_grad rank 1 : forward rank 0 : forward PYTORCH:ProcessGroupMPI -- rank 1: enqueu mpi:broadcast PYTORCH:ProcessGroupMPI -- rank 0: enqueu mpi:broadcast PYTORCH:ProcessGroupMPI -- rank 1: enqueu mpi:broadcast PYTORCH:ProcessGroupMPI -- rank 0: enqueu mpi:broadcast rank 1 : calculate loss rank 0 : calculate loss PYTORCH:ProcessGroupMPI -- rank 1: enqueu mpi:all_reduce PYTORCH:ProcessGroupMPI -- rank 0: enqueu mpi:all_reduce rank 1 : backward rank 0 : backward rank 1 : zero_grad rank 0 : zero_grad rank 1 : forward rank 0 : forward rank 1 : calculate loss rank 0 : calculate loss PYTORCH:ProcessGroupMPI -- rank 1: enqueu mpi:all_reduce PYTORCH:ProcessGroupMPI -- rank 0: enqueu mpi:all_reduce rank 1 : backward rank 0 : backward rank 1 : zero_grad rank 0 : zero_grad rank 1 : forward rank 0 : forward rank 1 : calculate loss rank 0 : calculate loss PYTORCH:ProcessGroupMPI -- rank 1: enqueu mpi:all_reduce PYTORCH:ProcessGroupMPI -- rank 0: enqueu mpi:all_reduce rank 1 : backward rank 0 : backward
結果をみる
- 初めにbarrierで同期を取って
- all_gather を実行(何のため?)
- broadcastでモデルパラメタ値を配って
- forward-backwardの1サイクルに1回、勾配をall_reduceで共有していますね
またBachNormalizationなどを行なっている場合はその統計量(平均や分散)をrank0から他のノードにbroadcastすることで、パラメタ全体の同期処理を行わずとも全てのrankで同等の順伝播計算が行えるようにしているようです。
Parameters are never broadcast between processes. The module performs an all-reduce step on gradients and assumes that they will be modified by the optimizer in all processes in the same way. Buffers (e.g. BatchNorm stats) are broadcast from the module in process of rank 0, to all other replicas in the system in every iteration. https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html