サブロウ丸

Sabrou-mal サブロウ丸

主にプログラミングと数学

DistributedDataParallel (pytorch) サンプルコード

本稿ではDistributedDataParallelのサンプルコードを示し、また実行中にどのような通信が行われているかを確認します。

参考:

サンプルコード

GLOO backend

並列数はプログラム内のworld_size変数を直接編集。 python (プログラム).py で実行

import os
import sys
import tempfile
import numpy as np
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.optim as optim
import torch.multiprocessing as mp

from torch.nn.parallel import DistributedDataParallel as DDP


def setup(rank, world_size):
    os.environ["MASTER_ADDR"] = "localhost"
    os.environ["MASTER_PORT"] = "12355"

    # initialize the process group
    dist.init_process_group("gloo", rank=rank, world_size=world_size)


def cleanup():
    dist.destroy_process_group()


class ToyDataset(torch.utils.data.Dataset):
    def __init__(self, input_size, output_size, length):
        self.len = length
        self.data = torch.randn(length, input_size)
        self.label = torch.randn(length, output_size)

    def __getitem__(self, index):
        return self.data[index], self.label[index]

    def __len__(self):
        return self.len


class ToyModel(nn.Module):
    def __init__(self):
        super(ToyModel, self).__init__()
        self.net1 = nn.Linear(10, 10)
        self.relu = nn.ReLU()
        self.net2 = nn.Linear(10, 5)

    def forward(self, x):
        return self.net2(self.relu(self.net1(x)))


def demo_basic(rank, world_size):
    setup(rank, world_size)
    print(f"Running basic DDP example on rank {rank}.")

    # create model and move it to GPU with id rank
    model = ToyModel()
    ddp_model = DDP(model)

    dataset = ToyDataset(input_size=10, output_size=5, length=100)
    sampler = torch.utils.data.distributed.DistributedSampler(
        dataset, num_replicas=world_size, rank=rank, shuffle=False
    )
    data_loader = torch.utils.data.DataLoader(
        dataset, batch_size=5, shuffle=sampler is None, sampler=sampler
    )

    loss_fn = nn.MSELoss()
    optimizer = optim.SGD(ddp_model.parameters(), lr=0.001)

    for data, labels in data_loader:
        print("rank", rank, "data", data, "labels", labels)
        optimizer.zero_grad()
        outputs = ddp_model(data)
        loss_fn(outputs, labels).backward()
        optimizer.step()

    cleanup()


def run_demo(demo_fn, world_size):
    mp.spawn(demo_fn, args=(world_size,), nprocs=world_size, join=True)


if __name__ == "__main__":
    world_size = 2
    run_demo(demo_basic, world_size)

MPI backend

並列数は mpirun -n (並列数) python (プログラム).py というように実行時のmpirunのオプションで指定。

import os
import sys
import tempfile
import numpy as np
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.optim as optim
import torch.multiprocessing as mp

from torch.nn.parallel import DistributedDataParallel as DDP


def setup():
    # initialize the process group
    dist.init_process_group("mpi")
    rank = dist.get_rank()
    world_size = dist.get_world_size()
    return rank, world_size


def cleanup():
    dist.destroy_process_group()


class ToyDataset(torch.utils.data.Dataset):
    def __init__(self, input_size, output_size, length):
        self.len = length
        self.data = torch.randn(length, input_size)
        self.label = torch.randn(length, output_size)

    def __getitem__(self, index):
        return self.data[index], self.label[index]

    def __len__(self):
        return self.len


class ToyModel(nn.Module):
    def __init__(self):
        super(ToyModel, self).__init__()
        self.net1 = nn.Linear(10, 10)
        self.relu = nn.ReLU()
        self.net2 = nn.Linear(10, 5)

    def forward(self, x):
        return self.net2(self.relu(self.net1(x)))


def demo_basic():
    rank, world_size = setup()
    print(f"Running basic DDP example on rank {rank}.")

    # create model and move it to GPU with id rank
    model = ToyModel()
    ddp_model = DDP(model)

    dataset = ToyDataset(input_size=10, output_size=5, length=10)
    sampler = torch.utils.data.distributed.DistributedSampler(
        dataset, num_replicas=world_size, rank=rank, shuffle=False
    )
    data_loader = torch.utils.data.DataLoader(
        dataset, batch_size=2, shuffle=sampler is None, sampler=sampler
    )

    loss_fn = nn.MSELoss()
    optimizer = optim.SGD(ddp_model.parameters(), lr=0.001)

    for data, labels in data_loader:
        print("rank", rank, "data", data, "labels", labels)
        optimizer.zero_grad()
        outputs = ddp_model(data)
        loss_fn(outputs, labels).backward()
        optimizer.step()

    cleanup()


def run_demo(demo_fn):
    demo_fn()


if __name__ == "__main__":
    run_demo(demo_basic)




通信確認

torch/csrc/distributed/c10d/ProcessGroupMPI.cppを次のように変更してコンパイル。 実行中にどのタイミングでどの通信が行われるかを確認してみる。次のようにプログラムの一部を変更して行われる通信を表示する。

diff --git a/torch/csrc/distributed/c10d/ProcessGroupMPI.cpp b/torch/csrc/distributed/c10d/ProcessGroupMPI.cpp
index 556ab13887..e9a84e6f87 100644
--- a/torch/csrc/distributed/c10d/ProcessGroupMPI.cpp
+++ b/torch/csrc/distributed/c10d/ProcessGroupMPI.cpp
@@ -374,6 +376,7 @@ c10::intrusive_ptr<ProcessGroup::Work> ProcessGroupMPI::enqueue(
     const c10::optional<std::vector<at::Tensor>>& inputTensors) {
   auto work = c10::make_intrusive<WorkMPI>(entry->dst, profilingTitle, inputTensors);
   std::unique_lock<std::mutex> lock(pgMutex_);
+  std::cout << "PYTORCH:ProcessGroupMPI -- rank " << getRank() << ": enqueu " << profilingTitle << std::endl;
   queue_.push_back(std::make_tuple(std::move(entry), work));
   lock.unlock();
   queueProduceCV_.notify_one();

ビルド方法は下記を参考。

実行コード

ddp_mpi.py

import os
import sys
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.optim as optim

from torch.nn.parallel import DistributedDataParallel as DDP


class ToyDataset(torch.utils.data.Dataset):
    def __init__(self, input_size, output_size, length):
        self.len = length
        self.data = torch.randn(length, input_size)
        self.label = torch.randn(length, output_size)

    def __getitem__(self, index):
        return self.data[index], self.label[index]

    def __len__(self):
        return self.len


class ToyModel(nn.Module):
    def __init__(self):
        super(ToyModel, self).__init__()
        self.net1 = nn.Linear(10, 10000)
        self.relu1 = nn.ReLU()
        self.net2 = nn.Linear(10000, 10000)
        self.relu2 = nn.ReLU()
        self.net3 = nn.Linear(10000, 5)

    def forward(self, x):
        x = self.relu1(self.net1(x))
        x = self.relu2(self.net2(x))
        x = self.net3(x)
        return x


def run():
    # initialize the process group
    dist.init_process_group("mpi")
    rank = dist.get_rank()
    world_size = dist.get_world_size()
    print(f"Running basic DDP example on rank {rank}.")

    # create model and move it to GPU with id rank
    model = ToyModel()
    ddp_model = DDP(
        model,
        find_unused_parameters=not True,
        bucket_cap_mb=1,  # 1MB
    )

    dataset = ToyDataset(input_size=10, output_size=5, length=100)
    sampler = torch.utils.data.distributed.DistributedSampler(
        dataset, num_replicas=world_size, rank=rank, shuffle=False
    )
    data_loader = torch.utils.data.DataLoader(
        dataset, batch_size=2, shuffle=sampler is None, sampler=sampler
    )

    loss_fn = nn.MSELoss()
    optimizer = optim.SGD(ddp_model.parameters(), lr=0.001)

    for data, labels in data_loader:
        # print("rank", rank, "data", data, "labels", labels)
        print("rank", rank, ": zero_grad")
        optimizer.zero_grad()
        print("rank", rank, ": forward")
        outputs = ddp_model(data)
        print("rank", rank, ": calculate loss")
        loss_fn(outputs, labels).backward()
        print("rank", rank, ": backward")
        optimizer.step()
        # dist.barrier()

    dist.destroy_process_group()


if __name__ == "__main__":
    run()

実行結果

mpirun -n 2 python ddp_mpi.py
PYTORCH:ProcessGroupMPI -- rank 0: enqueu mpi:barrier
PYTORCH:ProcessGroupMPI -- rank 1: enqueu mpi:barrier
Running basic DDP example on rank 0.
Running basic DDP example on rank 1.
PYTORCH:ProcessGroupMPI -- rank 0: enqueu mpi:all_gather
PYTORCH:ProcessGroupMPI -- rank 1: enqueu mpi:all_gather
PYTORCH:ProcessGroupMPI -- rank 0: enqueu mpi:broadcast
PYTORCH:ProcessGroupMPI -- rank 1: enqueu mpi:broadcast
PYTORCH:ProcessGroupMPI -- rank 0: enqueu mpi:broadcast
PYTORCH:ProcessGroupMPI -- rank 1: enqueu mpi:broadcast
rank 0 : zero_grad
rank 0 : forward
rank 1 : zero_grad
rank 1 : forward
rank 1 : calculate loss
rank 0 : calculate loss
PYTORCH:ProcessGroupMPI -- rank 1: enqueu mpi:all_reduce
PYTORCH:ProcessGroupMPI -- rank 0: enqueu mpi:all_reduce
rank 1 : backward
rank 0 : backward
rank 1 : zero_grad
rank 0 : zero_grad
rank 1 : forward
rank 0 : forward
PYTORCH:ProcessGroupMPI -- rank 1: enqueu mpi:broadcast
PYTORCH:ProcessGroupMPI -- rank 0: enqueu mpi:broadcast
PYTORCH:ProcessGroupMPI -- rank 1: enqueu mpi:broadcast
PYTORCH:ProcessGroupMPI -- rank 0: enqueu mpi:broadcast
rank 1 : calculate loss
rank 0 : calculate loss
PYTORCH:ProcessGroupMPI -- rank 1: enqueu mpi:all_reduce
PYTORCH:ProcessGroupMPI -- rank 0: enqueu mpi:all_reduce
rank 1 : backward
rank 0 : backward
rank 1 : zero_grad
rank 0 : zero_grad
rank 1 : forward
rank 0 : forward
rank 1 : calculate loss
rank 0 : calculate loss
PYTORCH:ProcessGroupMPI -- rank 1: enqueu mpi:all_reduce
PYTORCH:ProcessGroupMPI -- rank 0: enqueu mpi:all_reduce
rank 1 : backward
rank 0 : backward
rank 1 : zero_grad
rank 0 : zero_grad
rank 1 : forward
rank 0 : forward
rank 1 : calculate loss
rank 0 : calculate loss
PYTORCH:ProcessGroupMPI -- rank 1: enqueu mpi:all_reduce
PYTORCH:ProcessGroupMPI -- rank 0: enqueu mpi:all_reduce
rank 1 : backward
rank 0 : backward

結果をみる

  1. 初めにbarrierで同期を取って
  2. all_gather を実行(何のため?)
  3. broadcastでモデルパラメタ値を配って
  4. forward-backwardの1サイクルに1回、勾配をall_reduceで共有していますね

またBachNormalizationなどを行なっている場合はその統計量(平均や分散)をrank0から他のノードにbroadcastすることで、パラメタ全体の同期処理を行わずとも全てのrankで同等の順伝播計算が行えるようにしているようです。

Parameters are never broadcast between processes. The module performs an all-reduce step on gradients and assumes that they will be modified by the optimizer in all processes in the same way. Buffers (e.g. BatchNorm stats) are broadcast from the module in process of rank 0, to all other replicas in the system in every iteration. https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html