diff --git a/tests/common/datasets.py b/tests/common/datasets.py index 5c80dc63c4..b2132edb35 100644 --- a/tests/common/datasets.py +++ b/tests/common/datasets.py @@ -68,7 +68,6 @@ def __init__( self.shape: Sequence[int] = shape self.num_classes: int = num_classes self.device: Optional[torch.device] = device - self.generator: Optional[torch.Generator] = None self.x: Optional[torch.Tensor] = None self.y: Optional[torch.Tensor] = None @@ -79,7 +78,11 @@ def __getitem__(self, index: int): # Note: lazily generate data so it runs after Composer seeds everything, giving the same # dataset across multiple calls when using the same seed. if self.x is None: - self.x = torch.randn(self.size, *self.shape, device=self.device) + self.x = torch.randn( + self.size, + *self.shape, + device=self.device, + ) ic(self.x, self.x.device) if self.y is None: self.y = torch.randint(0, self.num_classes, size=(self.size,), device=self.device) diff --git a/tests/trainer/test_tp.py b/tests/trainer/test_tp.py index d6982e178c..642f9cbfcb 100644 --- a/tests/trainer/test_tp.py +++ b/tests/trainer/test_tp.py @@ -2,7 +2,9 @@ # SPDX-License-Identifier: Apache-2.0 from contextlib import contextmanager -from typing import Any, Optional, Sequence +from typing import Any, Optional, Sequence, TypeVar + +E = TypeVar('E', bound=BaseException) import numpy as np import pytest @@ -105,6 +107,7 @@ def get_trainer( """Trainer for a simple model with any parallelism_config.""" reproducibility.seed_all(seed) + if replication: dataset: Dataset = RandomClassificationDatasetReplicated( shape=(num_features,), @@ -246,6 +249,7 @@ def get_tp_fsdp_trainer( def forward_pass(trainer): + reproducibility.seed_all(42) batch = next(iter(trainer.state.train_dataloader)) output = trainer.state.model.forward(batch) return output @@ -260,7 +264,7 @@ def _replace_state_dict_name(state_dict: dict[str, Any], old_name: str, new_name return state_dict -def _compare_modules(module1, module2, check_grad: bool = False): +def _compare_modules(module1: dict[str, Any], module2: dict[str, Any], check_grad: bool = False): module_type = 'Gradients' if check_grad else 'Parameters' for (param1_name, param1), (param2_name, param2) in zip(module1.items(), module2.items()): @@ -289,9 +293,8 @@ def compare_models(ddp_trainer: Trainer, fsdp_trainer: Trainer, tp_fsdp_trainer: # However, calling `tp_fsdp_trainer.state.state_dict()` directly causes a NCCL timeout # due to this pytorch bug: https://github.com/pytorch/pytorch/issues/134095/. # As a workaround, we use `tp_fsdp_trainer.state.model.named_parameters()` instead. - # This issues only exists in `tp_fsdp_trainer.state.state_dict()` when we use TP and - # FSDP together; it does not arise when calling `ddp_trainer.state.state_dict()` or - # `fsdp_trainer.state.state_dict()`. + # This issues only exists with `tp_fsdp_trainer.state.state_dict()` and does not + # arise when calling `ddp_trainer.state.state_dict()` or `fsdp_trainer.state.state_dict()`. with FSDP.summon_full_params(fsdp_trainer.state.model, with_grads=check_grad): with FSDP.summon_full_params(tp_fsdp_trainer.state.model, with_grads=check_grad): ddp_params = dict(ddp_trainer.state.model.named_parameters()) @@ -311,6 +314,15 @@ def compare_models(ddp_trainer: Trainer, fsdp_trainer: Trainer, tp_fsdp_trainer: _compare_modules(ddp_params, fsdp_params, check_grad=check_grad) +@contextmanager +def fail_without_replication(replication: int, exception: type[E], error_error_msg: str): + if replication: + yield + else: + with pytest.raises(exception, match=error_error_msg): + yield + + @pytest.mark.gpu @world_size(4) @pytest.mark.skipif(version.parse(torch.__version__) < version.parse('2.3'), reason='requires PyTorch 2.3+') @@ -525,15 +537,6 @@ def test_tp_hang(world_size: int): # print(tp_fsdp_state_dict_6) -@contextmanager -def fail_without_replication(replication: int, exception: Exception, error_error_msg: str): - if replication: - yield - else: - with pytest.raises(exception, match=error_error_msg): - yield - - @pytest.mark.gpu @world_size(4) @pytest.mark.skipif(version.parse(torch.__version__) < version.parse('2.3'), reason='Requires PyTorch 2.3+') @@ -546,58 +549,19 @@ def test_tp_forwards_backwards(world_size: int, replication: int = 0): - updated weights """ - # # Initialize trainers with DDP, FSDP, TP-FSDP - # ddp_trainer = get_ddp_trainer(replication=replication) - # fsdp_trainer = get_fsdp_trainer(replication=replication) - # tp_fsdp_trainer = get_tp_fsdp_trainer(replication=replication) - - # # Ensure initial model weights are the same - # compare_models(ddp_trainer, fsdp_trainer, tp_fsdp_trainer) - - # # Forward pass - # ddp_out = forward_pass(ddp_trainer) - # fsdp_out = forward_pass(fsdp_trainer) - # tp_fsdp_out = forward_pass(tp_fsdp_trainer) - - # # Ensure output of the forward pass is the same - # _compare_modules({'': ddp_out}, {'': fsdp_out}) - # _compare_modules({'': ddp_out}, {'': tp_fsdp_out}) - # _compare_modules({'': fsdp_out}, {'': tp_fsdp_out}) - - # # Compute gradients - # torch.sum(ddp_out).backward() - # torch.sum(fsdp_out).backward() - # torch.sum(tp_fsdp_out).backward() - - # # Ensure the gradients are the same - # compare_models(ddp_trainer, fsdp_trainer, tp_fsdp_trainer, check_grad=True) - - # # Update the model weights - # ddp_trainer.state.optimizers[0].step() - # fsdp_trainer.state.optimizers[0].step() - # tp_fsdp_trainer.state.optimizers[0].step() - - # # Ensure the updated weights are the same - # compare_models(ddp_trainer, fsdp_trainer, tp_fsdp_trainer) - - # DDP + # Initialize trainers with DDP, FSDP, TP-FSDP ddp_trainer = get_ddp_trainer(replication=replication) - ddp_out = forward_pass(ddp_trainer) - torch.sum(ddp_out).backward() - - # FSDP fsdp_trainer = get_fsdp_trainer(replication=replication) - fsdp_out = forward_pass(fsdp_trainer) - torch.sum(fsdp_out).backward() - - # TP-FSDP tp_fsdp_trainer = get_tp_fsdp_trainer(replication=replication) - tp_fsdp_out = forward_pass(tp_fsdp_trainer) - torch.sum(tp_fsdp_out).backward() # Ensure initial model weights are the same compare_models(ddp_trainer, fsdp_trainer, tp_fsdp_trainer) + # Forward pass + ddp_out = forward_pass(ddp_trainer) + fsdp_out = forward_pass(fsdp_trainer) + tp_fsdp_out = forward_pass(tp_fsdp_trainer) + # Ensure output of the forward pass is the same with FSDP.summon_full_params(fsdp_trainer.state.model): with FSDP.summon_full_params(tp_fsdp_trainer.state.model): @@ -605,8 +569,13 @@ def test_tp_forwards_backwards(world_size: int, replication: int = 0): _compare_modules({'': ddp_out}, {'': tp_fsdp_out}) _compare_modules({'': fsdp_out}, {'': tp_fsdp_out}) + # Compute gradients + torch.sum(ddp_out).backward() + torch.sum(fsdp_out).backward() + torch.sum(tp_fsdp_out).backward() + # Ensure the gradients are the same - # We expect this test to fail without replication, i.e. replication=0 + # We expect this test to fail without replication, i.e. when replication=0 error_error_msg = 'Gradients are not close enough:*' with fail_without_replication(replication, AssertionError, error_error_msg): compare_models(ddp_trainer, fsdp_trainer, tp_fsdp_trainer, check_grad=True) @@ -617,7 +586,7 @@ def test_tp_forwards_backwards(world_size: int, replication: int = 0): tp_fsdp_trainer.state.optimizers[0].step() # Ensure the updated weights are the same - # We expect this test to fail without replication, i.e. replication=0 + # We expect this test to fail without replication, i.e. when replication=0 error_error_msg = 'Parameters are not close enough:*' with fail_without_replication(replication, AssertionError, error_error_msg): compare_models(ddp_trainer, fsdp_trainer, tp_fsdp_trainer) @@ -903,6 +872,6 @@ def test_tp_fsdp_state_dict(world_size: int): if __name__ == '__main__': import warnings warnings.filterwarnings('ignore') - # test_tp_forwards_backwards(4, replication=2) test_tp_forwards_backwards(4, replication=2) + test_tp_forwards_backwards(4, replication=0) # test_tp_fsdp_state_dict(4)