Skip to content

Commit

Permalink
Remove unnecessary encapsulation of DCP APIs (#891)
Browse files Browse the repository at this point in the history
Summary: Pull Request resolved: #891

Reviewed By: JKSenthil

Differential Revision: D61951699
  • Loading branch information
diego-urgell authored and facebook-github-bot committed Aug 29, 2024
1 parent d3e85dc commit 2a45ee2
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 64 deletions.
15 changes: 11 additions & 4 deletions tests/framework/callbacks/test_dcp_saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
DummyAutoUnit,
DummyTrainUnit,
generate_random_dataloader,
get_dummy_train_state,
)
from torchtnt.framework.callbacks.checkpointer_types import KnobOptions, RestoreOptions
from torchtnt.framework.callbacks.dcp_saver import DistributedCheckpointSaver
Expand Down Expand Up @@ -306,6 +307,7 @@ def test_save_default_planner_storage_components(
save_every_n_train_steps = 1

my_unit = DummyTrainUnit(input_dim=input_dim)
state = get_dummy_train_state()

with tempfile.TemporaryDirectory() as temp_dir:
dcp_cb = DistributedCheckpointSaver(
Expand All @@ -314,9 +316,11 @@ def test_save_default_planner_storage_components(
knob_options=KnobOptions(1),
)

dcp_cb._save(
dcp_cb._checkpoint_impl(
state=state,
unit=my_unit,
checkpoint_id=temp_dir,
app_state=my_unit.module.state_dict(),
hook="on_train_epoch_end",
)

planner = mock_dist_cp.save.call_args_list[0][1]["planner"]
Expand All @@ -331,6 +335,7 @@ def test_save_planner_storage_components(self, mock_dist_cp: MagicMock) -> None:
save_every_n_train_steps = 1

my_unit = DummyTrainUnit(input_dim=input_dim)
state = get_dummy_train_state()

with tempfile.TemporaryDirectory() as temp_dir:
dcp_cb = DistributedCheckpointSaver(
Expand All @@ -339,9 +344,11 @@ def test_save_planner_storage_components(self, mock_dist_cp: MagicMock) -> None:
knob_options=KnobOptions(1),
)

dcp_cb._save(
dcp_cb._checkpoint_impl(
state=state,
unit=my_unit,
checkpoint_id=temp_dir,
app_state=my_unit.module.state_dict(),
hook="on_train_epoch_end",
planner=DummySavePlanner(),
storage_writer=DummyStorageWriter(path=temp_dir),
)
Expand Down
81 changes: 21 additions & 60 deletions torchtnt/framework/callbacks/dcp_saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,26 +137,38 @@ def _checkpoint_impl(
intra_epoch = hook == "on_train_step_end"
curr_snapshot_wait = hook == "on_train_end"

if planner is None:
planner = DefaultSavePlanner()

if storage_writer is None:
storage_writer = Writer(checkpoint_id, **self.default_writer_options)

app_state = _prepare_app_state_for_checkpoint(state, unit, intra_epoch)
# TODO: evaluate whether we need to implement the equivalent of torchsnapshot.RNGState()
if self._async_checkpoint:
with get_timing_context(state, f"{self.__class__.__name__}.async_save"):
# TODO checkpoint is not truly successful
# since this is async checkpointed, so in
# future, add logic to set successful flag
# only when checkpoint is fully written
checkpoint_success = self._async_save(
checkpoint_id, app_state, planner, storage_writer
# Redundant check for safety
self._wait(log_warning=True)
self._prev_snapshot = dcp.async_save(
state_dict={"app_state": MultiStateful(app_state)},
checkpoint_id=checkpoint_id,
process_group=self._process_group,
storage_writer=storage_writer,
planner=planner,
)
if curr_snapshot_wait:
self._wait(log_warning=False)
else:
with get_timing_context(state, f"{self.__class__.__name__}.save"):
checkpoint_success = self._save(
checkpoint_id, app_state, planner, storage_writer
dcp.save(
state_dict={"app_state": MultiStateful(app_state)},
checkpoint_id=checkpoint_id,
process_group=self._process_group,
storage_writer=storage_writer,
planner=planner,
)

return checkpoint_success
return True

def _wait(self, log_warning: bool = True) -> None:
"""
Expand Down Expand Up @@ -195,57 +207,6 @@ def _wait(self, log_warning: bool = True) -> None:
logger=logger,
)

def _async_save(
self,
checkpoint_id: str,
app_state: Dict[str, Stateful],
planner: Optional[SavePlanner] = None,
storage_writer: Optional[StorageWriter] = None,
) -> bool:

if planner is None:
planner = DefaultSavePlanner()

if storage_writer is None:
storage_writer = Writer(checkpoint_id, **self.default_writer_options)

# Redundant check for safety
self._wait(log_warning=True)

self._prev_snapshot = dcp.async_save(
state_dict={"app_state": MultiStateful(app_state)},
checkpoint_id=checkpoint_id,
process_group=self._process_group,
storage_writer=storage_writer,
planner=planner,
)

return True

def _save(
self,
checkpoint_id: str,
app_state: Dict[str, Stateful],
planner: Optional[SavePlanner] = None,
storage_writer: Optional[StorageWriter] = None,
) -> bool:
# Initialize DefaultSavePlanner and FsspecWriter if not provided
if planner is None:
planner = DefaultSavePlanner()

if storage_writer is None:
storage_writer = Writer(checkpoint_id, **self.default_writer_options)

dcp.save(
state_dict={"app_state": MultiStateful(app_state)},
checkpoint_id=checkpoint_id,
process_group=self._process_group,
storage_writer=storage_writer,
planner=planner,
)

return True

def on_exception(
self,
state: State,
Expand Down

0 comments on commit 2a45ee2

Please sign in to comment.