Skip to content

Commit

Permalink
add and use generic torch version comparator
Browse files Browse the repository at this point in the history
Differential Revision: D56446382
  • Loading branch information
JKSenthil authored and facebook-github-bot committed Apr 25, 2024
1 parent 4594bf0 commit 5e51dd5
Show file tree
Hide file tree
Showing 5 changed files with 15 additions and 17 deletions.
18 changes: 8 additions & 10 deletions tests/utils/test_prepare_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,12 @@
TorchCompileParams,
)
from torchtnt.utils.test_utils import skip_if_not_distributed
from torchtnt.utils.version import Version
from torchtnt.utils.version import is_torch_version_geq


class PrepareModelTest(unittest.TestCase):
torch_version_geq_2_1_0: bool = is_torch_version_geq("2.1.0")

def test_invalid_fsdp_strategy_str_values(self) -> None:
from torchtnt.utils.prepare_module import MixedPrecision as _MixedPrecision

Expand Down Expand Up @@ -144,7 +146,7 @@ def _test_prepare_module_ddp_throws_with_compile_params_and_static_graph() -> No

tc = unittest.TestCase()
with patch(
"torchtnt.utils.version.get_torch_version", return_value=Version("2.0.0")
"torchtnt.utils.prepare_module.is_torch_version_geq", return_value=False
):
with tc.assertRaisesRegex(
RuntimeError,
Expand All @@ -157,14 +159,6 @@ def _test_prepare_module_ddp_throws_with_compile_params_and_static_graph() -> No
torch_compile_params=TorchCompileParams(backend="inductor"),
)

# no error should be thrown on latest pytorch
prepare_module(
module=torch.nn.Linear(2, 2),
device=init_from_env(),
strategy=DDPStrategy(static_graph=True),
torch_compile_params=TorchCompileParams(backend="inductor"),
)

def test_prepare_module_compile_invalid_backend(self) -> None:
"""
verify error is thrown on invalid backend
Expand All @@ -190,6 +184,10 @@ def test_prepare_module_incompatible_FSDP_torchcompile_params(self) -> None:
torch_compile_params=TorchCompileParams(),
)

@unittest.skipUnless(
torch_version_geq_2_1_0,
reason="Must be on torch 2.1.0+ to run test",
)
def test_prepare_module_compile_module_state_dict(self) -> None:
device = init_from_env()
my_module = torch.nn.Linear(2, 2, device=device)
Expand Down
2 changes: 1 addition & 1 deletion tests/utils/test_version.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,4 +49,4 @@ def test_get_torch_version(self) -> None:

def test_torch_version_comparators(self) -> None:
with patch.object(torch, "__version__", "2.0.0a0"):
self.assertFalse(version.is_torch_version_geq_2_1())
self.assertFalse(version.is_torch_version_geq("2.1.0"))
4 changes: 2 additions & 2 deletions torchtnt/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@
from .version import (
get_python_version,
get_torch_version,
is_torch_version_geq_2_1,
is_torch_version_geq,
is_windows,
)

Expand Down Expand Up @@ -144,7 +144,7 @@
"TLRScheduler",
"get_python_version",
"get_torch_version",
"is_torch_version_geq_2_1",
"is_torch_version_geq",
"is_windows",
"get_pet_launch_config",
"spawn_multi_process",
Expand Down
4 changes: 2 additions & 2 deletions torchtnt/utils/prepare_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
)

from torchtnt.utils.rank_zero_log import rank_zero_warn
from torchtnt.utils.version import is_torch_version_geq_2_1
from torchtnt.utils.version import is_torch_version_geq


@dataclass
Expand Down Expand Up @@ -318,7 +318,7 @@ def prepare_module(
if (
torch_compile_params
and strategy.static_graph is True
and not is_torch_version_geq_2_1()
and not is_torch_version_geq("2.1.0")
):
raise RuntimeError(
"Torch version >= 2.1.0 required for Torch compile + DDP with static graph"
Expand Down
4 changes: 2 additions & 2 deletions torchtnt/utils/version.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,5 +56,5 @@ def get_torch_version() -> Version:
return pkg_version


def is_torch_version_geq_2_1() -> bool:
return get_torch_version() >= Version("2.1.0")
def is_torch_version_geq(version: str) -> bool:
return get_torch_version() >= Version(version)

0 comments on commit 5e51dd5

Please sign in to comment.