Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove support for deprecated DCP APIs in DCPSaver callback #890

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 3 additions & 8 deletions tests/framework/callbacks/test_dcp_saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,18 +7,11 @@

# pyre-strict

import unittest

from torchtnt.framework.callbacks.dcp_saver import _LATEST_DCP_AVAIL
from torchtnt.framework.state import State

if not _LATEST_DCP_AVAIL:
raise unittest.SkipTest("Latest Pytorch is required to run DCP tests")

import math
import os
import shutil
import tempfile
import unittest
from typing import Any, Dict, Iterator, List, Optional
from unittest import mock
from unittest.mock import MagicMock, patch
Expand All @@ -40,6 +33,8 @@
)
from torchtnt.framework.callbacks.checkpointer_types import KnobOptions, RestoreOptions
from torchtnt.framework.callbacks.dcp_saver import DistributedCheckpointSaver

from torchtnt.framework.state import State
from torchtnt.framework.train import train
from torchtnt.utils.distributed import get_global_rank, spawn_multi_process
from torchtnt.utils.env import seed
Expand Down
33 changes: 5 additions & 28 deletions torchtnt/framework/callbacks/base_checkpointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import abc
import logging
from datetime import timedelta
from typing import Any, cast, Dict, Iterable, List, Literal, Optional, Union
from typing import Any, cast, Iterable, List, Literal, Optional, Union

import torch.distributed as dist
from pyre_extensions import none_throws
Expand All @@ -21,7 +21,6 @@
from torchtnt.utils.checkpoint import (
BestCheckpointConfig,
CheckpointManager,
CheckpointPath,
get_best_checkpoint_path,
get_latest_checkpoint_path,
MetricData,
Expand Down Expand Up @@ -172,7 +171,7 @@ def _generate_checkpoint_and_upkeep(
value=metric_value,
)

checkpoint_path = self._generate_checkpoint_path(
checkpoint_path = self._checkpoint_manager.generate_checkpoint_path(
epoch,
step_mapping,
metric_data,
Expand All @@ -185,7 +184,9 @@ def _generate_checkpoint_and_upkeep(

if hook == "on_train_end":
# 2.1) Make sure that last checkpoint does not already exist
if self._does_checkpoint_exist(checkpoint_path, self._process_group):
if self._checkpoint_manager.does_checkpoint_exist(
checkpoint_path, self._process_group
):
rank_zero_warn(
"Final checkpoint already exists, skipping.", logger=logger
)
Expand Down Expand Up @@ -220,30 +221,6 @@ def _generate_checkpoint_and_upkeep(

return True

def _does_checkpoint_exist(
self,
checkpoint_path: CheckpointPath,
process_group: Optional[dist.ProcessGroup] = None,
) -> bool:
# Only keep this function as a hook for downstream checkpointer
return self._checkpoint_manager.does_checkpoint_exist(
checkpoint_path, process_group
)

def _generate_checkpoint_path(
self,
epoch: int,
step_mapping: Union[int, Dict[Phase, int]],
metric_data: Optional[MetricData] = None,
process_group: Optional[dist.ProcessGroup] = None,
) -> CheckpointPath:
return self._checkpoint_manager.generate_checkpoint_path(
epoch,
step_mapping,
metric_data,
process_group=process_group,
)

def _get_tracked_metric_value(
self, unit: Union[TTrainUnit, TEvalUnit]
) -> Optional[float]:
Expand Down
174 changes: 69 additions & 105 deletions torchtnt/framework/callbacks/dcp_saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,13 @@

import torch
import torch.distributed as dist
from pyre_extensions import none_throws
from torch.distributed import checkpoint as dcp

from torch.distributed.checkpoint._fsspec_filesystem import (
FsspecReader as Reader,
FsspecWriter as Writer,
)
from torch.distributed.checkpoint.default_planner import (
DefaultLoadPlanner,
DefaultSavePlanner,
Expand All @@ -39,35 +45,13 @@
TTrainUnit,
)
from torchtnt.framework.utils import get_timing_context
from torchtnt.utils.checkpoint import (
BestCheckpointConfig,
CheckpointPath,
MetricData,
Phase,
)
from torchtnt.utils.checkpoint import BestCheckpointConfig
from torchtnt.utils.optimizer import init_optim_state
from torchtnt.utils.rank_zero_log import rank_zero_info, rank_zero_warn
from torchtnt.utils.stateful import MultiStateful, Stateful


logger: logging.Logger = logging.getLogger(__name__)

_LATEST_DCP_AVAIL: bool = True
try:
from torch.distributed.checkpoint._fsspec_filesystem import (
FsspecReader as Reader,
FsspecWriter as Writer,
)
except ModuleNotFoundError:
logger.warn(
"To use FsspecReader / FsspecWriter, please install latest pytorch version"
)
_LATEST_DCP_AVAIL = False
from torch.distributed.checkpoint import (
FileSystemReader as Reader,
FileSystemWriter as Writer,
)


class DistributedCheckpointSaver(BaseCheckpointer):
"""
Expand Down Expand Up @@ -165,7 +149,7 @@ def _checkpoint_impl(
checkpoint_id, app_state, planner, storage_writer
)
if curr_snapshot_wait:
self._wait()
self._wait(log_warning=False)
else:
with get_timing_context(state, f"{self.__class__.__name__}.save"):
checkpoint_success = self._save(
Expand All @@ -174,9 +158,42 @@ def _checkpoint_impl(

return checkpoint_success

def _wait(self) -> None:
if self._prev_snapshot is not None:
self._prev_snapshot.result()
def _wait(self, log_warning: bool = True) -> None:
"""
If the previous async checkpoint is still running, wait for it to finish before continuing. Otherwise,
distributed collectives that use the checkpointing process group will result in a stuck job. This also
computes and logs the time spent waiting on the previous checkpoint to finish, and a toggable warning
for the user to modify checkpointing frequency.

If the previous checkpoing has already finished, this is a no-op.

Args:
log_warning: Toggle for logging a warning to the user to modify checkpointing frequency. Sometimes
this is not up to the user (e.g. on_exception, on_train_end).
"""
if self._prev_snapshot is None:
return

if self._prev_snapshot.done():
none_throws(self._prev_snapshot).result()
return

if log_warning:
rank_zero_warn(
(
"Waiting on previous checkpoint to finish... Consider modifying checkpointing "
f"frequency if this is an issue. Current value (current {self._save_every_n_train_steps})"
),
logger=logger,
)

t0 = time.monotonic()
none_throws(self._prev_snapshot).result()

rank_zero_warn(
f"Waiting on previous checkpoint for {time.monotonic()-t0:.3f} seconds",
logger=logger,
)

def _async_save(
self,
Expand All @@ -192,23 +209,8 @@ def _async_save(
if storage_writer is None:
storage_writer = Writer(checkpoint_id, **self.default_writer_options)

if self._prev_snapshot is not None:
if not self._prev_snapshot.done():
rank_zero_warn(
(
"Waiting on previous checkpoint to finish... Consider modifying checkpointing "
f"frequency if this is an issue. Current value (current {self._save_every_n_train_steps})"
),
logger=logger,
)
t0 = time.monotonic()
self._wait()
rank_zero_warn(
f"Waiting on previous checkpoint for {time.monotonic()-t0:.3f} seconds",
logger=logger,
)
else:
self._wait()
# Redundant check for safety
self._wait(log_warning=True)

self._prev_snapshot = dcp.async_save(
state_dict={"app_state": MultiStateful(app_state)},
Expand All @@ -234,24 +236,13 @@ def _save(
if storage_writer is None:
storage_writer = Writer(checkpoint_id, **self.default_writer_options)

try:
dcp.save(
state_dict={"app_state": MultiStateful(app_state)},
checkpoint_id=checkpoint_id,
process_group=self._process_group,
storage_writer=storage_writer,
planner=planner,
)
except AttributeError as ex:
logger.warning(
f"Unable to save checkpoint (will retry saving using deprecated API). Error: {ex}"
)
dcp.save_state_dict(
state_dict={"app_state": MultiStateful(app_state)},
process_group=self._process_group,
storage_writer=storage_writer,
planner=planner,
)
dcp.save(
state_dict={"app_state": MultiStateful(app_state)},
checkpoint_id=checkpoint_id,
process_group=self._process_group,
storage_writer=storage_writer,
planner=planner,
)

return True

Expand All @@ -261,7 +252,8 @@ def on_exception(
unit: Union[TTrainUnit, TEvalUnit, TPredictUnit],
exc: BaseException,
) -> None:
self._wait()
rank_zero_info("Ensuring previous async checkpoint finished before exiting.")
self._wait(log_warning=False)

@staticmethod
def restore(
Expand Down Expand Up @@ -382,55 +374,27 @@ def restore_with_id(
if isinstance(optimizer, torch.optim.Optimizer):
init_optim_state(optimizer)

try:
dcp.load(
{"app_state": MultiStateful(app_state)},
checkpoint_id=checkpoint_id,
storage_reader=storage_reader,
planner=planner,
process_group=process_group,
)
except AttributeError:
dcp.load_state_dict(
{"app_state": MultiStateful(app_state)},
storage_reader=storage_reader,
process_group=process_group,
planner=planner,
)
dcp.load(
{"app_state": MultiStateful(app_state)},
checkpoint_id=checkpoint_id,
storage_reader=storage_reader,
planner=planner,
process_group=process_group,
)

rank_zero_info(
f"Restored snapshot for checkpoint_id: {checkpoint_id}", logger=logger
)

def _does_checkpoint_exist(
self,
checkpoint_path: CheckpointPath,
process_group: Optional[dist.ProcessGroup] = None,
def _generate_checkpoint_and_upkeep(
self, state: State, unit: Union[TTrainUnit, TEvalUnit], hook: str
) -> bool:
# if we are still checkpointing, this might cause a collective hang.
# so wait here instead
# if we are still checkpointing, this might cause a collective hang, since several
# operations in the base class use the process group. So wait here instead.
self._wait()

return super()._does_checkpoint_exist(
checkpoint_path=checkpoint_path, process_group=process_group
)

def _generate_checkpoint_path(
self,
epoch: int,
step_mapping: Union[int, Dict[Phase, int]],
metric_data: Optional[MetricData] = None,
process_group: Optional[dist.ProcessGroup] = None,
) -> CheckpointPath:
# if we are still checkpointing, this might cause a collective hang.
# so wait here instead
self._wait()

return super()._generate_checkpoint_path(
epoch=epoch,
step_mapping=step_mapping,
metric_data=metric_data,
process_group=process_group,
)
# Note that every async checkpoint will be completed at this point.
return super()._generate_checkpoint_and_upkeep(state, unit, hook)

@property
def default_writer_options(self) -> Dict[str, Any]:
Expand Down
2 changes: 1 addition & 1 deletion torchtnt/utils/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ def _is_phase_aware(self) -> bool:
def newer_than(self, other: "CheckpointPath") -> bool:
"""
Given another CheckpointPath instance, determine if this checkpoint is strictly newer than the other.
Note that recency is determine in terms of the epoch, phase, and number of steps. It is NOT related to
Note that recency is determined in terms of the epoch, phase, and number of steps. It is NOT related to
the timestamp the checkpoint was saved.

Returns:
Expand Down
Loading