diff --git a/tests/framework/test_auto_unit.py b/tests/framework/test_auto_unit.py index 92b9bffca0..593db6eecb 100644 --- a/tests/framework/test_auto_unit.py +++ b/tests/framework/test_auto_unit.py @@ -10,6 +10,7 @@ from unittest.mock import MagicMock, patch import torch +from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler from torchtnt.utils.version import is_torch_version_geq_1_13 @@ -79,6 +80,35 @@ def test_app_state_mixin(self) -> None: for key in ("module", "optimizer", "lr_scheduler", "grad_scaler"): self.assertIn(key, auto_unit.app_state()) + @unittest.skipUnless( + condition=distributed_available, reason="Torch distributed is needed to run" + ) + @unittest.skipUnless( + condition=cuda_available, reason="This test needs a GPU host to run." + ) + def test_fsdp_fp16(self) -> None: + """ + Test that FSDP + FP16 uses ShardedGradScaler + """ + spawn_multi_process( + 2, + "nccl", + self._test_fsdp_fp16, + ) + + @staticmethod + def _test_fsdp_fp16() -> None: + device = init_from_env() + my_module = torch.nn.Linear(2, 2) + auto_unit_fsdp = DummyAutoUnit( + module=my_module, + device=device, + strategy=FSDPStrategy(), + precision="fp16", + ) + tc = unittest.TestCase() + tc.assertTrue(isinstance(auto_unit_fsdp.grad_scaler, ShardedGradScaler)) + def test_lr_scheduler_step(self) -> None: """ Test that the lr scheduler is stepped every optimizer step when step_lr_interval="step" diff --git a/torchtnt/framework/auto_unit.py b/torchtnt/framework/auto_unit.py index 9d9a85a164..8e4dd4c6d5 100644 --- a/torchtnt/framework/auto_unit.py +++ b/torchtnt/framework/auto_unit.py @@ -476,7 +476,7 @@ def __init__( if self.precision: self.grad_scaler = _get_grad_scaler_from_precision( self.precision, - module, + self.module, ) self.step_lr_interval = step_lr_interval