运维笔记

PyTorch 分布式训练多 GPU 配置教程 2026:别再踩那些 DDP 的坑了

AI & ML Infrastructure 技术可视化

这年头搞大模型训练,谁还没被分布式折磨过?

上周我们团队刚把一个 7B 参数的模型从单卡迁移到 8 卡 DDP,本以为照着官方文档走一遍就行——结果跑起来直接 OOM,排查了整整两天才发现是 batch_size 没按总 GPU 数缩放。这破事儿我估计不少人都干过。

PyTorch 的 DistributedDataParallel (DDP) 确实是目前最稳的多卡训练方案,但文档写得真的一言难尽。2026 年了,网上那些教程要么是 2023 年的老古董,要么就是只讲玩具代码。今天这篇东西,我准备把我这两年踩过的坑、实战中用到的配置全掏出来,希望能帮后来人省点时间。

为什么 DDP 比 DataParallel 强?

先简单说一下,PyTorch 里有两个多 GPU 方案:

  • DataParallel (DP):单进程,主卡负责收集梯度再广播。主卡显存爆炸,效率感人。
  • DistributedDataParallel (DDP):多进程,每个进程一张卡,各自算梯度然后 all-reduce。没有主卡瓶颈,能跑到接近线性加速。

结论很明确:2026 年如果你还在用 DP,赶紧换 DDP。 别问为什么,问就是 DP 在 4 卡以上基本看不到加速,白花钱。

单机多卡:最基础的配置

我们从一个最常规的场景开始——单机 4 卡。假设你已经装了 PyTorch 2.12(2026 年 5 月刚发布了 2.12 版本,修复了不少 gloo 后端的问题)。

1. 启动方式

有两种主流方式:

方式 A:torchrun(推荐)

torchrun --nproc_per_node=4 train.py

方式 B:手动 spawn

import torch.multiprocessing as mp
mp.spawn(train_fn, args=(args,), nprocs=args.n_gpus)

我个人强烈推荐 torchrun。它帮你处理了环境变量设置、错误重试这些破事,2024 年之后就已经是官方推荐了。

2. 核心代码模板

import torch
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP

def setup(rank, world_size):
    # 初始化进程组
    dist.init_process_group(
        backend='nccl',  # 单机多卡用 nccl
        init_method='env://',
        rank=rank,
        world_size=world_size
    )
    torch.cuda.set_device(rank)

def cleanup():
    dist.destroy_process_group()

def train():
    local_rank = int(os.environ['LOCAL_RANK'])
    world_size = int(os.environ['WORLD_SIZE'])
    
    setup(local_rank, world_size)
    
    model = MyModel().cuda(local_rank)
    model = DDP(model, device_ids=[local_rank])
    
    # 关键:batch_size 要除以 world_size
    train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
    dataloader = DataLoader(train_dataset, batch_size=32//world_size, 
                           sampler=train_sampler)
    
    for epoch in range(epochs):
        train_sampler.set_epoch(epoch)  # 每个 epoch 要 shuffle
        for data, target in dataloader:
            data, target = data.cuda(local_rank), target.cuda(local_rank)
            output = model(data)
            loss = criterion(output, target)
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()
    
    cleanup()

这里有个我踩过的坑: DistributedSampler 默认是 sequential 的,每个 epoch 必须调用 set_epoch(epoch) 才能打乱数据。忘了的话,你会在每个 epoch 看到完全一样的数据顺序——模型直接过拟合到 epoch 顺序上。

多机多卡:真正的挑战

单机 4 卡只是入门,多机才是地狱。

我们团队去年底搭了一个 4 节点、每节点 8 张 A100 的集群,32 卡跑起来。踩的坑可以写一本书。

启动命令

# 在 master 节点上
torchrun --nnodes=4 --nproc_per_node=8 \
    --rdzv_backend=c10d \
    --rdzv_endpoint="master_ip:29500" \
    --rdzv_id=my_job \
    train.py

# 在其他 3 个节点上(命令完全一样)
torchrun --nnodes=4 --nproc_per_node=8 \
    --rdzv_backend=c10d \
    --rdzv_endpoint="master_ip:29500" \
    --rdzv_id=my_job \
    train.py

网络配置的坑

多机训练最蛋疼的是网络。nccl 后端依赖 InfiniBand 或者 RoCE,如果你的机器之间只有千兆以太网——趁早放弃,或者老老实实调大 NCCL_TIMEOUT

我们踩过的一个真实案例:同一个配置,在 AWS 的 p4d 实例上跑得好好的,换到自家机房就各种 timeout。排查了半天发现是交换机的 MTU 没开巨帧。

2026 年的最佳实践:

  • 节点间用 InfiniBand,延迟 < 2μs
  • 没有 IB 的话,至少 100Gbps RoCE
  • NCCL_IB_DISABLE=1 可以回退到 TCP,但会慢 5-10 倍

性能对比表

配置训练速度 (steps/s)显存占用/卡加速比备注
单卡1032GB1x基准
单机 4 卡 DDP388.2GB3.8x接近线性
单机 8 卡 DDP724.3GB7.2x通信开销开始显现
2 节点 16 卡1254.3GB12.5x跨节点延迟拖累
4 节点 32 卡2204.3GB22x需要梯度压缩

结论: 单机多卡几乎线性加速,跨节点后加速比会打折扣。32 卡以上建议用 torch.distributed.optim.ZeroRedundancyOptimizer 或者 FSDP。

2026 年新特性

PyTorch 2.12 在分布式方面有几个值得关注的点:

  1. NCCL 2.22 集成:支持 NVLink 4.0 的拓扑感知,自动选择最优通信路径
  2. Fault Tolerance 改进torchrun 现在支持自动重启失败的 worker,不用手动 kill 再跑
  3. Compiled DDPtorch.compile 终于能和 DDP 配合了,之前跑编译就会出各种诡异的 bug

常见问题 FAQ

Q: 为什么我的 DDP 训练比单卡还慢? A: 大概率是你 batch_size 设太小了。每张卡上的 batch_size 太小会导致计算时间小于通信时间,GPU 一直在等数据。经验值:每张卡至少 8-16 个样本。

Q: 多机训练一直 timeout 怎么办? A: 先检查网络:ping 延迟 < 1ms,带宽 > 25Gbps。然后用 NCCL_DEBUG=INFO 启动,看是卡在哪个阶段。常见原因是防火墙没开 29500 端口。

Q: 模型太大放不进单卡显存怎么办? A: 2026 年最实际的做法是用 FSDP(Fully Sharded Data Parallel),或者 ZeRO-3。DDP 要求每张卡都能放下完整模型,FSDP 会把参数分片到各卡。

Q: 用 torchrun 启动后进程直接报错退出怎么办? A: 检查环境变量:MASTER_ADDRMASTER_PORTWORLD_SIZERANKtorchrun 会自动设置这些,但如果你在脚本里手动覆盖了,就会出问题。


这套配置我们用了大半年,跑过从 7B 到 70B 的各种模型。说实话,DDP 的坑虽然多,但一旦摸清楚规律,它就是目前最稳的多卡训练方案。不像某些框架(不点名了),每次升级都要重写一遍代码。

最后给个建议:先在小规模(2 卡)上把整个流程跑通,再上生产。别一上来就 32 卡,出了错你连日志都看不过来。