diff --git a/tests/test_lion8b.py b/tests/test_lion8b.py index 999bf0fa67..7a4cd25982 100644 --- a/tests/test_lion8b.py +++ b/tests/test_lion8b.py @@ -6,15 +6,23 @@ import warnings import numpy as np +import packaging.version as version import pytest import torch import torch.distributed as dist import torch.nn as nn from torch.distributed import fsdp from torch.distributed.fsdp import FullyShardedDataParallel as FSDP -from torch.distributed.fsdp.api import ( # type:ignore .api not in public API - FullOptimStateDictConfig, LocalOptimStateDictConfig, - ShardedOptimStateDictConfig) + +if version.parse(torch.__version__) >= version.parse('2.0.1'): + from torch.distributed.fsdp.api import ( # type:ignore .api not in public API + FullOptimStateDictConfig, LocalOptimStateDictConfig, + ShardedOptimStateDictConfig) +else: + from unittest.mock import MagicMock # for pyright so vars aren't None + FullOptimStateDictConfig = MagicMock() + LocalOptimStateDictConfig = MagicMock() + ShardedOptimStateDictConfig = MagicMock() from llmfoundry.optim import DecoupledLionW_8bit as Lion8bit @@ -403,6 +411,8 @@ def test_fsdp_save_load(dtype: torch.dtype, use_errors: bool, device = 'cuda' if torch.cuda.device_count() < 2: pytest.skip(f'This test requires 2+ GPUs.') + if version.parse(torch.__version__) < version.parse('2.0.1'): + pytest.skip(f'This test requires torch 2.0.1 or greater.') torch.cuda.set_device(f'cuda:{os.environ["RANK"]}') # needed for fsdp if not dist.is_initialized():