diff --git a/tests/framework/callbacks/test_module_summary.py b/tests/framework/callbacks/test_module_summary.py index 217705ec57..81417b0fc3 100644 --- a/tests/framework/callbacks/test_module_summary.py +++ b/tests/framework/callbacks/test_module_summary.py @@ -19,11 +19,6 @@ from torchtnt.framework.callbacks.module_summary import ModuleSummary from torchtnt.framework.state import EntryPoint, PhaseState, State -from torchtnt.utils.version import is_torch_version_geq_1_13 - -MODULE_SUMMARY_FLOPS_AVAILABLE = False -if is_torch_version_geq_1_13(): - MODULE_SUMMARY_FLOPS_AVAILABLE = True class ModuleSummaryTest(unittest.TestCase): @@ -85,10 +80,6 @@ def forward(self, x): self.assertTrue("b1" in ms.submodule_summaries) self.assertTrue("l2" in ms.submodule_summaries) - @unittest.skipUnless( - condition=MODULE_SUMMARY_FLOPS_AVAILABLE, - reason="This test needs PyTorch 1.13 or greater to run.", - ) def test_module_summary_retrieve_module_summaries_module_inputs(self) -> None: """ Test ModuleSummary callback in train diff --git a/tests/framework/test_auto_unit.py b/tests/framework/test_auto_unit.py index fe2142fe21..c059ef32c4 100644 --- a/tests/framework/test_auto_unit.py +++ b/tests/framework/test_auto_unit.py @@ -12,15 +12,6 @@ from unittest.mock import MagicMock, patch import torch -from torchtnt.framework.auto_unit import TrainStepResults -from torchtnt.utils.test_utils import skip_if_not_distributed - -from torchtnt.utils.version import is_torch_version_geq_1_13 - -COMPILE_AVAIL = False -if is_torch_version_geq_1_13(): - COMPILE_AVAIL = True - import torch._dynamo from pyre_extensions import none_throws, ParameterSpecification as ParamSpec @@ -37,6 +28,7 @@ AutoUnit, SWALRParams, SWAParams, + TrainStepResults, ) from torchtnt.framework.evaluate import evaluate from torchtnt.framework.predict import predict @@ -49,6 +41,7 @@ from torchtnt.utils.lr_scheduler import TLRScheduler from torchtnt.utils.prepare_module import DDPStrategy from torchtnt.utils.progress import Progress +from torchtnt.utils.test_utils import skip_if_not_distributed from torchtnt.utils.timer import Timer TParams = ParamSpec("TParams") diff --git a/tests/framework/test_auto_unit_gpu.py b/tests/framework/test_auto_unit_gpu.py index 6f8bdf710b..8223508614 100644 --- a/tests/framework/test_auto_unit_gpu.py +++ b/tests/framework/test_auto_unit_gpu.py @@ -8,24 +8,16 @@ # pyre-strict import unittest + +from copy import deepcopy from typing import TypeVar from unittest.mock import MagicMock, patch import torch -from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler -from torchtnt.utils.test_utils import skip_if_not_distributed, skip_if_not_gpu - -from torchtnt.utils.version import is_torch_version_geq_1_13 - -COMPILE_AVAIL = False -if is_torch_version_geq_1_13(): - COMPILE_AVAIL = True - import torch._dynamo - -from copy import deepcopy from pyre_extensions import ParameterSpecification as ParamSpec from torch.distributed.fsdp import FullyShardedDataParallel as FSDP +from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler from torchtnt.framework._test_utils import ( DummyAutoUnit, generate_random_dataloader, @@ -40,6 +32,7 @@ from torchtnt.utils.distributed import spawn_multi_process from torchtnt.utils.env import init_from_env, seed from torchtnt.utils.prepare_module import DDPStrategy, FSDPStrategy, TorchCompileParams +from torchtnt.utils.test_utils import skip_if_not_distributed, skip_if_not_gpu TParams = ParamSpec("TParams") T = TypeVar("T") @@ -320,10 +313,6 @@ def test_predict_mixed_precision_fp16(self, mock_autocast: MagicMock) -> None: device_type="cuda", dtype=torch.float16, enabled=True ) - @unittest.skipUnless( - condition=COMPILE_AVAIL, - reason="This test needs PyTorch 1.13 or greater to run.", - ) @skip_if_not_gpu @patch("torch.compile") def test_compile_predict(self, mock_dynamo: MagicMock) -> None: diff --git a/tests/utils/test_memory_snapshot_profiler.py b/tests/utils/test_memory_snapshot_profiler.py index 00c14d983c..f277cc9334 100644 --- a/tests/utils/test_memory_snapshot_profiler.py +++ b/tests/utils/test_memory_snapshot_profiler.py @@ -14,17 +14,9 @@ MemorySnapshotParams, MemorySnapshotProfiler, ) -from torchtnt.utils.version import is_torch_version_geq_2_0 class MemorySnapshotProfilerTest(unittest.TestCase): - - torch_version_geq_2_0: bool = is_torch_version_geq_2_0() - - @unittest.skipUnless( - condition=torch_version_geq_2_0, - reason="This test needs changes from PyTorch 2.0 to run.", - ) def test_validation(self) -> None: """Test parameter validation.""" with tempfile.TemporaryDirectory() as temp_dir: diff --git a/tests/utils/test_memory_snapshot_profiler_gpu.py b/tests/utils/test_memory_snapshot_profiler_gpu.py index d4060d6b4d..3087a2e826 100644 --- a/tests/utils/test_memory_snapshot_profiler_gpu.py +++ b/tests/utils/test_memory_snapshot_profiler_gpu.py @@ -18,18 +18,10 @@ MemorySnapshotProfiler, ) from torchtnt.utils.test_utils import skip_if_not_gpu -from torchtnt.utils.version import is_torch_version_geq_2_0 class MemorySnapshotProfilerGPUTest(unittest.TestCase): - - torch_version_geq_2_0: bool = is_torch_version_geq_2_0() - @skip_if_not_gpu - @unittest.skipUnless( - condition=torch_version_geq_2_0, - reason="This test needs changes from PyTorch 2.0 to run.", - ) def test_stop_step(self) -> None: """Test that a memory snapshot is saved when stop_step is reached.""" with tempfile.TemporaryDirectory() as temp_dir: diff --git a/tests/utils/test_oom_gpu.py b/tests/utils/test_oom_gpu.py index 0ab67ce19d..850199d762 100644 --- a/tests/utils/test_oom_gpu.py +++ b/tests/utils/test_oom_gpu.py @@ -16,15 +16,10 @@ from torchtnt.utils.oom import log_memory_snapshot from torchtnt.utils.test_utils import skip_if_not_gpu -from torchtnt.utils.version import is_torch_version_geq_2_0 class OomGPUTest(unittest.TestCase): @skip_if_not_gpu - @unittest.skipUnless( - condition=bool(is_torch_version_geq_2_0()), - reason="This test needs changes from PyTorch 2.0 to run.", - ) def test_log_memory_snapshot(self) -> None: with tempfile.TemporaryDirectory() as temp_dir: # Record history diff --git a/tests/utils/test_prepare_module.py b/tests/utils/test_prepare_module.py index 2924d82f8e..65543feebf 100644 --- a/tests/utils/test_prepare_module.py +++ b/tests/utils/test_prepare_module.py @@ -22,12 +22,7 @@ TorchCompileParams, ) from torchtnt.utils.test_utils import skip_if_not_distributed -from torchtnt.utils.version import is_torch_version_geq_1_13, Version - -COMPILE_AVAIL = False -if is_torch_version_geq_1_13(): - COMPILE_AVAIL = True - import torch._dynamo +from torchtnt.utils.version import Version class PrepareModelTest(unittest.TestCase): @@ -170,10 +165,6 @@ def _test_prepare_module_ddp_throws_with_compile_params_and_static_graph() -> No torch_compile_params=TorchCompileParams(backend="inductor"), ) - @unittest.skipUnless( - condition=COMPILE_AVAIL, - reason="This test needs PyTorch 1.13 or greater to run.", - ) def test_prepare_module_compile_invalid_backend(self) -> None: """ verify error is thrown on invalid backend @@ -199,10 +190,6 @@ def test_prepare_module_incompatible_FSDP_torchcompile_params(self) -> None: torch_compile_params=TorchCompileParams(), ) - @unittest.skipUnless( - condition=COMPILE_AVAIL, - reason="This test needs PyTorch 1.13 or greater to run.", - ) def test_prepare_module_compile_module_state_dict(self) -> None: device = init_from_env() my_module = torch.nn.Linear(2, 2, device=device) diff --git a/tests/utils/test_prepare_module_gpu.py b/tests/utils/test_prepare_module_gpu.py index 7668341c2e..d583a7085d 100644 --- a/tests/utils/test_prepare_module_gpu.py +++ b/tests/utils/test_prepare_module_gpu.py @@ -7,9 +7,10 @@ # pyre-strict import unittest -from unittest.mock import patch import torch + +from torch.distributed._composable import fully_shard from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.distributed.fsdp.fully_sharded_data_parallel import MixedPrecision from torch.nn.parallel import DistributedDataParallel as DDP @@ -24,15 +25,6 @@ prepare_module, ) from torchtnt.utils.test_utils import skip_if_not_distributed, skip_if_not_gpu -from torchtnt.utils.version import is_torch_version_geq_1_13, is_torch_version_geq_2_0 - -COMPILE_AVAIL = False -if is_torch_version_geq_1_13(): - COMPILE_AVAIL = True - import torch._dynamo - -if is_torch_version_geq_2_0(): - from torch.distributed._composable import fully_shard class PrepareModelGPUTest(unittest.TestCase): @@ -81,33 +73,6 @@ def _test_prepare_fsdp() -> None: tc = unittest.TestCase() tc.assertTrue(isinstance(fsdp_module, FSDP)) - @skip_if_not_distributed - @skip_if_not_gpu - def test_fsdp_pytorch_version(self) -> None: - """ - Test that a RuntimeError is thrown when using FSDP, and PyTorch < v1.12 - """ - spawn_multi_process( - 2, - "nccl", - self._test_fsdp_pytorch_version, - ) - - @staticmethod - def _test_fsdp_pytorch_version() -> None: - device = init_from_env() - module = torch.nn.Linear(2, 2).to(device) - - tc = unittest.TestCase() - with patch( - "torchtnt.utils.prepare_module.is_torch_version_geq_1_12", - return_value=False, - ), tc.assertRaisesRegex( - RuntimeError, - "Please install PyTorch 1.12 or higher to use FSDP: https://pytorch.org/get-started/locally/", - ): - _ = prepare_fsdp(module, device, FSDPStrategy()) - @skip_if_not_distributed @unittest.skipUnless( condition=bool(torch.cuda.device_count() >= 2), @@ -128,9 +93,8 @@ def _test_is_fsdp_module() -> None: model = FSDP(torch.nn.Linear(1, 1, device=device)) assert _is_fsdp_module(model) model = torch.nn.Linear(1, 1, device=device) - if is_torch_version_geq_2_0(): - fully_shard(model) - assert _is_fsdp_module(model) + fully_shard(model) + assert _is_fsdp_module(model) @skip_if_not_distributed @skip_if_not_gpu diff --git a/tests/utils/test_version.py b/tests/utils/test_version.py index 0fda7c916b..c30b6d3070 100644 --- a/tests/utils/test_version.py +++ b/tests/utils/test_version.py @@ -48,48 +48,5 @@ def test_get_torch_version(self) -> None: self.assertEqual(version.get_torch_version(), Version("1.12.0")) def test_torch_version_comparators(self) -> None: - with patch.object(torch, "__version__", "1.7.0"): - self.assertFalse(version.is_torch_version_geq_1_8()) - self.assertFalse(version.is_torch_version_geq_1_9()) - self.assertFalse(version.is_torch_version_geq_1_10()) - self.assertFalse(version.is_torch_version_geq_1_11()) - self.assertFalse(version.is_torch_version_geq_1_12()) - - with patch.object(torch, "__version__", "1.8.0"): - self.assertTrue(version.is_torch_version_geq_1_8()) - self.assertFalse(version.is_torch_version_geq_1_9()) - self.assertFalse(version.is_torch_version_geq_1_10()) - self.assertFalse(version.is_torch_version_geq_1_11()) - self.assertFalse(version.is_torch_version_geq_1_12()) - - with patch.object(torch, "__version__", "1.9.0"): - self.assertTrue(version.is_torch_version_geq_1_8()) - self.assertTrue(version.is_torch_version_geq_1_9()) - self.assertFalse(version.is_torch_version_geq_1_10()) - self.assertFalse(version.is_torch_version_geq_1_11()) - self.assertFalse(version.is_torch_version_geq_1_12()) - - with patch.object(torch, "__version__", "1.10.0"): - self.assertTrue(version.is_torch_version_geq_1_8()) - self.assertTrue(version.is_torch_version_geq_1_9()) - self.assertTrue(version.is_torch_version_geq_1_10()) - self.assertFalse(version.is_torch_version_geq_1_11()) - self.assertFalse(version.is_torch_version_geq_1_12()) - - with patch.object(torch, "__version__", "1.11.0"): - self.assertTrue(version.is_torch_version_geq_1_8()) - self.assertTrue(version.is_torch_version_geq_1_9()) - self.assertTrue(version.is_torch_version_geq_1_10()) - self.assertTrue(version.is_torch_version_geq_1_11()) - self.assertFalse(version.is_torch_version_geq_1_12()) - - with patch.object(torch, "__version__", "1.12.0"): - self.assertTrue(version.is_torch_version_geq_1_8()) - self.assertTrue(version.is_torch_version_geq_1_9()) - self.assertTrue(version.is_torch_version_geq_1_10()) - self.assertTrue(version.is_torch_version_geq_1_11()) - self.assertTrue(version.is_torch_version_geq_1_12()) - with patch.object(torch, "__version__", "2.0.0a0"): - self.assertTrue(version.is_torch_version_ge_1_13_1()) - self.assertFalse(version.is_torch_version_geq_2_0()) + self.assertFalse(version.is_torch_version_geq_2_1()) diff --git a/torchtnt/framework/auto_unit.py b/torchtnt/framework/auto_unit.py index 24c8507677..dd061a01de 100644 --- a/torchtnt/framework/auto_unit.py +++ b/torchtnt/framework/auto_unit.py @@ -50,7 +50,6 @@ TorchCompileParams, ) from torchtnt.utils.swa import AveragedModel -from torchtnt.utils.version import is_torch_version_ge_1_13_1 from typing_extensions import Literal @@ -166,8 +165,6 @@ def __init__( torch_compile_params: Optional[TorchCompileParams] = None, ) -> None: super().__init__() - if torch_compile_params: - _validate_torch_compile_available() self.device: torch.device = device or init_from_env() self.precision: Optional[torch.dtype] = ( @@ -879,11 +876,3 @@ def _update_lr_and_swa(self, state: State, number_of_steps_or_epochs: int) -> No state, f"{self.__class__.__name__}.lr_scheduler_step" ): self.step_lr_scheduler() - - -def _validate_torch_compile_available() -> None: - if not is_torch_version_ge_1_13_1(): - raise RuntimeError( - "Torch compile support is available only in PyTorch 2.0 or higher. " - "Please install PyTorch 2.0 or higher to continue: https://pytorch.org/get-started/locally/" - ) diff --git a/torchtnt/utils/__init__.py b/torchtnt/utils/__init__.py index 00d6f74f5a..c1d13d28ef 100644 --- a/torchtnt/utils/__init__.py +++ b/torchtnt/utils/__init__.py @@ -74,15 +74,6 @@ from .version import ( get_python_version, get_torch_version, - is_torch_version_ge_1_13_1, - is_torch_version_geq_1_10, - is_torch_version_geq_1_11, - is_torch_version_geq_1_12, - is_torch_version_geq_1_13, - is_torch_version_geq_1_14, - is_torch_version_geq_1_8, - is_torch_version_geq_1_9, - is_torch_version_geq_2_0, is_torch_version_geq_2_1, is_windows, ) @@ -153,15 +144,6 @@ "TLRScheduler", "get_python_version", "get_torch_version", - "is_torch_version_ge_1_13_1", - "is_torch_version_geq_1_10", - "is_torch_version_geq_1_11", - "is_torch_version_geq_1_12", - "is_torch_version_geq_1_13", - "is_torch_version_geq_1_14", - "is_torch_version_geq_1_8", - "is_torch_version_geq_1_9", - "is_torch_version_geq_2_0", "is_torch_version_geq_2_1", "is_windows", "get_pet_launch_config", diff --git a/torchtnt/utils/device.py b/torchtnt/utils/device.py index cc9f16d0ef..5f80cd2989 100644 --- a/torchtnt/utils/device.py +++ b/torchtnt/utils/device.py @@ -16,7 +16,6 @@ from typing import Any, Dict, Mapping, TypeVar import torch -from torchtnt.utils.version import is_torch_version_geq_1_12 from typing_extensions import Protocol, runtime_checkable, TypedDict logger: logging.Logger = logging.getLogger(__name__) @@ -42,11 +41,7 @@ def get_device_from_env() -> torch.device: ) device = torch.device(f"cuda:{local_rank}") torch.cuda.set_device(device) - elif ( - is_torch_version_geq_1_12() - and torch.backends.mps.is_built() - and torch.backends.mps.is_available() - ): + elif torch.backends.mps.is_built() and torch.backends.mps.is_available(): device = torch.device("mps") else: device = torch.device("cpu") @@ -340,10 +335,7 @@ def set_float32_precision(precision: str = "high") -> None: Args: precision: The setting to determine which datatypes to use for matrix multiplication and convolution operations. """ - if not ( - torch.cuda.is_available() # Not relevant for non-CUDA devices - and is_torch_version_geq_1_12() # API exposed from PyTorch 1.12 onward - ): + if not (torch.cuda.is_available()): # Not relevant for non-CUDA devices return # set precision for matrix multiplications torch.set_float32_matmul_precision(precision) diff --git a/torchtnt/utils/memory_snapshot_profiler.py b/torchtnt/utils/memory_snapshot_profiler.py index 09d2076fd6..a94890c390 100644 --- a/torchtnt/utils/memory_snapshot_profiler.py +++ b/torchtnt/utils/memory_snapshot_profiler.py @@ -14,7 +14,6 @@ import torch from torchtnt.utils.oom import attach_oom_observer, log_memory_snapshot -from torchtnt.utils.version import is_torch_version_geq_2_0 logger: logging.Logger = logging.getLogger(__name__) @@ -133,8 +132,6 @@ def __init__( self.step_num: int = 0 self.is_started: bool = False - if not is_torch_version_geq_2_0(): - raise RuntimeError("CUDA memory snapshot requires torch>=2.0") if self.params.enable_oom_observer: attach_oom_observer( output_dir=output_dir, trace_max_entries=self.params.max_entries diff --git a/torchtnt/utils/module_summary.py b/torchtnt/utils/module_summary.py index eacdbe8bbb..84d937a694 100644 --- a/torchtnt/utils/module_summary.py +++ b/torchtnt/utils/module_summary.py @@ -30,8 +30,8 @@ from torch.nn.parameter import UninitializedParameter from torch.utils._pytree import PyTree, tree_flatten from torch.utils.hooks import RemovableHandle +from torchtnt.utils.flops import FlopTensorDispatchMode -from torchtnt.utils.version import is_torch_version_geq_1_13 from typing_extensions import Literal _ATTRIB_TO_COL_HEADER = { @@ -244,35 +244,25 @@ def _get_module_flops_and_activation_sizes( module_kwargs = module_kwargs or {} flops_forward = None flops_backward = None - if not is_torch_version_geq_1_13(): - warnings.warn( - "Please install PyTorch 1.13 or higher to compute FLOPs: https://pytorch.org/get-started/locally/" - ) - module(*module_args, **module_kwargs) + + with FlopTensorDispatchMode(module) as ftdm: + # Count for forward flops (+ compute activation sizes) + res = module(*module_args, **module_kwargs) + # detach activation size hook handles for hook_handle in activation_size_handles: hook_handle.remove() - else: - from torchtnt.utils.flops import FlopTensorDispatchMode - - with FlopTensorDispatchMode(module) as ftdm: - # Count for forward flops (+ compute activation sizes) - res = module(*module_args, **module_kwargs) - - # detach activation size hook handles - for hook_handle in activation_size_handles: - hook_handle.remove() - - flops_forward = copy.deepcopy(ftdm.flop_counts) - if isinstance(res, torch.Tensor): - # Count for backward flops - ftdm.reset() - res.mean().backward() - flops_backward = copy.deepcopy(ftdm.flop_counts) - else: - warnings.warn( - "Backward FLOPs are only computed if module foward returns a tensor." - ) + + flops_forward = copy.deepcopy(ftdm.flop_counts) + if isinstance(res, torch.Tensor): + # Count for backward flops + ftdm.reset() + res.mean().backward() + flops_backward = copy.deepcopy(ftdm.flop_counts) + else: + warnings.warn( + "Backward FLOPs are only computed if module foward returns a tensor." + ) # remove forward time elapsed handles for hook_handle in forward_elapsed_time_handles: diff --git a/torchtnt/utils/oom.py b/torchtnt/utils/oom.py index 3a47c61933..9e65a0bed8 100644 --- a/torchtnt/utils/oom.py +++ b/torchtnt/utils/oom.py @@ -15,7 +15,6 @@ import torch from torchtnt.utils.distributed import get_global_rank from torchtnt.utils.fsspec import get_filesystem -from torchtnt.utils.version import is_torch_version_geq_2_0 logger: logging.Logger = logging.getLogger(__name__) @@ -90,11 +89,6 @@ def log_memory_snapshot(output_dir: str, file_prefix: Optional[str] = None) -> N if not torch.cuda.is_available(): logger.info("CUDA unavailable. Not logging snapshot") return - if not is_torch_version_geq_2_0(): - logger.warning( - "CUDA memory snapshot utilities are unavailable. Not logging snapshot" - ) - return rank = get_global_rank() if file_prefix is None: @@ -134,11 +128,6 @@ def attach_oom_observer(output_dir: str, trace_max_entries: int = 1000000) -> No if not torch.cuda.is_available(): logger.info("CUDA unavailable. Not attaching OOM observer.") return - if not is_torch_version_geq_2_0(): - logger.warning( - "CUDA memory snapshot utilities are unavailable. Not attaching OOM observer." - ) - return torch.cuda.memory._record_memory_history( enabled="all", max_entries=trace_max_entries diff --git a/torchtnt/utils/prepare_module.py b/torchtnt/utils/prepare_module.py index 8f5f1cf529..702bd05e1b 100644 --- a/torchtnt/utils/prepare_module.py +++ b/torchtnt/utils/prepare_module.py @@ -13,10 +13,18 @@ import torch import torch.distributed as dist from torch.distributed import ProcessGroup + +from torch.distributed._composable_state import _get_module_state +from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( + apply_activation_checkpointing, + checkpoint_wrapper, + CheckpointImpl, +) from torch.distributed.fsdp import ( FullyShardedDataParallel as FSDP, StateDictType as _StateDictType, ) +from torch.distributed.fsdp._common_utils import _FSDPState from torch.distributed.fsdp.api import OptimStateDictConfig, StateDictConfig from torch.distributed.fsdp.fully_sharded_data_parallel import ( BackwardPrefetch as _BackwardPrefetch, @@ -33,20 +41,7 @@ ) from torchtnt.utils.rank_zero_log import rank_zero_warn -from torchtnt.utils.version import ( - is_torch_version_geq_1_12, - is_torch_version_geq_2_0, - is_torch_version_geq_2_1, -) - -if is_torch_version_geq_2_0(): - from torch.distributed._composable_state import _get_module_state - from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( - apply_activation_checkpointing, - checkpoint_wrapper, - CheckpointImpl, - ) - from torch.distributed.fsdp._common_utils import _FSDPState +from torchtnt.utils.version import is_torch_version_geq_2_1 @dataclass @@ -237,10 +232,6 @@ def prepare_fsdp( device = torch.device("cuda") fsdp_module = prepare_fsdp(module, device, strategy) """ - if not is_torch_version_geq_1_12(): - raise RuntimeError( - "Please install PyTorch 1.12 or higher to use FSDP: https://pytorch.org/get-started/locally/" - ) strategy = strategy if strategy is not None else FSDPStrategy() # we use __dict__ and not asdict() here because asdict() is recursively applied on nested objects @@ -291,11 +282,10 @@ def _is_fsdp_module(module: torch.nn.Module) -> bool: if isinstance(module, FSDP): return True - if is_torch_version_geq_2_0(): - # Also check for composable FSDP API - maybe_composable_state = _get_module_state(module) - if maybe_composable_state is not None: - return isinstance(maybe_composable_state, _FSDPState) + # Also check for composable FSDP API + maybe_composable_state = _get_module_state(module) + if maybe_composable_state is not None: + return isinstance(maybe_composable_state, _FSDPState) return False @@ -359,8 +349,6 @@ def prepare_module( module = module.to(device) if activation_checkpoint_params: - if not is_torch_version_geq_2_0(): - raise RuntimeError("Activation checkpointing requires torch>=2.0") checkpoint_impl = activation_checkpoint_params.checkpoint_impl check_fn = activation_checkpoint_params.check_fn auto_wrap_policy = activation_checkpoint_params.auto_wrap_policy diff --git a/torchtnt/utils/version.py b/torchtnt/utils/version.py index 4e56cb8abc..c140361826 100644 --- a/torchtnt/utils/version.py +++ b/torchtnt/utils/version.py @@ -56,41 +56,5 @@ def get_torch_version() -> Version: return pkg_version -def is_torch_version_geq_1_8() -> bool: - return get_torch_version() >= Version("1.8.0") - - -def is_torch_version_geq_1_9() -> bool: - return get_torch_version() >= Version("1.9.0") - - -def is_torch_version_geq_1_10() -> bool: - return get_torch_version() >= Version("1.10.0") - - -def is_torch_version_geq_1_11() -> bool: - return get_torch_version() >= Version("1.11.0") - - -def is_torch_version_geq_1_12() -> bool: - return get_torch_version() >= Version("1.12.0") - - -def is_torch_version_geq_1_13() -> bool: - return get_torch_version() >= Version("1.13.0") - - -def is_torch_version_ge_1_13_1() -> bool: - return get_torch_version() > Version("1.13.1") - - -def is_torch_version_geq_1_14() -> bool: - return get_torch_version() >= Version("1.14.0") - - -def is_torch_version_geq_2_0() -> bool: - return get_torch_version() >= Version("2.0.0") - - def is_torch_version_geq_2_1() -> bool: return get_torch_version() >= Version("2.1.0")