Skip to content

Commit

Permalink
skip fsdp checkpoint test for torch 1.13.1 since...config classes mis…
Browse files Browse the repository at this point in the history
…sing?
  • Loading branch information
dblalock committed Aug 19, 2023
1 parent 5082966 commit fbde16b
Showing 1 changed file with 13 additions and 3 deletions.
16 changes: 13 additions & 3 deletions tests/test_lion8b.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,23 @@
import warnings

import numpy as np
import packaging.version as version
import pytest
import torch
import torch.distributed as dist
import torch.nn as nn
from torch.distributed import fsdp
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.api import ( # type:ignore .api not in public API
FullOptimStateDictConfig, LocalOptimStateDictConfig,
ShardedOptimStateDictConfig)

if version.parse(torch.__version__) >= version.parse('2.0.1'):
from torch.distributed.fsdp.api import ( # type:ignore .api not in public API
FullOptimStateDictConfig, LocalOptimStateDictConfig,
ShardedOptimStateDictConfig)
else:
from unittest.mock import MagicMock # for pyright so vars aren't None
FullOptimStateDictConfig = MagicMock()
LocalOptimStateDictConfig = MagicMock()
ShardedOptimStateDictConfig = MagicMock()

from llmfoundry.optim import DecoupledLionW_8bit as Lion8bit

Expand Down Expand Up @@ -403,6 +411,8 @@ def test_fsdp_save_load(dtype: torch.dtype, use_errors: bool,
device = 'cuda'
if torch.cuda.device_count() < 2:
pytest.skip(f'This test requires 2+ GPUs.')
if version.parse(torch.__version__) < version.parse('2.0.1'):
pytest.skip(f'This test requires torch 2.0.1 or greater.')

torch.cuda.set_device(f'cuda:{os.environ["RANK"]}') # needed for fsdp
if not dist.is_initialized():
Expand Down

0 comments on commit fbde16b

Please sign in to comment.