Skip to content

Commit

Permalink
add _checkpoint_impl to TSS (#622)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #622

# Context
The `on_train_step_end`, `on_train_epoch_end` and `on_train_end` hooks all use a nearly identical code block to save checkpoint
```
with get_timing_context(
    state, f"{self.__class__.__name__}.take_async_snapshot"
):
    checkpoint_success = self._async_snapshot(
        snapshot_path, app_state, wait=True
    )

if checkpoint_success:
    if self._should_remove_snapshot():
        self._remove_snapshot(state)
    self._ckpt_dirpaths.append(snapshot_path)
```
# This Diff
Moves this logic into a `_checkpoint_impl` function to reduce code duplication

Reviewed By: galrotem

Differential Revision: D51285751

fbshipit-source-id: 0dbb5a024e9f2e05de0b8c41432d9951874e1eef
  • Loading branch information
JKSenthil authored and facebook-github-bot committed Nov 14, 2023
1 parent b57c4ea commit fa9e007
Showing 1 changed file with 61 additions and 35 deletions.
96 changes: 61 additions & 35 deletions torchtnt/framework/callbacks/torchsnapshot_saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,6 @@ def on_train_step_end(self, state: State, unit: TTrainUnit) -> None:
):
return

app_state = _get_app_state(state, unit, self._replicated, intra_epoch=True)
epoch = unit.train_progress.num_epochs_completed
if state.entry_point == EntryPoint.FIT:
num_steps_completed += cast(
Expand All @@ -205,25 +204,21 @@ def on_train_step_end(self, state: State, unit: TTrainUnit) -> None:
snapshot_path = _get_snapshot_save_path(
self._dirpath, epoch, num_steps_completed
)
with get_timing_context(
state, f"{self.__class__.__name__}.take_async_snapshot"
):
checkpoint_success = self._async_snapshot(
snapshot_path, app_state, wait=False
)

if checkpoint_success:
if self._should_remove_snapshot():
self._remove_snapshot(state)
self._ckpt_dirpaths.append(snapshot_path)
self._checkpoint_impl(
state,
unit,
snapshot_path=snapshot_path,
intra_epoch=True,
prev_snapshot_wait=False,
curr_snapshot_wait=False,
)

def on_train_epoch_end(self, state: State, unit: TTrainUnit) -> None:
epoch = unit.train_progress.num_epochs_completed
save_every_n_epochs = self._save_every_n_epochs
if save_every_n_epochs is None or epoch % save_every_n_epochs != 0:
return

app_state = _get_app_state(state, unit, self._replicated, intra_epoch=False)
num_steps_completed = unit.train_progress.num_steps_completed
if state.entry_point == EntryPoint.FIT:
num_steps_completed += cast(
Expand All @@ -232,20 +227,16 @@ def on_train_epoch_end(self, state: State, unit: TTrainUnit) -> None:
snapshot_path = _get_snapshot_save_path(
self._dirpath, epoch, num_steps_completed
)
with get_timing_context(
state, f"{self.__class__.__name__}.take_async_snapshot"
):
checkpoint_success = self._async_snapshot(
snapshot_path, app_state, wait=True
)

if checkpoint_success:
if self._should_remove_snapshot():
self._remove_snapshot(state)
self._ckpt_dirpaths.append(snapshot_path)
self._checkpoint_impl(
state,
unit,
snapshot_path=snapshot_path,
intra_epoch=False,
prev_snapshot_wait=True,
curr_snapshot_wait=False,
)

def on_train_end(self, state: State, unit: TTrainUnit) -> None:
app_state = _get_app_state(state, unit, self._replicated, intra_epoch=False)
num_steps_completed = unit.train_progress.num_steps_completed
if state.entry_point == EntryPoint.FIT:
num_steps_completed += cast(
Expand All @@ -255,6 +246,47 @@ def on_train_end(self, state: State, unit: TTrainUnit) -> None:
snapshot_path = _get_snapshot_save_path(
self._dirpath, epoch, num_steps_completed
)
self._checkpoint_impl(
state,
unit,
snapshot_path=snapshot_path,
intra_epoch=False,
prev_snapshot_wait=True,
curr_snapshot_wait=True,
)

def on_exception(
self,
state: State,
unit: Union[TTrainUnit, TEvalUnit, TPredictUnit],
exc: BaseException,
) -> None:
self._wait()

def _checkpoint_impl(
self,
state: State,
unit: AppStateMixin,
*,
snapshot_path: str,
intra_epoch: bool,
prev_snapshot_wait: bool,
curr_snapshot_wait: bool,
) -> None:
"""
Checkpoint the current state of the application.
Args:
state: State of the application
unit: The training/evaluation/prediction unit
snapshot_path: Path to save the snapshot
intra_epoch: Whether in middle of epoch or not
prev_snapshot_wait: Whether to wait for previous snapshot to finish writing
curr_snapshot_wait: Whether to wait for current snapshot to finish writing
"""
app_state = _get_app_state(
state, unit, self._replicated, intra_epoch=intra_epoch
)
with get_timing_context(
state, f"{self.__class__.__name__}.take_async_snapshot"
):
Expand All @@ -263,23 +295,17 @@ def on_train_end(self, state: State, unit: TTrainUnit) -> None:
# future, add logic to set successful flag
# only when checkpoint is fully written
checkpoint_success = self._async_snapshot(
snapshot_path, app_state, wait=True
snapshot_path, app_state, wait=prev_snapshot_wait
)
self._wait()
if curr_snapshot_wait:
self._wait()

# remove and book keep snapshots related to keep_last_n_checkpoints
if checkpoint_success:
if self._should_remove_snapshot():
self._remove_snapshot(state)
self._ckpt_dirpaths.append(snapshot_path)

def on_exception(
self,
state: State,
unit: Union[TTrainUnit, TEvalUnit, TPredictUnit],
exc: BaseException,
) -> None:
self._wait()

def _should_remove_snapshot(self) -> bool:
keep_last_n_checkpoints = self._keep_last_n_checkpoints
return (
Expand Down

0 comments on commit fa9e007

Please sign in to comment.