原文:https://pytorch.org/tutorials/intermediate/ddp_tutorial.html
作者:Shen Li
编辑:Joe Zhu
先决条件:
DistributedDataParallel
(DDP)在模块级别实现可在多台计算机上运行的数据并行性。 使用 DDP 的应用应产生多个进程,并为每个进程创建一个 DDP 实例。 DDP 在torch.distributed
包中使用集体通信来同步梯度和缓冲区。 更具体地说,DDP 为model.parameters()
给定的每个参数注册一个 Autograd 挂钩,当在后向传递中计算相应的梯度时,挂钩将触发。 然后,DDP 使用该信号触发跨进程的梯度同步。 有关更多详细信息,请参考 DDP 设计说明。
推荐的使用 DDP 的方法是为每个模型副本生成一个进程,其中一个模型副本可以跨越多个设备。 DDP 进程可以放在同一台计算机上,也可以在多台计算机上,但是 GPU 设备不能在多个进程之间共享。 本教程从一个基本的 DDP 用例开始,然后演示了更高级的用例,包括检查点模型以及将 DDP 与模型并行结合。
注意
本教程中的代码在 8-GPU 服务器上运行,但可以轻松地推广到其他环境。
在深入探讨之前,让我们澄清一下为什么尽管增加了复杂性,但还是考虑使用DistributedDataParallel
而不是DataParallel
:
- 首先,
DataParallel
是单进程,多线程,并且只能在单台机器上运行,而DistributedDataParallel
是多进程,并且适用于单机和多机训练。 即使在单台机器上,DataParallel
通常也比DistributedDataParallel
慢,这是因为跨线程的 GIL 争用,每次迭代复制的模型以及分散输入和收集输出所带来的额外开销。 - 回顾先前的教程,如果模型太大而无法容纳在单个 GPU 上,则必须使用模型并行将其拆分到多个 GPU 中。
DistributedDataParallel
与模型并行一起使用;DataParallel
目前没有。 当 DDP 与模型并行组合时,每个 DDP 进程将并行使用模型,而所有进程共同将并行使用数据。 - 如果您的模型需要跨越多台机器,或者您的用例不适合数据并行性范式,请参阅 RPC API ,以获得更多通用的分布式训练支持。
要创建 DDP 模块,请首先正确设置过程组。 更多细节可以在用 PyTorch 编写分布式应用中找到。
import os
import sys
import tempfile
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.optim as optim
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP
def setup(rank, world_size):
if sys.platform == 'win32':
# Distributed package only covers collective communications with Gloo
# backend and FileStore on Windows platform. Set init_method parameter
# in init_process_group to a local file.
# Example init_method="file:///f:/libtmp/some_file"
init_method="file:///{your local file path}"
# initialize the process group
dist.init_process_group(
"gloo",
init_method=init_method,
rank=rank,
world_size=world_size
)
else:
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '12355'
# initialize the process group
dist.init_process_group("gloo", rank=rank, world_size=world_size)
def cleanup():
dist.destroy_process_group()
现在,让我们创建一个玩具模块,将其与 DDP 封装在一起,并提供一些虚拟输入数据。 请注意,由于 DDP 会将模型状态从等级 0 进程广播到 DDP 构造器中的所有其他进程,因此您不必担心不同的 DDP 进程从不同的模型参数初始值开始。
class ToyModel(nn.Module):
def __init__(self):
super(ToyModel, self).__init__()
self.net1 = nn.Linear(10, 10)
self.relu = nn.ReLU()
self.net2 = nn.Linear(10, 5)
def forward(self, x):
return self.net2(self.relu(self.net1(x)))
def demo_basic(rank, world_size):
print(f"Running basic DDP example on rank {rank}.")
setup(rank, world_size)
# create model and move it to GPU with id rank
model = ToyModel().to(rank)
ddp_model = DDP(model, device_ids=[rank])
loss_fn = nn.MSELoss()
optimizer = optim.SGD(ddp_model.parameters(), lr=0.001)
optimizer.zero_grad()
outputs = ddp_model(torch.randn(20, 10))
labels = torch.randn(20, 5).to(rank)
loss_fn(outputs, labels).backward()
optimizer.step()
cleanup()
def run_demo(demo_fn, world_size):
mp.spawn(demo_fn,
args=(world_size,),
nprocs=world_size,
join=True)
如您所见,DDP 包装了较低级别的分布式通信详细信息,并提供了干净的 API,就好像它是本地模型一样。 梯度同步通信发生在反向传递过程中,并且与反向计算重叠。 当backward()
返回时,param.grad
已经包含同步梯度张量。 对于基本用例,DDP 仅需要几个 LoC 即可设置流程组。 在将 DDP 应用到更高级的用例时,需要注意一些警告。
在 DDP 中,构造器,正向传播和反向传递都是分布式同步点。 预期不同的进程将启动相同数量的同步,并以相同的顺序到达这些同步点,并在大致相同的时间进入每个同步点。 否则,快速流程可能会提早到达,并在等待流浪者时超时。 因此,用户负责平衡流程之间的工作负载分配。 有时,由于例如网络延迟,资源争夺,不可预测的工作量峰值,不可避免地会出现处理速度偏差。 为了避免在这种情况下超时,请在调用init_process_group
时传递足够大的timeout
值。
在训练过程中通常使用torch.save
和torch.load
来检查点模块并从检查点中恢复。 有关更多详细信息,请参见保存和加载模型。 使用 DDP 时,一种优化方法是仅在一个进程中保存模型,然后将其加载到所有进程中,从而减少写开销。 这是正确的,因为所有过程都从相同的参数开始,并且梯度在反向传播中同步,因此优化程序应将参数设置为相同的值。 如果使用此优化,请确保在保存完成之前不要启动所有进程。 此外,在加载模块时,您需要提供适当的map_location
参数,以防止进程进入其他设备。 如果缺少map_location
,则torch.load
将首先将模块加载到 CPU,然后将每个参数复制到保存位置,这将导致同一台机器上的所有进程使用相同的设备集。 有关更高级的故障恢复和弹性支持,请参考这里。
def demo_checkpoint(rank, world_size):
print(f"Running DDP checkpoint example on rank {rank}.")
setup(rank, world_size)
model = ToyModel().to(rank)
ddp_model = DDP(model, device_ids=[rank])
loss_fn = nn.MSELoss()
optimizer = optim.SGD(ddp_model.parameters(), lr=0.001)
CHECKPOINT_PATH = tempfile.gettempdir() + "/model.checkpoint"
if rank == 0:
# All processes should see same parameters as they all start from same
# random parameters and gradients are synchronized in backward passes.
# Therefore, saving it in one process is sufficient.
torch.save(ddp_model.state_dict(), CHECKPOINT_PATH)
# Use a barrier() to make sure that process 1 loads the model after process
# 0 saves it.
dist.barrier()
# configure map_location properly
map_location = {'cuda:%d' % 0: 'cuda:%d' % rank}
ddp_model.load_state_dict(
torch.load(CHECKPOINT_PATH, map_location=map_location))
optimizer.zero_grad()
outputs = ddp_model(torch.randn(20, 10))
labels = torch.randn(20, 5).to(rank)
loss_fn = nn.MSELoss()
loss_fn(outputs, labels).backward()
optimizer.step()
# Not necessary to use a dist.barrier() to guard the file deletion below
# as the AllReduce ops in the backward pass of DDP already served as
# a synchronization.
if rank == 0:
os.remove(CHECKPOINT_PATH)
cleanup()
DDP 还可以与多 GPU 模型一起使用。 当训练具有大量数据的大型模型时,DDP 包装多 GPU 模型特别有用。
class ToyMpModel(nn.Module):
def __init__(self, dev0, dev1):
super(ToyMpModel, self).__init__()
self.dev0 = dev0
self.dev1 = dev1
self.net1 = torch.nn.Linear(10, 10).to(dev0)
self.relu = torch.nn.ReLU()
self.net2 = torch.nn.Linear(10, 5).to(dev1)
def forward(self, x):
x = x.to(self.dev0)
x = self.relu(self.net1(x))
x = x.to(self.dev1)
return self.net2(x)
将多 GPU 模型传递给 DDP 时,不得设置device_ids
和output_device
。 输入和输出数据将通过应用或模型forward()
方法放置在适当的设备中。
def demo_model_parallel(rank, world_size):
print(f"Running DDP with model parallel example on rank {rank}.")
setup(rank, world_size)
# setup mp_model and devices for this process
dev0 = rank * 2
dev1 = rank * 2 + 1
mp_model = ToyMpModel(dev0, dev1)
ddp_mp_model = DDP(mp_model)
loss_fn = nn.MSELoss()
optimizer = optim.SGD(ddp_mp_model.parameters(), lr=0.001)
optimizer.zero_grad()
# outputs will be on dev1
outputs = ddp_mp_model(torch.randn(20, 10))
labels = torch.randn(20, 5).to(dev1)
loss_fn(outputs, labels).backward()
optimizer.step()
cleanup()
if __name__ == "__main__":
n_gpus = torch.cuda.device_count()
if n_gpus < 8:
print(f"Requires at least 8 GPUs to run, but got {n_gpus}.")
else:
run_demo(demo_basic, 8)
run_demo(demo_checkpoint, 8)
run_demo(demo_model_parallel, 4)