Skip to content

Commit

Permalink
make torch compile OSS compatible
Browse files Browse the repository at this point in the history
Summary: Touches few torch.compile calls in TorchTNT which rely on nightlies to be OSS friendly

Differential Revision: D56483268
  • Loading branch information
JKSenthil authored and facebook-github-bot committed Apr 25, 2024
1 parent e2a4ba9 commit 3c91a81
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 11 deletions.
20 changes: 10 additions & 10 deletions tests/utils/test_prepare_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,12 @@
TorchCompileParams,
)
from torchtnt.utils.test_utils import skip_if_not_distributed
from torchtnt.utils.version import Version
from torchtnt.utils.version import is_torch_version_geq


class PrepareModelTest(unittest.TestCase):
torch_version_geq_2_1_0: bool = is_torch_version_geq("2.1.0")

def test_invalid_fsdp_strategy_str_values(self) -> None:
from torchtnt.utils.prepare_module import MixedPrecision as _MixedPrecision

Expand Down Expand Up @@ -143,7 +145,9 @@ def _test_prepare_module_ddp_throws_with_compile_params_and_static_graph() -> No
"""

tc = unittest.TestCase()
with patch("torchtnt.utils.version.is_torch_version_geq", return_value=False):
with patch(
"torchtnt.utils.prepare_module.is_torch_version_geq", return_value=False
):
with tc.assertRaisesRegex(
RuntimeError,
"Torch version >= 2.1.0 required",
Expand All @@ -155,14 +159,6 @@ def _test_prepare_module_ddp_throws_with_compile_params_and_static_graph() -> No
torch_compile_params=TorchCompileParams(backend="inductor"),
)

# no error should be thrown on latest pytorch
prepare_module(
module=torch.nn.Linear(2, 2),
device=init_from_env(),
strategy=DDPStrategy(static_graph=True),
torch_compile_params=TorchCompileParams(backend="inductor"),
)

def test_prepare_module_compile_invalid_backend(self) -> None:
"""
verify error is thrown on invalid backend
Expand All @@ -188,6 +184,10 @@ def test_prepare_module_incompatible_FSDP_torchcompile_params(self) -> None:
torch_compile_params=TorchCompileParams(),
)

@unittest.skipUnless(
torch_version_geq_2_1_0,
reason="Must be on torch 2.1.0+ to run test",
)
def test_prepare_module_compile_module_state_dict(self) -> None:
device = init_from_env()
my_module = torch.nn.Linear(2, 2, device=device)
Expand Down
12 changes: 11 additions & 1 deletion torchtnt/framework/callbacks/torch_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,17 @@

import logging

from torch._inductor.codecache import shutdown_compile_workers
try:
from torch._inductor.codecache import shutdown_compile_workers
except ImportError:

def shutdown_compile_workers() -> None:
logging.warning(
"shutdown_compile_workers is not available in your version of PyTorch. \
Please use nightly version to enable this feature."
)


from torchtnt.framework.callback import Callback
from torchtnt.framework.state import State
from torchtnt.framework.unit import TTrainUnit
Expand Down

0 comments on commit 3c91a81

Please sign in to comment.