TorchDDP
1. 概述
torch 的 DDP ( DistributedDataParallel 分布式数据并行) 实现的数据并行,支持单机和多机训练,同时也能与模型并行一起工作。
入门可以参考 torch 官网教程 (opens in a new tab)
2. 基本用例
参考 torch 官网教程 (opens in a new tab) 提供的一个简单的 DDP 用例,可以使用如下示例并创建一个名为 ddp.py 文件。
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.optim as optim
from torch.nn.parallel import DistributedDataParallel as DDP
class ToyModel(nn.Module):
def __init__(self):
super(ToyModel, self).__init__()
self.net1 = nn.Linear(10, 10)
self.relu = nn.ReLU()
self.net2 = nn.Linear(10, 5)
def forward(self, x):
return self.net2(self.relu(self.net1(x)))
def demo_basic():
dist.init_process_group("nccl")
rank = dist.get_rank()
print(f"Start running basic DDP example on rank {rank}.")
# create model and move it to GPU with id rank
device_id = rank % torch.cuda.device_count()
model = ToyModel().to(device_id)
ddp_model = DDP(model, device_ids=[device_id])
loss_fn = nn.MSELoss()
optimizer = optim.SGD(ddp_model.parameters(), lr=0.001)
optimizer.zero_grad()
outputs = ddp_model(torch.randn(20, 10))
print(f"outputs:{outputs}")
labels = torch.randn(20, 5).to(device_id)
loss_fn(outputs, labels).backward()
optimizer.step()
dist.destroy_process_group()
if __name__ == "__main__":
demo_basic()
3. 单机多卡
运行 DDP,可以使用 torchrun 或 python -m torch.distributed.launch 两种方式,注意使用 torchrun 需要 torch 版本 ≥ 1.11 。
使用 --nproc_per_node 每个节点运行的进程数量,通常跟运行节点卡数一致,如下使用 2 个 GPU 运行。
torchrun --nproc_per_node=2 ddp.py
# 或
python -m torch.distributed.launch --nproc_per_node=2 ddp.py