Skip to content

Commit

Permalink
remove progress from replicated if ["**"] passed in
Browse files Browse the repository at this point in the history
Reviewed By: schwarzmx

Differential Revision: D53284173

fbshipit-source-id: 8874035cce9c2c13dd5a878a538e0f93eb929119
  • Loading branch information
JKSenthil authored and facebook-github-bot committed Feb 1, 2024
1 parent 5b67b42 commit 521984f
Show file tree
Hide file tree
Showing 2 changed files with 89 additions and 2 deletions.
63 changes: 63 additions & 0 deletions tests/framework/callbacks/test_torchsnapshot_saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
)
from torchtnt.framework.callbacks.checkpointer_types import KnobOptions, RestoreOptions
from torchtnt.framework.callbacks.torchsnapshot_saver import (
_exclude_progress_from_replicated,
_override_knobs,
TorchSnapshotSaver,
)
Expand Down Expand Up @@ -363,6 +364,68 @@ def test_sync_checkpoint(self, _: MagicMock) -> None:
snapshot_cb.on_train_step_end(state, my_unit)
snapshot_cb._sync_snapshot.assert_called_once()

def test_exclude_progress_from_replicated(self) -> None:
"""
Tests that replicated is populated correctly with progress excluded
"""

module = nn.Linear(2, 3)
my_unit = DummyAutoUnit(module=module)
keys = my_unit.app_state().keys()

progress_keys = {"train_progress", "eval_progress", "predict_progress"}

replicated = _exclude_progress_from_replicated(my_unit.app_state())
for key in keys:
if key not in progress_keys:
self.assertIn(f"{key}/**", replicated)

# since we exclude 3 keys (train, eval, predict)
self.assertEqual(len(keys) - 3, len(replicated))

# check that progress is not included
for progress_key in progress_keys:
self.assertNotIn(f"{progress_key}/", replicated)

@patch("torchtnt.framework.callbacks.torchsnapshot_saver.Snapshot.take")
def test_exclude_progress_from_replicated_e2e(self, mock_take: MagicMock) -> None:
"""
Tests that replicated is populated correctly during snapshotting
"""

module = nn.Linear(2, 3)
my_unit = DummyAutoUnit(module=module)
state = get_dummy_train_state()

with tempfile.TemporaryDirectory() as temp_dir:
for replicated_value in (None, ["optimizer/**"], ["**"]):
tss = TorchSnapshotSaver(
dirpath=temp_dir,
save_every_n_train_steps=1,
async_checkpoint=False,
replicated=replicated_value,
)

progress_keys = {"train_progress", "eval_progress", "predict_progress"}

tss.on_train_step_end(state, my_unit)
replicated = mock_take.call_args.kwargs["replicated"]

if replicated_value is None:
self.assertEqual(replicated, [])
elif replicated_value == ["optimizer/**"]:
self.assertEqual(replicated, ["optimizer/**"])
elif replicated_value == ["**"]:
expected_replicated = [
f"{key}/**"
for key in my_unit.app_state().keys()
if key not in progress_keys
]
# this is added outside of the unit's app_state so it should be included
expected_replicated.append("rng_state/**")

self.assertEqual(set(replicated), set(expected_replicated))


class DummyStatefulDataLoader:
def __init__(self, dataloader: DataLoader) -> None:
Expand Down
28 changes: 26 additions & 2 deletions torchtnt/framework/callbacks/torchsnapshot_saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,12 +216,16 @@ def _async_snapshot(
)
return False

replicated = self._replicated
if self._replicated == {"**"}:
replicated = _exclude_progress_from_replicated(app_state)

with _override_knobs(self._knob_options):
self._prev_snapshot = Snapshot.async_take(
str(snapshot_path),
app_state=app_state,
pg=self._process_group,
replicated=list(self._replicated),
replicated=list(replicated),
storage_options=self._storage_options,
)
rank_zero_info(f"Saving snapshot to path: {snapshot_path}", logger=logger)
Expand All @@ -232,6 +236,10 @@ def _sync_snapshot(
snapshot_path: str,
app_state: Dict[str, _TStateful],
) -> bool:
replicated = self._replicated
if self._replicated == {"**"}:
replicated = _exclude_progress_from_replicated(app_state)

with _override_knobs(self._knob_options):
rank_zero_info(
f"Started saving snapshot to path: {snapshot_path}", logger=logger
Expand All @@ -240,7 +248,7 @@ def _sync_snapshot(
str(snapshot_path),
app_state=app_state,
pg=self._process_group,
replicated=list(self._replicated),
replicated=list(replicated),
storage_options=self._storage_options,
)
rank_zero_info(
Expand Down Expand Up @@ -316,6 +324,22 @@ def restore(
rank_zero_info(f"Restored snapshot from path: {path}", logger=logger)


def _exclude_progress_from_replicated(app_state: Dict[str, _TStateful]) -> Set[str]:
"""
Excludes progress state from being replicated. Called if replicated=["**"] is passed in.
Works by populating replicated with all possible keys from app_state, except for
the keys that match the "{train,eval,predict}_progress/**" pattern.
"""

filtered_replicated = set()
progress_keys = {"train_progress", "eval_progress", "predict_progress"}
for key in app_state.keys():
if key in progress_keys:
continue
filtered_replicated.add(f"{key}/**")
return filtered_replicated


def _validate_snapshot_available() -> None:
if not _TORCHSNAPSHOT_AVAILABLE:
raise RuntimeError(
Expand Down

0 comments on commit 521984f

Please sign in to comment.