From 3c91a81b3f991a249f55f283673993990d3fc6ca Mon Sep 17 00:00:00 2001 From: Jason Senthil Date: Wed, 24 Apr 2024 19:17:45 -0700 Subject: [PATCH] make torch compile OSS compatible Summary: Touches few torch.compile calls in TorchTNT which rely on nightlies to be OSS friendly Differential Revision: D56483268 --- tests/utils/test_prepare_module.py | 20 +++++++++---------- torchtnt/framework/callbacks/torch_compile.py | 12 ++++++++++- 2 files changed, 21 insertions(+), 11 deletions(-) diff --git a/tests/utils/test_prepare_module.py b/tests/utils/test_prepare_module.py index 3a34015ec5..7bdf6867e4 100644 --- a/tests/utils/test_prepare_module.py +++ b/tests/utils/test_prepare_module.py @@ -22,10 +22,12 @@ TorchCompileParams, ) from torchtnt.utils.test_utils import skip_if_not_distributed -from torchtnt.utils.version import Version +from torchtnt.utils.version import is_torch_version_geq class PrepareModelTest(unittest.TestCase): + torch_version_geq_2_1_0: bool = is_torch_version_geq("2.1.0") + def test_invalid_fsdp_strategy_str_values(self) -> None: from torchtnt.utils.prepare_module import MixedPrecision as _MixedPrecision @@ -143,7 +145,9 @@ def _test_prepare_module_ddp_throws_with_compile_params_and_static_graph() -> No """ tc = unittest.TestCase() - with patch("torchtnt.utils.version.is_torch_version_geq", return_value=False): + with patch( + "torchtnt.utils.prepare_module.is_torch_version_geq", return_value=False + ): with tc.assertRaisesRegex( RuntimeError, "Torch version >= 2.1.0 required", @@ -155,14 +159,6 @@ def _test_prepare_module_ddp_throws_with_compile_params_and_static_graph() -> No torch_compile_params=TorchCompileParams(backend="inductor"), ) - # no error should be thrown on latest pytorch - prepare_module( - module=torch.nn.Linear(2, 2), - device=init_from_env(), - strategy=DDPStrategy(static_graph=True), - torch_compile_params=TorchCompileParams(backend="inductor"), - ) - def test_prepare_module_compile_invalid_backend(self) -> None: """ verify error is thrown on invalid backend @@ -188,6 +184,10 @@ def test_prepare_module_incompatible_FSDP_torchcompile_params(self) -> None: torch_compile_params=TorchCompileParams(), ) + @unittest.skipUnless( + torch_version_geq_2_1_0, + reason="Must be on torch 2.1.0+ to run test", + ) def test_prepare_module_compile_module_state_dict(self) -> None: device = init_from_env() my_module = torch.nn.Linear(2, 2, device=device) diff --git a/torchtnt/framework/callbacks/torch_compile.py b/torchtnt/framework/callbacks/torch_compile.py index fc5bcf60f3..fb2c744013 100644 --- a/torchtnt/framework/callbacks/torch_compile.py +++ b/torchtnt/framework/callbacks/torch_compile.py @@ -8,7 +8,17 @@ import logging -from torch._inductor.codecache import shutdown_compile_workers +try: + from torch._inductor.codecache import shutdown_compile_workers +except ImportError: + + def shutdown_compile_workers() -> None: + logging.warning( + "shutdown_compile_workers is not available in your version of PyTorch. \ + Please use nightly version to enable this feature." + ) + + from torchtnt.framework.callback import Callback from torchtnt.framework.state import State from torchtnt.framework.unit import TTrainUnit