Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

make DCPSaver OSS compatible #806

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion tests/framework/callbacks/test_dcp_saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,18 @@
generate_random_dataloader,
)
from torchtnt.framework.callbacks.checkpointer_types import KnobOptions, RestoreOptions
from torchtnt.framework.callbacks.dcp_saver import DistributedCheckpointSaver
from torchtnt.framework.callbacks.dcp_saver import (
_LATEST_DCP_AVAIL,
DistributedCheckpointSaver,
)
from torchtnt.framework.train import train
from torchtnt.utils.distributed import get_global_rank, spawn_multi_process
from torchtnt.utils.env import seed
from torchtnt.utils.test_utils import skip_if_not_distributed

if not _LATEST_DCP_AVAIL:
raise unittest.SkipTest("Latest Pytorch is required to run DCP tests")


class DistributedCheckpointSaverTest(unittest.TestCase):
def test_save_restore(self) -> None:
Expand Down
55 changes: 42 additions & 13 deletions torchtnt/framework/callbacks/dcp_saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
import torch.distributed as dist
from torch.distributed import checkpoint as dcp

from torch.distributed.checkpoint._fsspec_filesystem import FsspecReader, FsspecWriter
from torchtnt.framework.callbacks._checkpoint_utils import (
_prepare_app_state_for_checkpoint,
_prepare_app_state_for_restore,
Expand Down Expand Up @@ -44,6 +43,22 @@

logger: logging.Logger = logging.getLogger(__name__)

_LATEST_DCP_AVAIL: bool = True
try:
from torch.distributed.checkpoint._fsspec_filesystem import (
FsspecReader as Reader,
FsspecWriter as Writer,
)
except ModuleNotFoundError:
logger.warn(
"To use FsspecReader / FsspecWriter, please install latest pytorch version"
)
_LATEST_DCP_AVAIL = False
from torch.distributed.checkpoint import (
FileSystemReader as Reader,
FileSystemWriter as Writer,
)


class DistributedCheckpointSaver(BaseCheckpointer):
"""
Expand Down Expand Up @@ -166,17 +181,24 @@ def _async_save(self, checkpoint_id: str, app_state: Dict[str, Stateful]) -> boo
self._prev_snapshot = dcp.async_save(
state_dict={"app_state": MultiStateful(app_state)},
process_group=self._process_group,
storage_writer=FsspecWriter(checkpoint_id, **self.default_writer_options),
storage_writer=Writer(checkpoint_id, **self.default_writer_options),
)

return True

def _save(self, checkpoint_id: str, app_state: Dict[str, Stateful]) -> bool:
dcp.save(
state_dict={"app_state": MultiStateful(app_state)},
process_group=self._process_group,
storage_writer=FsspecWriter(checkpoint_id, **self.default_writer_options),
)
try:
dcp.save(
state_dict={"app_state": MultiStateful(app_state)},
process_group=self._process_group,
storage_writer=Writer(checkpoint_id, **self.default_writer_options),
)
except AttributeError:
dcp.save_state_dict(
state_dict={"app_state": MultiStateful(app_state)},
process_group=self._process_group,
storage_writer=Writer(checkpoint_id, **self.default_writer_options),
)

return True

Expand Down Expand Up @@ -217,7 +239,7 @@ def restore(
"Ignoring `knob_options` which was passed to DistributedCheckpointSaver.restore, but is not supported."
)

storage_reader = FsspecReader(path)
storage_reader = Reader(path)

restore_options = restore_options or RestoreOptions()
app_state = _prepare_app_state_for_restore(unit, restore_options)
Expand Down Expand Up @@ -250,11 +272,18 @@ def restore(
if isinstance(optimizer, torch.optim.Optimizer):
init_optim_state(optimizer)

dcp.load(
{"app_state": MultiStateful(app_state)},
storage_reader=storage_reader,
process_group=process_group,
)
try:
dcp.load(
{"app_state": MultiStateful(app_state)},
storage_reader=storage_reader,
process_group=process_group,
)
except AttributeError:
dcp.load_state_dict(
{"app_state": MultiStateful(app_state)},
storage_reader=storage_reader,
process_group=process_group,
)
rank_zero_info(f"Restored snapshot from path: {path}", logger=logger)

def _does_checkpoint_exist(
Expand Down
Loading