Skip to content

Commit

Permalink
wraps DDP models with DSD (#857)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #857

Distributed State Dict is the current suggested way from PyTorch for ensuring parallelized models state dicts are compatible with save/loads in Single process or re-sharding scenarios.

This diff updates dcp_saver to use DSD for DDP models. A good idea would be wrap all models in TNT with DSD, as this could replace some of the wrapper logic for FSDP and would guarantee future compat.

N5551629 also contains a workaround for current DDP model saved before this diff, by manually removing the "module." prefix in the checkpoint.

Differential Revision: D59234083
  • Loading branch information
LucasLLC authored and facebook-github-bot committed Jul 8, 2024
1 parent 58b6ea7 commit 435b1cb
Showing 1 changed file with 53 additions and 2 deletions.
55 changes: 53 additions & 2 deletions torchtnt/framework/callbacks/dcp_saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
)
from torch.distributed.checkpoint.planner import LoadPlanner, SavePlanner
from torch.distributed.checkpoint.storage import StorageReader, StorageWriter

from torch.nn.parallel import DistributedDataParallel
from torchtnt.framework.callbacks._checkpoint_utils import (
_prepare_app_state_for_checkpoint,
_prepare_app_state_for_restore,
Expand All @@ -41,6 +41,7 @@
from torchtnt.utils.checkpoint import BestCheckpointConfig, CheckpointPath
from torchtnt.utils.optimizer import init_optim_state
from torchtnt.utils.rank_zero_log import rank_zero_info, rank_zero_warn

from torchtnt.utils.stateful import MultiStateful, Stateful


Expand All @@ -62,6 +63,48 @@
FileSystemWriter as Writer,
)

# below code provides BC for PyTorch versions which don't include distributed state dict
# TODO: remove below code once this path is not longer supported
try:
import torch.distributed.checkpoint.state_dict as dsd

# pyre-ignore Incompatible variable type [9]
get_model_state_dict = dsd.get_model_state_dict

def set_model_state_dict(mod: torch.nn.Module, state_dict: Dict[str, Any]) -> None:
return dsd.set_model_state_dict(mod, state_dict)

except ImportError:
logger.warn(
"torch.distributed.checkpoint.state_dict checkpoint is not available, "
"falling back on defaults. Consider updating PyTorch, as this version "
"will not be supported in the future."
)

def get_model_state_dict(mod: torch.nn.Module) -> Dict[str, Any]:
return mod.state_dict()

def set_model_state_dict(mod: torch.nn.Module, state_dict: Dict[str, Any]) -> None:
return mod.load_state_dict(state_dict)


class DSDModelWrapper(Stateful):
"""This wrapper converts state dicts to Distributed State Dicts, essentially generating
state dicts as if they were created using single-device methods. This is useful for
when checkpoint models might be resharded, or loaded in notebooks or otherwise non-distributed
settings.
"""

def __init__(self, mod: torch.nn.Module) -> None:
self.mod: torch.nn.Module = mod

def state_dict(self) -> Dict[str, Any]:
return get_model_state_dict(self.mod)

def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
set_model_state_dict(self.mod, state_dict)


class DistributedCheckpointSaver(BaseCheckpointer):
"""
Expand Down Expand Up @@ -148,6 +191,11 @@ def _checkpoint_impl(
curr_snapshot_wait = hook == "on_train_end"

app_state = _prepare_app_state_for_checkpoint(state, unit, intra_epoch)

for key, obj in app_state.items():
if isinstance(obj, DistributedDataParallel):
app_state[key] = DSDModelWrapper(obj)

# TODO: evaluate whether we need to implement the equivalent of torchsnapshot.RNGState()
if self._async_checkpoint:
with get_timing_context(state, f"{self.__class__.__name__}.async_save"):
Expand Down Expand Up @@ -315,14 +363,17 @@ def restore(
)

# necessary for loading optimizers since states are initialized lazy
for obj in app_state.values():
for key, obj in app_state.items():
# sometimes optimizers are actually held in a wrapper which handles calling
# state_dict and load_state_dict, sa is the case for
# `torchtnt.utils.prepare_module.FSDPOptimizerWrapper`, this handles that case.
optimizer = getattr(obj, "optimizer", obj)
if isinstance(optimizer, torch.optim.Optimizer):
init_optim_state(optimizer)

if isinstance(obj, DistributedDataParallel):
app_state[key] = DSDModelWrapper(obj)

try:
dcp.load(
{"app_state": MultiStateful(app_state)},
Expand Down

0 comments on commit 435b1cb

Please sign in to comment.