diff --git a/tests/trainer/test_tp.py b/tests/trainer/test_tp.py index 4e75129550..0ba390e178 100644 --- a/tests/trainer/test_tp.py +++ b/tests/trainer/test_tp.py @@ -21,7 +21,8 @@ SimpleModel, world_size, ) - +from icecream import install +install() @pytest.mark.gpu @world_size(4) @@ -274,7 +275,7 @@ def forward_pass(trainer): @pytest.mark.skipif(version.parse(torch.__version__) < version.parse('2.3'), reason='Requires PyTorch 2.3+') @pytest.mark.filterwarnings(r'ignore:.*\(TP\) is experimental.*:FutureWarning') def test_tp_forward(world_size: int): - """Test that the forward pass with DDP, FSDP, TP-FSDP all output the same tensor.""" + """Test that DDP, FSDP, TP-FSDP do the same forward pass.""" # DDP forward pass ddp_trainer = get_ddp_trainer() @@ -297,6 +298,44 @@ def test_tp_forward(world_size: int): ), f'Outputs have different values: {ddp_out=} and {tp_fsdp_out=}' +@pytest.mark.gpu +@world_size(4) +@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('2.3'), reason='Requires PyTorch 2.3+') +@pytest.mark.filterwarnings(r'ignore:.*\(TP\) is experimental.*:FutureWarning') +def test_tp_gradients(world_size: int): + """Test that DDP, FSDP, TP-FSDP output the same gradients.""" + + # DDP gradients + ddp_trainer = get_ddp_trainer() + ddp_out = forward_pass(ddp_trainer) + torch.sum(ddp_out).backward() + + ic('ddp_trainer') + for name, param in ddp_trainer.state.model.named_parameters(): + if param.grad is not None: + ic(name, param.shape, param.grad.shape, param, param.grad) + + # FSDP gradients + fsdp_trainer = get_fsdp_trainer() + fsdp_out = forward_pass(fsdp_trainer) + torch.sum(fsdp_out).backward() + + ic('fsdp_trainer') + for name, param in fsdp_trainer.state.model.named_parameters(): + if param.grad is not None: + ic(name, param.shape, param.grad.shape, param, param.grad) + + # TP-FSDP gradients + tp_fsdp_trainer = get_tp_fsdp_trainer() + tp_fsdp_out = forward_pass(tp_fsdp_trainer) + torch.sum(tp_fsdp_out).backward() + + ic('tp_fsdp_trainer') + for name, param in tp_fsdp_trainer.state.model.named_parameters(): + if param.grad is not None: + ic(name, param.shape, param.grad.shape, param, param.grad) + + def get_stats(trainer: Trainer) -> dict[str, np.ndarray]: logger = trainer.logger.destinations[0] stats = { @@ -311,7 +350,7 @@ def get_stats(trainer: Trainer) -> dict[str, np.ndarray]: @pytest.mark.skipif(version.parse(torch.__version__) < version.parse('2.3'), reason='Requires PyTorch 2.3+') @pytest.mark.filterwarnings(r'ignore:.*\(TP\) is experimental.*:FutureWarning') def test_tp_fit(world_size: int): - """Test that trainer.fit() with DDP, FSDP, TP-FSDP all output the same loss and accuracy.""" + """Test that DDP, FSDP, TP-FSDP have the same trainer.fit(), i.e. output the same loss and accuracy.""" size = 1024 # make enough data to train for multiple steps