Skip to content

Commit

Permalink
make DCPSaver OSS compatible
Browse files Browse the repository at this point in the history
Summary: This diff makes the dcp saver OSS compatible (with any pytorch stable version >= 2.0.0).

Differential Revision: D56537398
  • Loading branch information
JKSenthil authored and facebook-github-bot committed Apr 25, 2024
1 parent 37b1070 commit e27fce7
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 14 deletions.
8 changes: 7 additions & 1 deletion tests/framework/callbacks/test_dcp_saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,18 @@
generate_random_dataloader,
)
from torchtnt.framework.callbacks.checkpointer_types import KnobOptions, RestoreOptions
from torchtnt.framework.callbacks.dcp_saver import DistributedCheckpointSaver
from torchtnt.framework.callbacks.dcp_saver import (
_LATEST_DCP_AVAIL,
DistributedCheckpointSaver,
)
from torchtnt.framework.train import train
from torchtnt.utils.distributed import get_global_rank, spawn_multi_process
from torchtnt.utils.env import seed
from torchtnt.utils.test_utils import skip_if_not_distributed

if not _LATEST_DCP_AVAIL:
raise unittest.SkipTest("Latest Pytorch is required to run DCP tests")


class DistributedCheckpointSaverTest(unittest.TestCase):
def test_save_restore(self) -> None:
Expand Down
55 changes: 42 additions & 13 deletions torchtnt/framework/callbacks/dcp_saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
import torch.distributed as dist
from torch.distributed import checkpoint as dcp

from torch.distributed.checkpoint._fsspec_filesystem import FsspecReader, FsspecWriter
from torchtnt.framework.callbacks._checkpoint_utils import (
_prepare_app_state_for_checkpoint,
_prepare_app_state_for_restore,
Expand Down Expand Up @@ -44,6 +43,22 @@

logger: logging.Logger = logging.getLogger(__name__)

_LATEST_DCP_AVAIL: bool = True
try:
from torch.distributed.checkpoint._fsspec_filesystem import (
FsspecReader as Reader,
FsspecWriter as Writer,
)
except ModuleNotFoundError:
logger.warn(
"To use FsspecReader / FsspecWriter, please install latest pytorch version"
)
_LATEST_DCP_AVAIL = False
from torch.distributed.checkpoint import (
FileSystemReader as Reader,
FileSystemWriter as Writer,
)


class DistributedCheckpointSaver(BaseCheckpointer):
"""
Expand Down Expand Up @@ -166,17 +181,24 @@ def _async_save(self, checkpoint_id: str, app_state: Dict[str, Stateful]) -> boo
self._prev_snapshot = dcp.async_save(
state_dict={"app_state": MultiStateful(app_state)},
process_group=self._process_group,
storage_writer=FsspecWriter(checkpoint_id, **self.default_writer_options),
storage_writer=Writer(checkpoint_id, **self.default_writer_options),
)

return True

def _save(self, checkpoint_id: str, app_state: Dict[str, Stateful]) -> bool:
dcp.save(
state_dict={"app_state": MultiStateful(app_state)},
process_group=self._process_group,
storage_writer=FsspecWriter(checkpoint_id, **self.default_writer_options),
)
try:
dcp.save(
state_dict={"app_state": MultiStateful(app_state)},
process_group=self._process_group,
storage_writer=Writer(checkpoint_id, **self.default_writer_options),
)
except AttributeError:
dcp.save_state_dict(
state_dict={"app_state": MultiStateful(app_state)},
process_group=self._process_group,
storage_writer=Writer(checkpoint_id, **self.default_writer_options),
)

return True

Expand Down Expand Up @@ -217,7 +239,7 @@ def restore(
"Ignoring `knob_options` which was passed to DistributedCheckpointSaver.restore, but is not supported."
)

storage_reader = FsspecReader(path)
storage_reader = Reader(path)

restore_options = restore_options or RestoreOptions()
app_state = _prepare_app_state_for_restore(unit, restore_options)
Expand Down Expand Up @@ -250,11 +272,18 @@ def restore(
if isinstance(optimizer, torch.optim.Optimizer):
init_optim_state(optimizer)

dcp.load(
{"app_state": MultiStateful(app_state)},
storage_reader=storage_reader,
process_group=process_group,
)
try:
dcp.load(
{"app_state": MultiStateful(app_state)},
storage_reader=storage_reader,
process_group=process_group,
)
except AttributeError:
dcp.load_state_dict(
{"app_state": MultiStateful(app_state)},
storage_reader=storage_reader,
process_group=process_group,
)
rank_zero_info(f"Restored snapshot from path: {path}", logger=logger)

def _does_checkpoint_exist(
Expand Down

0 comments on commit e27fce7

Please sign in to comment.