From 65df26f61559ec580b8e455c8833b19358298cc2 Mon Sep 17 00:00:00 2001 From: Andrew Gu Date: Tue, 17 Sep 2024 23:00:11 -0700 Subject: [PATCH] [FSDP2] Fixed 2D mismatched grad placements (#136237) ``` CUDA_VISIBLE_DEVICES=2,3,6,7 pytest test/distributed/_composable/test_composability/test_2d_composability.py -k test_train_parity_2d_transformer ``` Differential Revision: [D62964658](https://our.internmc.facebook.com/intern/diff/D62964658) Pull Request resolved: https://github.com/pytorch/pytorch/pull/136237 Approved by: https://github.com/weifengpy --- .../test_2d_composability.py | 62 +++++++++++++++++++ .../_composable/fsdp/_fsdp_param.py | 9 +-- 2 files changed, 67 insertions(+), 4 deletions(-) diff --git a/test/distributed/_composable/test_composability/test_2d_composability.py b/test/distributed/_composable/test_composability/test_2d_composability.py index c8d2ac4d4dc25..83b0f8f2b5ac6 100644 --- a/test/distributed/_composable/test_composability/test_2d_composability.py +++ b/test/distributed/_composable/test_composability/test_2d_composability.py @@ -171,6 +171,68 @@ def _test_train_parity_2d_mlp( _optim.step() self.assertEqual(losses[0], losses[1]) + @skip_if_lt_x_gpu(2) + @skipIfRocm + def test_train_parity_2d_transformer(self): + torch.manual_seed(42) + model_args = ModelArgs(n_layers=3, dropout_p=0.0) + model = Transformer(model_args) + ref_model = copy.deepcopy(model).cuda() + ref_optim = torch.optim.AdamW(ref_model.parameters(), lr=1e-2) + + dp_size, tp_size = self.world_size // 2, 2 + global_mesh = init_device_mesh( + "cuda", (dp_size, tp_size), mesh_dim_names=("dp", "tp") + ) + model = Transformer.parallelize(model, global_mesh["tp"], use_seq_parallel=True) + + for layer in model.layers: + fully_shard(layer, mesh=global_mesh["dp"]) + fully_shard(model, mesh=global_mesh["dp"]) + optim = torch.optim.AdamW(model.parameters(), lr=1e-2) + + for param, ref_param in zip(model.parameters(), ref_model.parameters()): + full_param = param.full_tensor() + self.assertEqual(full_param, ref_param) + + torch.manual_seed(42 + global_mesh.get_local_rank("dp")) + inp = torch.randint(0, model_args.vocab_size, (2, 16), device="cuda") + for iter_idx in range(5): + ref_loss = ref_model(inp).sum() + loss = model(inp).sum() + self.assertEqual(ref_loss, loss) + ref_loss.backward() + loss.backward() + for param in ref_model.parameters(): + if param.grad is not None: + dist.all_reduce( + param.grad, + group=global_mesh.get_group("dp"), + op=dist.ReduceOp.AVG, + ) + + # Specially check the TP placement for `pos_embeddings.weight` and + # its which since the grad naturally has replicate placement, + # requiring FSDP to redistribute it to shard placement before FSDP + # runs its reduce-scatter + self.assertIsInstance(model.pos_embeddings.weight.placements[1], Shard) + self.assertIsInstance(model.pos_embeddings.weight.grad.placements[1], Shard) + for ref_param, (param_name, param) in zip( + ref_model.parameters(), model.named_parameters() + ): + full_grad = param.grad.full_tensor() + ref_grad = ref_param.grad + self.assertEqual(ref_param.grad, full_grad) + + ref_optim.step() + optim.step() + ref_optim.zero_grad() + optim.zero_grad() + + for param, ref_param in zip(model.parameters(), ref_model.parameters()): + full_param = param.full_tensor() + self.assertEqual(full_param, ref_param) + @skip_if_lt_x_gpu(2) @skipIfRocm def test_tp_with_fsdp_offloading(self): diff --git a/torch/distributed/_composable/fsdp/_fsdp_param.py b/torch/distributed/_composable/fsdp/_fsdp_param.py index 59c2ff7589d3a..fef1c865b6b08 100644 --- a/torch/distributed/_composable/fsdp/_fsdp_param.py +++ b/torch/distributed/_composable/fsdp/_fsdp_param.py @@ -695,10 +695,11 @@ def _get_grad_inner_tensor(self, grad: torch.Tensor) -> torch.Tensor: if isinstance(grad, AsyncCollectiveTensor): grad = grad.wait() assert isinstance(grad, DTensor), f"{type(grad)}" - if any(pl.is_partial() for pl in grad.placements): - placements = [ - Replicate() if pl.is_partial() else pl for pl in grad.placements - ] + placements = self._tp_spec.placements + if placements != grad.placements: + assert len(self._tp_spec.placements) == len( + grad.placements + ), f"{self._tp_spec=} {grad.placements=}" grad = grad.redistribute(placements=placements) grad = grad._local_tensor return grad