サブロウ丸

Sabrou-mal サブロウ丸

主にプログラミングと数学

PytorchのDeviceMesh

DeviceMeshはGPUなどのリソースグループを管理するツールで、これを使えば分散学習に割り当てるGPUリソースを柔軟に割り当てられる。

分散並列手法にはいくつかの種類があり、大まかにデータ並列とモデル並列の二つがある。LLMのようなパラメータ数が多いモデルを学習する際には、これらを組み合わせて実行したいことがある。その際、サーバー内ではモデル並列、サーバー間でデータ並列を行うといった細かい指定が可能だ。

https://pytorch.org/tutorials/_images/device_mesh.png

図は以下のDeviceMeshドキュメントから引用。

Getting Started with DeviceMesh — PyTorch Tutorials 2.7.0+cu126 documentation


DeviceMeshではデバイス(計算リソース)が多次元行列の形式で並んでいることを前提としている。例えば、6つのデバイスが2行3列で並んでいると考える(2枚のGPUを持つサーバーが3つある状況を想定するとわかりやすい)。これは2次元の例だが、3次元以上で計算リソースを整理することも可能だ。

以下では、この設定、つまり2枚のGPUを持つサーバーが3つある状況を想定して説明する。次はDeviceMeshを作成するテストスクリプトだ。

import torch
import torch.distributed as dist
from torch.distributed.device_mesh import DeviceMesh

def main():
    dist.init_process_group(backend="gloo")

    rank = dist.get_rank()
    world_size = dist.get_world_size()

    mesh = DeviceMesh(
        device_type="cpu",
        mesh=[[0, 1], [2, 3], [4, 5]],
        mesh_dim_names=("row", "col")
    )

    import time; time.sleep(rank)
    print(f"Rank {rank}: mesh['row'] = {mesh['row']}")
    print(f"Rank {rank}: mesh['col'] = {mesh['col']}")

    dist.destroy_process_group()

if __name__ == "__main__":
    main()
  • 実行コマンド
    • torchrun --nproc_per_node=6 run_device_mesh.py
      • (3つのサーバーがある環境を想定したが、簡単な説明のためこのコマンド自体は一つのホストで6つのプロセスを実行するものになっている)
  • DeviceMeshの作成について
    • 内部処理の都合上、分散処理の準備(プロセスグループの作成や適切な環境変数)が整っている必要がある
    • mesh=[[0, 1], [2, 3], [4, 5]] のようにデバイス番号を与える。もし2つのGPUを持つサーバーが3つある場合は、rankが若いサーバーからデバイス番号が認識される。
      • 0 ... rank 0 サーバーの0 番GPU
      • 1 ... rank 0 サーバーの1番GPU
      • 2 ... rank 1 サーバーの0番GPU)
      • ...
    • mesh['row']で行方向のデバイス番号を取得する。実行するデバイスごとに結果が異なる。上記のシチュエーションで行くと、自身と同じサーバーにあるデバイスの番号を取得できる。
    • mesh['col']で列方向のデバイス番号を取得。これも実行するデバイスごとに結果が異なる。上記のシチュエーションで行くと、複数サーバーをまたがって同じローカル番号のデバイスの番号を取得できる。

このスクリプトの出力結果は以下のようになる。

Rank 0: mesh['row'] = DeviceMesh('cpu', [0, 1], mesh_dim_names=('row',))
Rank 0: mesh['col'] = DeviceMesh('cpu', [0, 2, 4], mesh_dim_names=('col',))
Rank 1: mesh['row'] = DeviceMesh('cpu', [0, 1], mesh_dim_names=('row',))
Rank 1: mesh['col'] = DeviceMesh('cpu', [1, 3, 5], mesh_dim_names=('col',))
Rank 2: mesh['row'] = DeviceMesh('cpu', [2, 3], mesh_dim_names=('row',))
Rank 2: mesh['col'] = DeviceMesh('cpu', [0, 2, 4], mesh_dim_names=('col',))
Rank 3: mesh['row'] = DeviceMesh('cpu', [2, 3], mesh_dim_names=('row',))
Rank 3: mesh['col'] = DeviceMesh('cpu', [1, 3, 5], mesh_dim_names=('col',))
...

DeviceMeshを使ったデータ並列とモデル並列

モデル並列(テンソル並列; TP)

モジュールごとに分散処理の方式を指定できる。モジュールは線形層(nn.Linear)やモデル(nn.Model)などを含む抽象概念。ここでは上記のdevice meshを使って説明する。

parallelize_module

たとえば、同じサーバー内の2つのGPU間でTPを行いたい場合は、下記のようになる。

    m = torch.nn.Linear(10, 10)  # 線形層を作成
    tp_mesh = mesh["row"]  # TPにしたい次元
    m = parallelize_module(
        module=m,
        device_mesh=tp_mesh,
        parallelize_plan=ColwiseParallel()
    )

ここで

tp_mesh = mesh["col"]

とすると、RuntimeError: ('Found TP device_mesh on the 0 dimension of its parent mesh.', 'Currently we only support intranode TP and TP needs to be the innermost dimension on its parent mesh.')というエラーが出る。device_mesh の一番外側(つまり距離的には近い)デバイスしかTPが使えないとある。親切だなぁ(?)。TP自体はテンソル演算ごとに通信を必要とするため、確かに近いデバイス間でTPを行うのが一般的な戦略である。

ParallelStyle

parallelize_moduleでは並列化の方式を指定する。https://github.com/pytorch/pytorch/blob/cbcf6772231a2c216be707627b6613e8c79a86ed/torch/distributed/tensor/parallel/style.py#L45

次のような方式が使えるようだ。

class ColwiseParallel(ParallelStyle):
class RowwiseParallel(ParallelStyle):
class SequenceParallel(ParallelStyle):
class PrepareModuleInput(ParallelStyle):
class PrepareModuleOutput(ParallelStyle):
class PrepareModuleInputOutput(ParallelStyle):
  1. ColwiseParallel
    • 入力をbroadcast(コピー)
    • 独立にテンソル演算
Input → [ Linear (W_col_0) | Linear (W_col_1) ]
                   ↓                ↓
             Output_0           Output_1
  1. RowwiseParallel
[ Input_0 | Input_1 ] → [ Linear (W_row_0) | Linear (W_row_1) ]
                            ↓                   ↓
                         Partial Output   Partial Output
                               ↓              ↓
                             (Reduce)
                               ↓
                            Final Output

1層目はRowwiseParallelで、2層目はColwiseParallelで、みたいな細かい指定も可能。

    m1 = torch.nn.Linear(10, 10)
    m2 = torch.nn.Linear(10, 10)
    m = torch.nn.Sequential(m1, m2)
    tp_mesh = mesh["row"]
    pm = parallelize_module(module=m, device_mesh=tp_mesh, parallelize_plan={"m1": RowwiseParallel(), "m2": ColwiseParallel()})
    print(pm)

DP (Data Parallel) + TP

データ並列とテンソル並列を両方用いるパターン。といっても、DDPを被せるだけ。注意する点としてモデルのモジュールにTPを設定する前に、モデルをDDPモデルに変換しておく必要がある。

   # Define the model
    model = torch.nn.Sequential(
        torch.nn.Linear(10, 10),
        torch.nn.ReLU(),
        torch.nn.Linear(10, 1),
    )

    # Data Parallelism setup
    dp_mesh = mesh["col"]
    dp_group = dp_mesh.get_group()
    print(f"Rank {rank}: dp_group = {dp_group}")
    
    # First, apply DDP
    model = DDP(model, process_group=dp_group)

    # Then, apply Tensor Parallelism
    tp_mesh = mesh["row"]
    model.module = parallelize_module(
        module=model.module,
        device_mesh=tp_mesh,
        parallelize_plan={
            "0": ColwiseParallel(),
            "2": RowwiseParallel(),
        },
    )

あとはサンプラーの作成する。この場合はrow (=サーバー内)でTPを行っているため、同じサーバー内のデバイスには同じ入力データが入力されるようにしなければならない。以下のようにDPの個数をnum_replicasで指定し、自身のrankがDPの何番目のレプリカであるかをdp_rankとして入力してやれば良い。

    global_rank = dist.get_rank()
    dp_rank = dist.get_rank(group=dp_group)
    dp_world_size = dist.get_world_size(group=dp_group)
    
    sampler = DistributedSampler(
        dataset,
        num_replicas=dp_world_size,
        rank=dp_rank,
        shuffle=True,
        drop_last=False
    )

DistributedSamplerのコード

import torch
from torch.nn.parallel import DistributedDataParallel as DDP
import torch.distributed as dist
from torch.distributed.device_mesh import DeviceMesh
from torch.distributed.tensor.parallel import parallelize_module, ColwiseParallel, RowwiseParallel

class LinearWrapper(torch.nn.Linear):
    ...
class LinearWrapper1(LinearWrapper):
    ...
class LinearWrapper2(LinearWrapper):
    ...
class LinearWrapper3(LinearWrapper):
    ...


class SimpleModel(torch.nn.Module):
    def __init__(self, rank):
        super(SimpleModel, self).__init__()
        self.rank = rank
        self.linear1 = LinearWrapper1(10, 10)
        self.linear2 = LinearWrapper2(10, 10)
        self.linear3 = LinearWrapper3(10, 1)

        # Register backward hooks for all sequential layers
        def backward_hook(module, grad_input, grad_output):
            print(f"[rank{self.rank}] Backward pass - {module.__class__.__name__}")
            return grad_input

        def forward_hook(module, input, output):
            print(f"[rank{self.rank}] Forward pass - {module.__class__.__name__}")
            return output

        self.linear1.register_forward_hook(forward_hook)
        self.linear2.register_forward_hook(forward_hook)
        self.linear3.register_forward_hook(forward_hook)
        self.linear1.register_full_backward_hook(backward_hook)
        self.linear2.register_full_backward_hook(backward_hook)
        self.linear3.register_full_backward_hook(backward_hook)

    def forward(self, x):
        x = torch.relu(self.linear1(x))
        x = torch.relu(self.linear2(x))
        x = torch.relu(self.linear3(x))
        return x


def main():
    dist.init_process_group(backend="gloo")

    rank = dist.get_rank()
    world_size = dist.get_world_size()

    if rank == 0:
        import os
        os.environ["TORCH_CPP_LOG_LEVEL"] = "INFO"
        os.environ["TORCH_DISTRIBUTED_DEBUG"] = "DETAIL"

    print(f"Rank {rank}/{world_size} is running.")

    mesh = DeviceMesh(
        device_type="cpu",
        mesh=[[0, 1], [2, 3], [4, 5]],
        mesh_dim_names=("col", "row")
    )

    import time; time.sleep(rank * 0.5)
    print(f"Rank {rank}: mesh['row'] = {mesh['row']}")
    print(f"Rank {rank}: mesh['col'] = {mesh['col']}")

    # Define the model
    # model = torch.nn.Sequential(
    #     torch.nn.Linear(10, 10),
    #     torch.nn.ReLU(),
    #     torch.nn.Linear(10, 1),
    # )
    model = SimpleModel(rank)

    # Data Parallelism setup
    dp_mesh = mesh["col"]
    dp_group = dp_mesh.get_group()
    print(f"Rank {rank}: dp_group = {dp_group}")
    
    # First, apply DDP
    model = DDP(
        model,
        process_group=dp_group,
        find_unused_parameters=True
    )

    # Then, apply Tensor Parallelism
    tp_mesh = mesh["row"]
    model.module = parallelize_module(
        module=model.module,
        device_mesh=tp_mesh,
        parallelize_plan={
            "linear1": ColwiseParallel(),
            "linear2": RowwiseParallel(),
        },
    )

    # training step
    model.train()
    optimizer = torch.optim.SGD(model.parameters(), lr=0.01)

    get_local_root_rank = lambda mesh: mesh.mesh[0].item()
    if rank == get_local_root_rank(tp_mesh):
        input = torch.randn(3, 10)
    else:
        input = torch.empty(3, 10)
    dist.broadcast(input, src=get_local_root_rank(tp_mesh), group=tp_mesh.get_group())

    for _ in range(10):
        output = model(input)
        loss = output.sum()
        print(f"Rank {rank}: loss = {loss}")
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
    
    # finalizes
    dist.destroy_process_group()

if __name__ == "__main__":
    main()

TORCH_CPP_LOG_LEVEL=INFO TORCH_DISTRIBUTED_DEBUG=DETAILの環境変数を設定して、呼び出された集団通信のログが表示される。