From 0f723338d68567d583e37e7d751399b023bc48d8 Mon Sep 17 00:00:00 2001 From: Diego Urgell Date: Fri, 14 Jun 2024 14:48:20 -0700 Subject: [PATCH] Don't take final ckpt if no more training was done in FIT (#846) Summary: Pull Request resolved: https://github.com/pytorch/tnt/pull/846 Reviewed By: JKSenthil Differential Revision: D58397317 fbshipit-source-id: 31b8f7382059f04cd35f26eafbd957bd53e7f3f0 --- .../callbacks/test_base_checkpointer.py | 38 +++++++++++++++++++ .../framework/callbacks/base_checkpointer.py | 31 +++++++++++---- 2 files changed, 62 insertions(+), 7 deletions(-) diff --git a/tests/framework/callbacks/test_base_checkpointer.py b/tests/framework/callbacks/test_base_checkpointer.py index b9918e3cde..3998ff62b7 100644 --- a/tests/framework/callbacks/test_base_checkpointer.py +++ b/tests/framework/callbacks/test_base_checkpointer.py @@ -500,6 +500,44 @@ def test_save_on_train_end(self) -> None: ], ) + def test_save_on_train_end_on_fit(self) -> None: + input_dim = 2 + dataset_len = 10 + batch_size = 2 + max_epochs = 6 + + for save_every_n_eval_epochs, expected_last_ckpt in [ + (None, "epoch_6_train_step_30_eval_step_25"), + (2, "epoch_6_train_step_30_eval_step_30"), + ]: + my_unit = DummyAutoUnit(module=nn.Linear(input_dim, 2)) + train_dataloader = generate_random_dataloader( + dataset_len, input_dim, batch_size + ) + eval_dataloader = generate_random_dataloader( + dataset_len, input_dim, batch_size + ) + with tempfile.TemporaryDirectory() as temp_dir: + checkpoint_cb = BaseCheckpointSaver( + temp_dir, + save_every_n_epochs=2, + save_every_n_eval_epochs=save_every_n_eval_epochs, + ) + fit( + my_unit, + train_dataloader=train_dataloader, + eval_dataloader=eval_dataloader, + max_epochs=max_epochs, + evaluate_every_n_epochs=1, + callbacks=[checkpoint_cb], + ) + expected_path = os.path.join(temp_dir, expected_last_ckpt) + self.assertTrue(os.path.exists(expected_path)) + self.assertEqual( + checkpoint_cb._checkpoint_manager._ckpt_paths[-1].path, + expected_path, + ) + @skip_if_not_distributed def test_directory_sync_collective(self) -> None: spawn_multi_process( diff --git a/torchtnt/framework/callbacks/base_checkpointer.py b/torchtnt/framework/callbacks/base_checkpointer.py index 95c61a1ae0..cc9bc17481 100644 --- a/torchtnt/framework/callbacks/base_checkpointer.py +++ b/torchtnt/framework/callbacks/base_checkpointer.py @@ -16,7 +16,7 @@ from torchtnt.framework.callback import Callback from torchtnt.framework.callbacks._checkpoint_utils import _get_step_phase_mapping from torchtnt.framework.callbacks.checkpointer_types import RestoreOptions -from torchtnt.framework.state import State +from torchtnt.framework.state import EntryPoint, State from torchtnt.framework.unit import AppStateMixin, TEvalUnit, TTrainData, TTrainUnit from torchtnt.utils.checkpoint import ( BestCheckpointConfig, @@ -25,6 +25,7 @@ get_best_checkpoint_path, get_latest_checkpoint_path, MetricData, + Phase, ) from torchtnt.utils.distributed import PGWrapper from torchtnt.utils.rank_zero_log import rank_zero_info, rank_zero_warn @@ -177,12 +178,28 @@ def _generate_checkpoint_and_upkeep( if not self._checkpoint_manager.should_save_checkpoint(checkpoint_path): return False - # 2.1) Make sure that last checkpoint does not already exist - if hook == "on_train_end" and self._does_checkpoint_exist( - checkpoint_path, self._process_group - ): - rank_zero_warn("Final checkpoint already exists, skipping.", logger=logger) - return False + if hook == "on_train_end": + # 2.1) Make sure that last checkpoint does not already exist + if self._does_checkpoint_exist(checkpoint_path, self._process_group): + rank_zero_warn( + "Final checkpoint already exists, skipping.", logger=logger + ) + return False + + # 2.2) If doing fit without eval checkpointing, only consider training progress when + # checking if last checkpoint exists. + if ( + state.entry_point == EntryPoint.FIT + and self._save_every_n_eval_epochs is None + and self._checkpoint_manager._ckpt_paths + and self._checkpoint_manager._ckpt_paths[-1].step[Phase.TRAIN] + == cast(TTrainUnit, unit).train_progress.num_steps_completed + ): + rank_zero_info( + "Omitting final checkpoint since train progress is unchanged, and eval checkpointing is not configured.", + logger=logger, + ) + return False # 3) try to save checkpoint if not self._checkpoint_impl(