Skip to content

Commit

Permalink
Remove support for deprecated DCP APIs in DCPSaver callback (#890)
Browse files Browse the repository at this point in the history
Summary: Pull Request resolved: #890

Reviewed By: anshulverma, JKSenthil

Differential Revision: D61887203

fbshipit-source-id: 17bd899a9b88033feb0285f3a395be6edbf82d5a
  • Loading branch information
diego-urgell authored and facebook-github-bot committed Aug 28, 2024
1 parent 3345ed9 commit d3e85dc
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 58 deletions.
11 changes: 3 additions & 8 deletions tests/framework/callbacks/test_dcp_saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,18 +7,11 @@

# pyre-strict

import unittest

from torchtnt.framework.callbacks.dcp_saver import _LATEST_DCP_AVAIL
from torchtnt.framework.state import State

if not _LATEST_DCP_AVAIL:
raise unittest.SkipTest("Latest Pytorch is required to run DCP tests")

import math
import os
import shutil
import tempfile
import unittest
from typing import Any, Dict, Iterator, List, Optional
from unittest import mock
from unittest.mock import MagicMock, patch
Expand All @@ -40,6 +33,8 @@
)
from torchtnt.framework.callbacks.checkpointer_types import KnobOptions, RestoreOptions
from torchtnt.framework.callbacks.dcp_saver import DistributedCheckpointSaver

from torchtnt.framework.state import State
from torchtnt.framework.train import train
from torchtnt.utils.distributed import get_global_rank, spawn_multi_process
from torchtnt.utils.env import seed
Expand Down
70 changes: 20 additions & 50 deletions torchtnt/framework/callbacks/dcp_saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,11 @@
import torch.distributed as dist
from pyre_extensions import none_throws
from torch.distributed import checkpoint as dcp

from torch.distributed.checkpoint._fsspec_filesystem import (
FsspecReader as Reader,
FsspecWriter as Writer,
)
from torch.distributed.checkpoint.default_planner import (
DefaultLoadPlanner,
DefaultSavePlanner,
Expand Down Expand Up @@ -45,25 +50,8 @@
from torchtnt.utils.rank_zero_log import rank_zero_info, rank_zero_warn
from torchtnt.utils.stateful import MultiStateful, Stateful


logger: logging.Logger = logging.getLogger(__name__)

_LATEST_DCP_AVAIL: bool = True
try:
from torch.distributed.checkpoint._fsspec_filesystem import (
FsspecReader as Reader,
FsspecWriter as Writer,
)
except ModuleNotFoundError:
logger.warn(
"To use FsspecReader / FsspecWriter, please install latest pytorch version"
)
_LATEST_DCP_AVAIL = False
from torch.distributed.checkpoint import (
FileSystemReader as Reader,
FileSystemWriter as Writer,
)


class DistributedCheckpointSaver(BaseCheckpointer):
"""
Expand Down Expand Up @@ -248,24 +236,13 @@ def _save(
if storage_writer is None:
storage_writer = Writer(checkpoint_id, **self.default_writer_options)

try:
dcp.save(
state_dict={"app_state": MultiStateful(app_state)},
checkpoint_id=checkpoint_id,
process_group=self._process_group,
storage_writer=storage_writer,
planner=planner,
)
except AttributeError as ex:
logger.warning(
f"Unable to save checkpoint (will retry saving using deprecated API). Error: {ex}"
)
dcp.save_state_dict(
state_dict={"app_state": MultiStateful(app_state)},
process_group=self._process_group,
storage_writer=storage_writer,
planner=planner,
)
dcp.save(
state_dict={"app_state": MultiStateful(app_state)},
checkpoint_id=checkpoint_id,
process_group=self._process_group,
storage_writer=storage_writer,
planner=planner,
)

return True

Expand Down Expand Up @@ -397,21 +374,14 @@ def restore_with_id(
if isinstance(optimizer, torch.optim.Optimizer):
init_optim_state(optimizer)

try:
dcp.load(
{"app_state": MultiStateful(app_state)},
checkpoint_id=checkpoint_id,
storage_reader=storage_reader,
planner=planner,
process_group=process_group,
)
except AttributeError:
dcp.load_state_dict(
{"app_state": MultiStateful(app_state)},
storage_reader=storage_reader,
process_group=process_group,
planner=planner,
)
dcp.load(
{"app_state": MultiStateful(app_state)},
checkpoint_id=checkpoint_id,
storage_reader=storage_reader,
planner=planner,
process_group=process_group,
)

rank_zero_info(
f"Restored snapshot for checkpoint_id: {checkpoint_id}", logger=logger
)
Expand Down

0 comments on commit d3e85dc

Please sign in to comment.