From 4c048acd6de3fc75b39ae3be44787af8c0d6d40b Mon Sep 17 00:00:00 2001 From: Saurabh Mishra Date: Mon, 15 Jul 2024 07:28:24 -0700 Subject: [PATCH] Use checkpoint_id in the internal APIs (#860) Summary: Pull Request resolved: https://github.com/pytorch/tnt/pull/860 Use checkpoint_id in the internal APIs instead of checkpoint paths. ID is a more generic parameter which will be used in the subsequent diffs to represent Meta internal abstractions like model entity id to identify a checkpoint. Reviewed By: galrotem Differential Revision: D59638742 --- tests/framework/callbacks/test_base_checkpointer.py | 8 ++++---- torchtnt/framework/callbacks/base_checkpointer.py | 6 +++--- torchtnt/framework/callbacks/dcp_saver.py | 6 +++--- torchtnt/framework/callbacks/torchsnapshot_saver.py | 6 +++--- 4 files changed, 13 insertions(+), 13 deletions(-) diff --git a/tests/framework/callbacks/test_base_checkpointer.py b/tests/framework/callbacks/test_base_checkpointer.py index 3998ff62b7..1bd2739970 100644 --- a/tests/framework/callbacks/test_base_checkpointer.py +++ b/tests/framework/callbacks/test_base_checkpointer.py @@ -74,11 +74,11 @@ def __init__( self._latest_checkpoint_path: str = "" def _checkpoint_impl( - self, state: State, unit: AppStateMixin, checkpoint_path: str, hook: str + self, state: State, unit: AppStateMixin, checkpoint_id: str, hook: str ) -> bool: - self._latest_checkpoint_path = checkpoint_path - if not os.path.exists(checkpoint_path): - os.mkdir(checkpoint_path) + self._latest_checkpoint_path = checkpoint_id + if not os.path.exists(checkpoint_id): + os.mkdir(checkpoint_id) return True @staticmethod diff --git a/torchtnt/framework/callbacks/base_checkpointer.py b/torchtnt/framework/callbacks/base_checkpointer.py index cc9bc17481..82ed41f4b7 100644 --- a/torchtnt/framework/callbacks/base_checkpointer.py +++ b/torchtnt/framework/callbacks/base_checkpointer.py @@ -203,7 +203,7 @@ def _generate_checkpoint_and_upkeep( # 3) try to save checkpoint if not self._checkpoint_impl( - state, unit, checkpoint_path=checkpoint_path.path, hook=hook + state, unit, checkpoint_id=checkpoint_path.path, hook=hook ): return False @@ -299,7 +299,7 @@ def _checkpoint_impl( state: State, unit: AppStateMixin, *, - checkpoint_path: str, + checkpoint_id: str, hook: str, ) -> bool: """ @@ -308,7 +308,7 @@ def _checkpoint_impl( Args: state: current application state unit: current unit - checkpoint_path: path to save checkpoint + checkpoint_id: Checkpoint id to save a checkpoint. It can be a path hook: name of callback hook that triggered this function call Returns: diff --git a/torchtnt/framework/callbacks/dcp_saver.py b/torchtnt/framework/callbacks/dcp_saver.py index 59d6ded012..232171df9f 100644 --- a/torchtnt/framework/callbacks/dcp_saver.py +++ b/torchtnt/framework/callbacks/dcp_saver.py @@ -131,7 +131,7 @@ def _checkpoint_impl( state: State, unit: AppStateMixin, *, - checkpoint_path: str, + checkpoint_id: str, hook: str, planner: Optional[SavePlanner] = None, storage_writer: Optional[StorageWriter] = None, @@ -156,14 +156,14 @@ def _checkpoint_impl( # future, add logic to set successful flag # only when checkpoint is fully written checkpoint_success = self._async_save( - checkpoint_path, app_state, planner, storage_writer + checkpoint_id, app_state, planner, storage_writer ) if curr_snapshot_wait: self._wait() else: with get_timing_context(state, f"{self.__class__.__name__}.save"): checkpoint_success = self._save( - checkpoint_path, app_state, planner, storage_writer + checkpoint_id, app_state, planner, storage_writer ) return checkpoint_success diff --git a/torchtnt/framework/callbacks/torchsnapshot_saver.py b/torchtnt/framework/callbacks/torchsnapshot_saver.py index d9a7153d8c..2ffb423b44 100644 --- a/torchtnt/framework/callbacks/torchsnapshot_saver.py +++ b/torchtnt/framework/callbacks/torchsnapshot_saver.py @@ -151,7 +151,7 @@ def _checkpoint_impl( state: State, unit: AppStateMixin, *, - checkpoint_path: str, + checkpoint_id: str, hook: str, ) -> bool: """ @@ -185,12 +185,12 @@ def _checkpoint_impl( # since this is async checkpointed, so in # future, add logic to set successful flag # only when checkpoint is fully written - checkpoint_success = self._async_snapshot(checkpoint_path, app_state) + checkpoint_success = self._async_snapshot(checkpoint_id, app_state) if curr_snapshot_wait: self._wait() else: with get_timing_context(state, f"{self.__class__.__name__}.take_snapshot"): - checkpoint_success = self._sync_snapshot(checkpoint_path, app_state) + checkpoint_success = self._sync_snapshot(checkpoint_id, app_state) return checkpoint_success def _wait(self) -> None: