Skip to content

Commit

Permalink
enfore pytorch >= 2.0
Browse files Browse the repository at this point in the history
Differential Revision: D56446353
  • Loading branch information
JKSenthil authored and facebook-github-bot committed Apr 25, 2024
1 parent e7b9e64 commit 0f7ab29
Show file tree
Hide file tree
Showing 17 changed files with 44 additions and 293 deletions.
9 changes: 0 additions & 9 deletions tests/framework/callbacks/test_module_summary.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,6 @@

from torchtnt.framework.callbacks.module_summary import ModuleSummary
from torchtnt.framework.state import EntryPoint, PhaseState, State
from torchtnt.utils.version import is_torch_version_geq_1_13

MODULE_SUMMARY_FLOPS_AVAILABLE = False
if is_torch_version_geq_1_13():
MODULE_SUMMARY_FLOPS_AVAILABLE = True


class ModuleSummaryTest(unittest.TestCase):
Expand Down Expand Up @@ -85,10 +80,6 @@ def forward(self, x):
self.assertTrue("b1" in ms.submodule_summaries)
self.assertTrue("l2" in ms.submodule_summaries)

@unittest.skipUnless(
condition=MODULE_SUMMARY_FLOPS_AVAILABLE,
reason="This test needs PyTorch 1.13 or greater to run.",
)
def test_module_summary_retrieve_module_summaries_module_inputs(self) -> None:
"""
Test ModuleSummary callback in train
Expand Down
11 changes: 2 additions & 9 deletions tests/framework/test_auto_unit.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,6 @@
from unittest.mock import MagicMock, patch

import torch
from torchtnt.framework.auto_unit import TrainStepResults
from torchtnt.utils.test_utils import skip_if_not_distributed

from torchtnt.utils.version import is_torch_version_geq_1_13

COMPILE_AVAIL = False
if is_torch_version_geq_1_13():
COMPILE_AVAIL = True
import torch._dynamo

from pyre_extensions import none_throws, ParameterSpecification as ParamSpec

Expand All @@ -37,6 +28,7 @@
AutoUnit,
SWALRParams,
SWAParams,
TrainStepResults,
)
from torchtnt.framework.evaluate import evaluate
from torchtnt.framework.predict import predict
Expand All @@ -49,6 +41,7 @@
from torchtnt.utils.lr_scheduler import TLRScheduler
from torchtnt.utils.prepare_module import DDPStrategy
from torchtnt.utils.progress import Progress
from torchtnt.utils.test_utils import skip_if_not_distributed
from torchtnt.utils.timer import Timer

TParams = ParamSpec("TParams")
Expand Down
19 changes: 4 additions & 15 deletions tests/framework/test_auto_unit_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,24 +8,16 @@
# pyre-strict

import unittest

from copy import deepcopy
from typing import TypeVar
from unittest.mock import MagicMock, patch

import torch
from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler
from torchtnt.utils.test_utils import skip_if_not_distributed, skip_if_not_gpu

from torchtnt.utils.version import is_torch_version_geq_1_13

COMPILE_AVAIL = False
if is_torch_version_geq_1_13():
COMPILE_AVAIL = True
import torch._dynamo

from copy import deepcopy

from pyre_extensions import ParameterSpecification as ParamSpec
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler
from torchtnt.framework._test_utils import (
DummyAutoUnit,
generate_random_dataloader,
Expand All @@ -40,6 +32,7 @@
from torchtnt.utils.distributed import spawn_multi_process
from torchtnt.utils.env import init_from_env, seed
from torchtnt.utils.prepare_module import DDPStrategy, FSDPStrategy, TorchCompileParams
from torchtnt.utils.test_utils import skip_if_not_distributed, skip_if_not_gpu

TParams = ParamSpec("TParams")
T = TypeVar("T")
Expand Down Expand Up @@ -320,10 +313,6 @@ def test_predict_mixed_precision_fp16(self, mock_autocast: MagicMock) -> None:
device_type="cuda", dtype=torch.float16, enabled=True
)

@unittest.skipUnless(
condition=COMPILE_AVAIL,
reason="This test needs PyTorch 1.13 or greater to run.",
)
@skip_if_not_gpu
@patch("torch.compile")
def test_compile_predict(self, mock_dynamo: MagicMock) -> None:
Expand Down
8 changes: 0 additions & 8 deletions tests/utils/test_memory_snapshot_profiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,17 +14,9 @@
MemorySnapshotParams,
MemorySnapshotProfiler,
)
from torchtnt.utils.version import is_torch_version_geq_2_0


class MemorySnapshotProfilerTest(unittest.TestCase):

torch_version_geq_2_0: bool = is_torch_version_geq_2_0()

@unittest.skipUnless(
condition=torch_version_geq_2_0,
reason="This test needs changes from PyTorch 2.0 to run.",
)
def test_validation(self) -> None:
"""Test parameter validation."""
with tempfile.TemporaryDirectory() as temp_dir:
Expand Down
8 changes: 0 additions & 8 deletions tests/utils/test_memory_snapshot_profiler_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,18 +18,10 @@
MemorySnapshotProfiler,
)
from torchtnt.utils.test_utils import skip_if_not_gpu
from torchtnt.utils.version import is_torch_version_geq_2_0


class MemorySnapshotProfilerGPUTest(unittest.TestCase):

torch_version_geq_2_0: bool = is_torch_version_geq_2_0()

@skip_if_not_gpu
@unittest.skipUnless(
condition=torch_version_geq_2_0,
reason="This test needs changes from PyTorch 2.0 to run.",
)
def test_stop_step(self) -> None:
"""Test that a memory snapshot is saved when stop_step is reached."""
with tempfile.TemporaryDirectory() as temp_dir:
Expand Down
5 changes: 0 additions & 5 deletions tests/utils/test_oom_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,10 @@
from torchtnt.utils.oom import log_memory_snapshot

from torchtnt.utils.test_utils import skip_if_not_gpu
from torchtnt.utils.version import is_torch_version_geq_2_0


class OomGPUTest(unittest.TestCase):
@skip_if_not_gpu
@unittest.skipUnless(
condition=bool(is_torch_version_geq_2_0()),
reason="This test needs changes from PyTorch 2.0 to run.",
)
def test_log_memory_snapshot(self) -> None:
with tempfile.TemporaryDirectory() as temp_dir:
# Record history
Expand Down
15 changes: 1 addition & 14 deletions tests/utils/test_prepare_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,7 @@
TorchCompileParams,
)
from torchtnt.utils.test_utils import skip_if_not_distributed
from torchtnt.utils.version import is_torch_version_geq_1_13, Version

COMPILE_AVAIL = False
if is_torch_version_geq_1_13():
COMPILE_AVAIL = True
import torch._dynamo
from torchtnt.utils.version import Version


class PrepareModelTest(unittest.TestCase):
Expand Down Expand Up @@ -170,10 +165,6 @@ def _test_prepare_module_ddp_throws_with_compile_params_and_static_graph() -> No
torch_compile_params=TorchCompileParams(backend="inductor"),
)

@unittest.skipUnless(
condition=COMPILE_AVAIL,
reason="This test needs PyTorch 1.13 or greater to run.",
)
def test_prepare_module_compile_invalid_backend(self) -> None:
"""
verify error is thrown on invalid backend
Expand All @@ -199,10 +190,6 @@ def test_prepare_module_incompatible_FSDP_torchcompile_params(self) -> None:
torch_compile_params=TorchCompileParams(),
)

@unittest.skipUnless(
condition=COMPILE_AVAIL,
reason="This test needs PyTorch 1.13 or greater to run.",
)
def test_prepare_module_compile_module_state_dict(self) -> None:
device = init_from_env()
my_module = torch.nn.Linear(2, 2, device=device)
Expand Down
44 changes: 4 additions & 40 deletions tests/utils/test_prepare_module_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,10 @@

# pyre-strict
import unittest
from unittest.mock import patch

import torch

from torch.distributed._composable import fully_shard
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.fully_sharded_data_parallel import MixedPrecision
from torch.nn.parallel import DistributedDataParallel as DDP
Expand All @@ -24,15 +25,6 @@
prepare_module,
)
from torchtnt.utils.test_utils import skip_if_not_distributed, skip_if_not_gpu
from torchtnt.utils.version import is_torch_version_geq_1_13, is_torch_version_geq_2_0

COMPILE_AVAIL = False
if is_torch_version_geq_1_13():
COMPILE_AVAIL = True
import torch._dynamo

if is_torch_version_geq_2_0():
from torch.distributed._composable import fully_shard


class PrepareModelGPUTest(unittest.TestCase):
Expand Down Expand Up @@ -81,33 +73,6 @@ def _test_prepare_fsdp() -> None:
tc = unittest.TestCase()
tc.assertTrue(isinstance(fsdp_module, FSDP))

@skip_if_not_distributed
@skip_if_not_gpu
def test_fsdp_pytorch_version(self) -> None:
"""
Test that a RuntimeError is thrown when using FSDP, and PyTorch < v1.12
"""
spawn_multi_process(
2,
"nccl",
self._test_fsdp_pytorch_version,
)

@staticmethod
def _test_fsdp_pytorch_version() -> None:
device = init_from_env()
module = torch.nn.Linear(2, 2).to(device)

tc = unittest.TestCase()
with patch(
"torchtnt.utils.prepare_module.is_torch_version_geq_1_12",
return_value=False,
), tc.assertRaisesRegex(
RuntimeError,
"Please install PyTorch 1.12 or higher to use FSDP: https://pytorch.org/get-started/locally/",
):
_ = prepare_fsdp(module, device, FSDPStrategy())

@skip_if_not_distributed
@unittest.skipUnless(
condition=bool(torch.cuda.device_count() >= 2),
Expand All @@ -128,9 +93,8 @@ def _test_is_fsdp_module() -> None:
model = FSDP(torch.nn.Linear(1, 1, device=device))
assert _is_fsdp_module(model)
model = torch.nn.Linear(1, 1, device=device)
if is_torch_version_geq_2_0():
fully_shard(model)
assert _is_fsdp_module(model)
fully_shard(model)
assert _is_fsdp_module(model)

@skip_if_not_distributed
@skip_if_not_gpu
Expand Down
45 changes: 1 addition & 44 deletions tests/utils/test_version.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,48 +48,5 @@ def test_get_torch_version(self) -> None:
self.assertEqual(version.get_torch_version(), Version("1.12.0"))

def test_torch_version_comparators(self) -> None:
with patch.object(torch, "__version__", "1.7.0"):
self.assertFalse(version.is_torch_version_geq_1_8())
self.assertFalse(version.is_torch_version_geq_1_9())
self.assertFalse(version.is_torch_version_geq_1_10())
self.assertFalse(version.is_torch_version_geq_1_11())
self.assertFalse(version.is_torch_version_geq_1_12())

with patch.object(torch, "__version__", "1.8.0"):
self.assertTrue(version.is_torch_version_geq_1_8())
self.assertFalse(version.is_torch_version_geq_1_9())
self.assertFalse(version.is_torch_version_geq_1_10())
self.assertFalse(version.is_torch_version_geq_1_11())
self.assertFalse(version.is_torch_version_geq_1_12())

with patch.object(torch, "__version__", "1.9.0"):
self.assertTrue(version.is_torch_version_geq_1_8())
self.assertTrue(version.is_torch_version_geq_1_9())
self.assertFalse(version.is_torch_version_geq_1_10())
self.assertFalse(version.is_torch_version_geq_1_11())
self.assertFalse(version.is_torch_version_geq_1_12())

with patch.object(torch, "__version__", "1.10.0"):
self.assertTrue(version.is_torch_version_geq_1_8())
self.assertTrue(version.is_torch_version_geq_1_9())
self.assertTrue(version.is_torch_version_geq_1_10())
self.assertFalse(version.is_torch_version_geq_1_11())
self.assertFalse(version.is_torch_version_geq_1_12())

with patch.object(torch, "__version__", "1.11.0"):
self.assertTrue(version.is_torch_version_geq_1_8())
self.assertTrue(version.is_torch_version_geq_1_9())
self.assertTrue(version.is_torch_version_geq_1_10())
self.assertTrue(version.is_torch_version_geq_1_11())
self.assertFalse(version.is_torch_version_geq_1_12())

with patch.object(torch, "__version__", "1.12.0"):
self.assertTrue(version.is_torch_version_geq_1_8())
self.assertTrue(version.is_torch_version_geq_1_9())
self.assertTrue(version.is_torch_version_geq_1_10())
self.assertTrue(version.is_torch_version_geq_1_11())
self.assertTrue(version.is_torch_version_geq_1_12())

with patch.object(torch, "__version__", "2.0.0a0"):
self.assertTrue(version.is_torch_version_ge_1_13_1())
self.assertFalse(version.is_torch_version_geq_2_0())
self.assertFalse(version.is_torch_version_geq_2_1())
11 changes: 0 additions & 11 deletions torchtnt/framework/auto_unit.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,6 @@
TorchCompileParams,
)
from torchtnt.utils.swa import AveragedModel
from torchtnt.utils.version import is_torch_version_ge_1_13_1
from typing_extensions import Literal


Expand Down Expand Up @@ -166,8 +165,6 @@ def __init__(
torch_compile_params: Optional[TorchCompileParams] = None,
) -> None:
super().__init__()
if torch_compile_params:
_validate_torch_compile_available()

self.device: torch.device = device or init_from_env()
self.precision: Optional[torch.dtype] = (
Expand Down Expand Up @@ -879,11 +876,3 @@ def _update_lr_and_swa(self, state: State, number_of_steps_or_epochs: int) -> No
state, f"{self.__class__.__name__}.lr_scheduler_step"
):
self.step_lr_scheduler()


def _validate_torch_compile_available() -> None:
if not is_torch_version_ge_1_13_1():
raise RuntimeError(
"Torch compile support is available only in PyTorch 2.0 or higher. "
"Please install PyTorch 2.0 or higher to continue: https://pytorch.org/get-started/locally/"
)
18 changes: 0 additions & 18 deletions torchtnt/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,15 +74,6 @@
from .version import (
get_python_version,
get_torch_version,
is_torch_version_ge_1_13_1,
is_torch_version_geq_1_10,
is_torch_version_geq_1_11,
is_torch_version_geq_1_12,
is_torch_version_geq_1_13,
is_torch_version_geq_1_14,
is_torch_version_geq_1_8,
is_torch_version_geq_1_9,
is_torch_version_geq_2_0,
is_torch_version_geq_2_1,
is_windows,
)
Expand Down Expand Up @@ -153,15 +144,6 @@
"TLRScheduler",
"get_python_version",
"get_torch_version",
"is_torch_version_ge_1_13_1",
"is_torch_version_geq_1_10",
"is_torch_version_geq_1_11",
"is_torch_version_geq_1_12",
"is_torch_version_geq_1_13",
"is_torch_version_geq_1_14",
"is_torch_version_geq_1_8",
"is_torch_version_geq_1_9",
"is_torch_version_geq_2_0",
"is_torch_version_geq_2_1",
"is_windows",
"get_pet_launch_config",
Expand Down
Loading

0 comments on commit 0f7ab29

Please sign in to comment.