diff --git a/torchtnt/framework/callbacks/torch_compile.py b/torchtnt/framework/callbacks/torch_compile.py index fc5bcf60f3..fb2c744013 100644 --- a/torchtnt/framework/callbacks/torch_compile.py +++ b/torchtnt/framework/callbacks/torch_compile.py @@ -8,7 +8,17 @@ import logging -from torch._inductor.codecache import shutdown_compile_workers +try: + from torch._inductor.codecache import shutdown_compile_workers +except ImportError: + + def shutdown_compile_workers() -> None: + logging.warning( + "shutdown_compile_workers is not available in your version of PyTorch. \ + Please use nightly version to enable this feature." + ) + + from torchtnt.framework.callback import Callback from torchtnt.framework.state import State from torchtnt.framework.unit import TTrainUnit