Skip to content

Commit

Permalink
fix AutoUnit FSDP + fp16 (#634)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #634

make sure ShardedGradScaler is used for fsdp + fp16

Reviewed By: dongyuzheng, JKSenthil

Differential Revision: D51489771

fbshipit-source-id: d84e8909cd8b18c0393d47c5e3111035244079fe
  • Loading branch information
galrotem authored and facebook-github-bot committed Nov 21, 2023
1 parent 03b8ec6 commit a780246
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 1 deletion.
30 changes: 30 additions & 0 deletions tests/framework/test_auto_unit.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from unittest.mock import MagicMock, patch

import torch
from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler

from torchtnt.utils.version import is_torch_version_geq_1_13

Expand Down Expand Up @@ -79,6 +80,35 @@ def test_app_state_mixin(self) -> None:
for key in ("module", "optimizer", "lr_scheduler", "grad_scaler"):
self.assertIn(key, auto_unit.app_state())

@unittest.skipUnless(
condition=distributed_available, reason="Torch distributed is needed to run"
)
@unittest.skipUnless(
condition=cuda_available, reason="This test needs a GPU host to run."
)
def test_fsdp_fp16(self) -> None:
"""
Test that FSDP + FP16 uses ShardedGradScaler
"""
spawn_multi_process(
2,
"nccl",
self._test_fsdp_fp16,
)

@staticmethod
def _test_fsdp_fp16() -> None:
device = init_from_env()
my_module = torch.nn.Linear(2, 2)
auto_unit_fsdp = DummyAutoUnit(
module=my_module,
device=device,
strategy=FSDPStrategy(),
precision="fp16",
)
tc = unittest.TestCase()
tc.assertTrue(isinstance(auto_unit_fsdp.grad_scaler, ShardedGradScaler))

def test_lr_scheduler_step(self) -> None:
"""
Test that the lr scheduler is stepped every optimizer step when step_lr_interval="step"
Expand Down
2 changes: 1 addition & 1 deletion torchtnt/framework/auto_unit.py
Original file line number Diff line number Diff line change
Expand Up @@ -476,7 +476,7 @@ def __init__(
if self.precision:
self.grad_scaler = _get_grad_scaler_from_precision(
self.precision,
module,
self.module,
)

self.step_lr_interval = step_lr_interval
Expand Down

0 comments on commit a780246

Please sign in to comment.