Skip to content

Commit

Permalink
add and use generic torch version comparator
Browse files Browse the repository at this point in the history
Differential Revision: D56446382
  • Loading branch information
JKSenthil authored and facebook-github-bot committed Apr 25, 2024
1 parent 0f7ab29 commit 4b08a29
Show file tree
Hide file tree
Showing 5 changed files with 8 additions and 10 deletions.
4 changes: 1 addition & 3 deletions tests/utils/test_prepare_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,9 +143,7 @@ def _test_prepare_module_ddp_throws_with_compile_params_and_static_graph() -> No
"""

tc = unittest.TestCase()
with patch(
"torchtnt.utils.version.get_torch_version", return_value=Version("2.0.0")
):
with patch("torchtnt.utils.version.is_torch_version_geq", return_value=False):
with tc.assertRaisesRegex(
RuntimeError,
"Torch version >= 2.1.0 required",
Expand Down
2 changes: 1 addition & 1 deletion tests/utils/test_version.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,4 +49,4 @@ def test_get_torch_version(self) -> None:

def test_torch_version_comparators(self) -> None:
with patch.object(torch, "__version__", "2.0.0a0"):
self.assertFalse(version.is_torch_version_geq_2_1())
self.assertFalse(version.is_torch_version_geq("2.1.0"))
4 changes: 2 additions & 2 deletions torchtnt/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@
from .version import (
get_python_version,
get_torch_version,
is_torch_version_geq_2_1,
is_torch_version_geq,
is_windows,
)

Expand Down Expand Up @@ -144,7 +144,7 @@
"TLRScheduler",
"get_python_version",
"get_torch_version",
"is_torch_version_geq_2_1",
"is_torch_version_geq",
"is_windows",
"get_pet_launch_config",
"spawn_multi_process",
Expand Down
4 changes: 2 additions & 2 deletions torchtnt/utils/prepare_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
)

from torchtnt.utils.rank_zero_log import rank_zero_warn
from torchtnt.utils.version import is_torch_version_geq_2_1
from torchtnt.utils.version import is_torch_version_geq


@dataclass
Expand Down Expand Up @@ -318,7 +318,7 @@ def prepare_module(
if (
torch_compile_params
and strategy.static_graph is True
and not is_torch_version_geq_2_1()
and not is_torch_version_geq("2.1.0")
):
raise RuntimeError(
"Torch version >= 2.1.0 required for Torch compile + DDP with static graph"
Expand Down
4 changes: 2 additions & 2 deletions torchtnt/utils/version.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,5 +56,5 @@ def get_torch_version() -> Version:
return pkg_version


def is_torch_version_geq_2_1() -> bool:
return get_torch_version() >= Version("2.1.0")
def is_torch_version_geq(version: str) -> bool:
return get_torch_version() >= Version(version)

0 comments on commit 4b08a29

Please sign in to comment.