Skip to content

Commit

Permalink
[FSDP2] Fixed 2D mismatched grad placements (pytorch#136237)
Browse files Browse the repository at this point in the history
```
CUDA_VISIBLE_DEVICES=2,3,6,7 pytest test/distributed/_composable/test_composability/test_2d_composability.py -k test_train_parity_2d_transformer
```

Differential Revision: [D62964658](https://our.internmc.facebook.com/intern/diff/D62964658)
Pull Request resolved: pytorch#136237
Approved by: https://github.com/weifengpy
  • Loading branch information
awgu authored and pytorchmergebot committed Sep 19, 2024
1 parent 4ea741d commit 65df26f
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,68 @@ def _test_train_parity_2d_mlp(
_optim.step()
self.assertEqual(losses[0], losses[1])

@skip_if_lt_x_gpu(2)
@skipIfRocm
def test_train_parity_2d_transformer(self):
torch.manual_seed(42)
model_args = ModelArgs(n_layers=3, dropout_p=0.0)
model = Transformer(model_args)
ref_model = copy.deepcopy(model).cuda()
ref_optim = torch.optim.AdamW(ref_model.parameters(), lr=1e-2)

dp_size, tp_size = self.world_size // 2, 2
global_mesh = init_device_mesh(
"cuda", (dp_size, tp_size), mesh_dim_names=("dp", "tp")
)
model = Transformer.parallelize(model, global_mesh["tp"], use_seq_parallel=True)

for layer in model.layers:
fully_shard(layer, mesh=global_mesh["dp"])
fully_shard(model, mesh=global_mesh["dp"])
optim = torch.optim.AdamW(model.parameters(), lr=1e-2)

for param, ref_param in zip(model.parameters(), ref_model.parameters()):
full_param = param.full_tensor()
self.assertEqual(full_param, ref_param)

torch.manual_seed(42 + global_mesh.get_local_rank("dp"))
inp = torch.randint(0, model_args.vocab_size, (2, 16), device="cuda")
for iter_idx in range(5):
ref_loss = ref_model(inp).sum()
loss = model(inp).sum()
self.assertEqual(ref_loss, loss)
ref_loss.backward()
loss.backward()
for param in ref_model.parameters():
if param.grad is not None:
dist.all_reduce(
param.grad,
group=global_mesh.get_group("dp"),
op=dist.ReduceOp.AVG,
)

# Specially check the TP placement for `pos_embeddings.weight` and
# its which since the grad naturally has replicate placement,
# requiring FSDP to redistribute it to shard placement before FSDP
# runs its reduce-scatter
self.assertIsInstance(model.pos_embeddings.weight.placements[1], Shard)
self.assertIsInstance(model.pos_embeddings.weight.grad.placements[1], Shard)
for ref_param, (param_name, param) in zip(
ref_model.parameters(), model.named_parameters()
):
full_grad = param.grad.full_tensor()
ref_grad = ref_param.grad
self.assertEqual(ref_param.grad, full_grad)

ref_optim.step()
optim.step()
ref_optim.zero_grad()
optim.zero_grad()

for param, ref_param in zip(model.parameters(), ref_model.parameters()):
full_param = param.full_tensor()
self.assertEqual(full_param, ref_param)

@skip_if_lt_x_gpu(2)
@skipIfRocm
def test_tp_with_fsdp_offloading(self):
Expand Down
9 changes: 5 additions & 4 deletions torch/distributed/_composable/fsdp/_fsdp_param.py
Original file line number Diff line number Diff line change
Expand Up @@ -695,10 +695,11 @@ def _get_grad_inner_tensor(self, grad: torch.Tensor) -> torch.Tensor:
if isinstance(grad, AsyncCollectiveTensor):
grad = grad.wait()
assert isinstance(grad, DTensor), f"{type(grad)}"
if any(pl.is_partial() for pl in grad.placements):
placements = [
Replicate() if pl.is_partial() else pl for pl in grad.placements
]
placements = self._tp_spec.placements
if placements != grad.placements:
assert len(self._tp_spec.placements) == len(
grad.placements
), f"{self._tp_spec=} {grad.placements=}"
grad = grad.redistribute(placements=placements)
grad = grad._local_tensor
return grad
Expand Down

0 comments on commit 65df26f

Please sign in to comment.