Skip to content

Commit

Permalink
Always wait for the previous snapshot to finish (#708)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #708

This avoids concurrency issues when the storage doesn't support it.

Reviewed By: JKSenthil

Differential Revision: D53880548

fbshipit-source-id: da81fe4d379aeb09e0f52caecb80f4a3fdb2a616
  • Loading branch information
schwarzmx authored and facebook-github-bot committed Feb 20, 2024
1 parent 96ecbe4 commit 3705462
Showing 1 changed file with 14 additions and 15 deletions.
29 changes: 14 additions & 15 deletions torchtnt/framework/callbacks/torchsnapshot_saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,18 +158,16 @@ def _checkpoint_impl(
"""
Checkpoint the current state of the application.
"""
if hook not in ["on_train_step_end", "on_train_epoch_end", "on_train_end"]:
raise RuntimeError(f"Unexpected hook encountered '{hook}'")

intra_epoch = False
prev_snapshot_wait = False
curr_snapshot_wait = False

if hook == "on_train_step_end":
intra_epoch = True
elif hook == "on_train_epoch_end":
prev_snapshot_wait = True
elif hook == "on_train_end":
prev_snapshot_wait = True
curr_snapshot_wait = True
else:
raise RuntimeError(f"Unexpected hook encountered '{hook}'")

app_state = _prepare_app_state_for_checkpoint(state, unit, intra_epoch)
rng_state = torchsnapshot.RNGState()
Expand All @@ -183,9 +181,7 @@ def _checkpoint_impl(
# since this is async checkpointed, so in
# future, add logic to set successful flag
# only when checkpoint is fully written
checkpoint_success = self._async_snapshot(
checkpoint_path, app_state, wait=prev_snapshot_wait
)
checkpoint_success = self._async_snapshot(checkpoint_path, app_state)
if curr_snapshot_wait:
self._wait()
else:
Expand All @@ -198,7 +194,9 @@ def _wait(self) -> None:
self._prev_snapshot.wait()

def _async_snapshot(
self, snapshot_path: str, app_state: Dict[str, _TStateful], *, wait: bool
self,
snapshot_path: str,
app_state: Dict[str, _TStateful],
) -> bool:
prev_snapshot = self._prev_snapshot
if prev_snapshot is not None:
Expand All @@ -207,14 +205,15 @@ def _async_snapshot(
# This can happen if we call _async_snapshot twice at the same step.
return False
still_pending = not prev_snapshot.done()
if still_pending and wait:
prev_snapshot.wait()
elif still_pending:
if still_pending:
rank_zero_warn(
f"Still writing previous snapshot, will skip this one. Consider increasing 'frequency' (current {self._save_every_n_train_steps})",
(
"Still writing previous snapshot; waiting for it to finish before writing a new one. "
f"Consider increasing 'frequency' (current {self._save_every_n_train_steps})"
),
logger=logger,
)
return False
prev_snapshot.wait()

replicated = self._replicated
if self._replicated == {"**"}:
Expand Down

0 comments on commit 3705462

Please sign in to comment.