diff --git a/tests/framework/callbacks/test_dcp_saver.py b/tests/framework/callbacks/test_dcp_saver.py index f74b4b5369..a92c7f68cb 100644 --- a/tests/framework/callbacks/test_dcp_saver.py +++ b/tests/framework/callbacks/test_dcp_saver.py @@ -26,12 +26,18 @@ generate_random_dataloader, ) from torchtnt.framework.callbacks.checkpointer_types import KnobOptions, RestoreOptions -from torchtnt.framework.callbacks.dcp_saver import DistributedCheckpointSaver +from torchtnt.framework.callbacks.dcp_saver import ( + _LATEST_DCP_AVAIL, + DistributedCheckpointSaver, +) from torchtnt.framework.train import train from torchtnt.utils.distributed import get_global_rank, spawn_multi_process from torchtnt.utils.env import seed from torchtnt.utils.test_utils import skip_if_not_distributed +if not _LATEST_DCP_AVAIL: + raise unittest.SkipTest("Latest Pytorch is required to run DCP tests") + class DistributedCheckpointSaverTest(unittest.TestCase): def test_save_restore(self) -> None: diff --git a/torchtnt/framework/callbacks/dcp_saver.py b/torchtnt/framework/callbacks/dcp_saver.py index 3984d6fb86..66cd699f50 100644 --- a/torchtnt/framework/callbacks/dcp_saver.py +++ b/torchtnt/framework/callbacks/dcp_saver.py @@ -15,7 +15,6 @@ import torch.distributed as dist from torch.distributed import checkpoint as dcp -from torch.distributed.checkpoint._fsspec_filesystem import FsspecReader, FsspecWriter from torchtnt.framework.callbacks._checkpoint_utils import ( _prepare_app_state_for_checkpoint, _prepare_app_state_for_restore, @@ -44,6 +43,22 @@ logger: logging.Logger = logging.getLogger(__name__) +_LATEST_DCP_AVAIL: bool = True +try: + from torch.distributed.checkpoint._fsspec_filesystem import ( + FsspecReader as Reader, + FsspecWriter as Writer, + ) +except ModuleNotFoundError: + logger.warn( + "To use FsspecReader / FsspecWriter, please install latest pytorch version" + ) + _LATEST_DCP_AVAIL = False + from torch.distributed.checkpoint import ( + FileSystemReader as Reader, + FileSystemWriter as Writer, + ) + class DistributedCheckpointSaver(BaseCheckpointer): """ @@ -166,17 +181,24 @@ def _async_save(self, checkpoint_id: str, app_state: Dict[str, Stateful]) -> boo self._prev_snapshot = dcp.async_save( state_dict={"app_state": MultiStateful(app_state)}, process_group=self._process_group, - storage_writer=FsspecWriter(checkpoint_id, **self.default_writer_options), + storage_writer=Writer(checkpoint_id, **self.default_writer_options), ) return True def _save(self, checkpoint_id: str, app_state: Dict[str, Stateful]) -> bool: - dcp.save( - state_dict={"app_state": MultiStateful(app_state)}, - process_group=self._process_group, - storage_writer=FsspecWriter(checkpoint_id, **self.default_writer_options), - ) + try: + dcp.save( + state_dict={"app_state": MultiStateful(app_state)}, + process_group=self._process_group, + storage_writer=Writer(checkpoint_id, **self.default_writer_options), + ) + except AttributeError: + dcp.save_state_dict( + state_dict={"app_state": MultiStateful(app_state)}, + process_group=self._process_group, + storage_writer=Writer(checkpoint_id, **self.default_writer_options), + ) return True @@ -217,7 +239,7 @@ def restore( "Ignoring `knob_options` which was passed to DistributedCheckpointSaver.restore, but is not supported." ) - storage_reader = FsspecReader(path) + storage_reader = Reader(path) restore_options = restore_options or RestoreOptions() app_state = _prepare_app_state_for_restore(unit, restore_options) @@ -250,11 +272,18 @@ def restore( if isinstance(optimizer, torch.optim.Optimizer): init_optim_state(optimizer) - dcp.load( - {"app_state": MultiStateful(app_state)}, - storage_reader=storage_reader, - process_group=process_group, - ) + try: + dcp.load( + {"app_state": MultiStateful(app_state)}, + storage_reader=storage_reader, + process_group=process_group, + ) + except AttributeError: + dcp.load_state_dict( + {"app_state": MultiStateful(app_state)}, + storage_reader=storage_reader, + process_group=process_group, + ) rank_zero_info(f"Restored snapshot from path: {path}", logger=logger) def _does_checkpoint_exist(