diff --git a/tests/framework/callbacks/test_base_checkpointer.py b/tests/framework/callbacks/test_base_checkpointer.py index d40cd2c236..c94033a0db 100644 --- a/tests/framework/callbacks/test_base_checkpointer.py +++ b/tests/framework/callbacks/test_base_checkpointer.py @@ -25,6 +25,7 @@ Batch, DummyAutoUnit, DummyFitUnit, + DummyPredictUnit, DummyTrainUnit, generate_random_dataloader, get_dummy_fit_state, @@ -35,7 +36,9 @@ ) from torchtnt.framework.callbacks.checkpointer_types import RestoreOptions from torchtnt.framework.callbacks.lambda_callback import Lambda +from torchtnt.framework.evaluate import evaluate from torchtnt.framework.fit import fit +from torchtnt.framework.predict import predict from torchtnt.framework.state import ActivePhase, State from torchtnt.framework.train import train @@ -57,7 +60,9 @@ def __init__( *, save_every_n_train_steps: Optional[int] = None, save_every_n_epochs: Optional[int] = None, + save_every_n_eval_steps: Optional[int] = None, save_every_n_eval_epochs: Optional[int] = None, + save_every_n_predict_steps: Optional[int] = None, keep_last_n_checkpoints: Optional[int] = None, best_checkpoint_config: Optional[BestCheckpointConfig] = None, process_group: Optional[dist.ProcessGroup] = None, @@ -66,7 +71,9 @@ def __init__( dirpath, save_every_n_train_steps=save_every_n_train_steps, save_every_n_epochs=save_every_n_epochs, + save_every_n_eval_steps=save_every_n_eval_steps, save_every_n_eval_epochs=save_every_n_eval_epochs, + save_every_n_predict_steps=save_every_n_predict_steps, keep_last_n_checkpoints=keep_last_n_checkpoints, best_checkpoint_config=best_checkpoint_config, process_group=process_group, @@ -243,6 +250,83 @@ def test_save_fit_entrypoint(self) -> None: checkpointer._latest_checkpoint_path, ) + @patch.object(BaseCheckpointSaver, "_checkpoint_impl") + def test_save_eval_entrypoint(self, mock_checkpoint_impl: MagicMock) -> None: + my_unit = DummyFitUnit(input_dim=2) + with tempfile.TemporaryDirectory() as temp_dir: + checkpointer = BaseCheckpointSaver( + temp_dir, + save_every_n_eval_steps=2, + best_checkpoint_config=BestCheckpointConfig( + monitored_metric="val_loss", mode="min" + ), + keep_last_n_checkpoints=1, + ) + + ckpt_container: List[str] = [] + + def _checkpoint_impl_side_effect( + state: State, unit: AppStateMixin, checkpoint_id: str, hook: str + ) -> bool: + ckpt_container.append(checkpoint_id) + return True + + mock_checkpoint_impl.side_effect = _checkpoint_impl_side_effect + + eval_dataloader = generate_random_dataloader(10, 2, 1) + + warning_container: List[str] = [] + with patch( + "torchtnt.framework.callbacks.base_checkpointer.logging.Logger.warning", + side_effect=warning_container.append, + ): + evaluate(my_unit, eval_dataloader, callbacks=[checkpointer]) + + # Verify that checkpoint optimality tracking was disabled + self.assertIn( + "Disabling best_checkpoint_config, since it is not supported for eval or predict entrypoints.", + warning_container, + ) + self.assertIn( + "Disabling keep_last_n_checkpoints, since is not supported for eval or predict entrypoints.", + warning_container, + ) + + # Make sure that the correct checkpoints were saved, without tracked metrics + expected_ckpts = [ + f"{temp_dir}/epoch_0_eval_step_{i*2}" for i in range(1, 6) + ] + self.assertEqual(ckpt_container, expected_ckpts) + + @patch.object(BaseCheckpointSaver, "_checkpoint_impl") + def test_save_predict_entrypoint(self, mock_checkpoint_impl: MagicMock) -> None: + my_unit = DummyPredictUnit(input_dim=2) + with tempfile.TemporaryDirectory() as temp_dir: + checkpointer = BaseCheckpointSaver( + temp_dir, + save_every_n_predict_steps=1, + ) + + ckpt_container: List[str] = [] + + def _checkpoint_impl_side_effect( + state: State, unit: AppStateMixin, checkpoint_id: str, hook: str + ) -> bool: + ckpt_container.append(checkpoint_id) + return True + + mock_checkpoint_impl.side_effect = _checkpoint_impl_side_effect + + predict_dataloader = generate_random_dataloader(10, 2, 1) + + predict(my_unit, predict_dataloader, callbacks=[checkpointer]) + + # Make sure that the correct checkpoints were saved + expected_ckpts = [ + f"{temp_dir}/epoch_0_predict_step_{i}" for i in range(1, 11) + ] + self.assertEqual(ckpt_container, expected_ckpts) + @unittest.mock.patch("sys.stdout", new_callable=io.StringIO) def test_restore_from_latest(self, mock_stdout: MagicMock) -> None: my_unit = DummyTrainUnit(input_dim=2) diff --git a/torchtnt/framework/callbacks/base_checkpointer.py b/torchtnt/framework/callbacks/base_checkpointer.py index 0b98737d65..a411e80163 100644 --- a/torchtnt/framework/callbacks/base_checkpointer.py +++ b/torchtnt/framework/callbacks/base_checkpointer.py @@ -15,10 +15,19 @@ import torch.distributed as dist from pyre_extensions import none_throws from torchtnt.framework.callback import Callback -from torchtnt.framework.callbacks._checkpoint_utils import _get_step_phase_mapping +from torchtnt.framework.callbacks._checkpoint_utils import ( + _get_epoch, + _get_step_phase_mapping, +) from torchtnt.framework.callbacks.checkpointer_types import RestoreOptions -from torchtnt.framework.state import EntryPoint, State -from torchtnt.framework.unit import AppStateMixin, TEvalUnit, TTrainData, TTrainUnit +from torchtnt.framework.state import ActivePhase, EntryPoint, State +from torchtnt.framework.unit import ( + AppStateMixin, + TEvalUnit, + TPredictUnit, + TTrainData, + TTrainUnit, +) from torchtnt.utils.checkpoint import ( BestCheckpointConfig, CheckpointManager, @@ -51,8 +60,11 @@ class BaseCheckpointer(Callback, metaclass=abc.ABCMeta): save_every_n_train_steps: Frequency of steps with which to save checkpoints during the train epoch. If None, no intra-epoch checkpoints are generated. save_every_n_epochs: Frequency of epochs with which to save checkpoints during training. If None, no end-of-epoch checkpoints are generated. save_every_n_eval_epochs: Frequency of evaluation epochs with which to save checkpoints during training. Use this if wanting to save checkpoints after every eval epoch during fit. - keep_last_n_checkpoints: Number of most recent checkpoints to keep. If None, all checkpoints are kept. If an excess of existing checkpoints are present, the oldest ones will be deleted to clean the difference. If best checkpoint config is enabled, this param will manage the top n checkpoints instead. - best_checkpoint_config: Configuration for saving the best checkpoint based on a monitored metric. The metric is read off the attribute of the unit prior to checkpoint. + save_every_n_eval_steps: Frequency of evaluation steps with which to save checkpoints during training. Use this if wanting to save checkpoints during evaluate. + save_every_n_predict_steps: Frequency of prediction steps with which to save checkpoints during training. Use this if wanting to save checkpoints during using predict entrypoint. + keep_last_n_checkpoints: Number of most recent checkpoints to keep. If None, all checkpoints are kept. If an excess of existing checkpoints are present, the oldest ones will be deleted + to clean the difference. If best checkpoint config is enabled, this param will manage the top n checkpoints instead. Only supported for train or fit entrypoints. + best_checkpoint_config: Configuration for saving the best checkpoint based on a monitored metric. The metric is read off the attribute of the unit prior to checkpoint. This param is ignored if not in train or fit entrypoints. process_group: The process group on which the ranks will communicate on. If the process group is not gloo-based, a new gloo-based process group will be created. Note: @@ -78,6 +90,8 @@ def __init__( save_every_n_train_steps: Optional[int] = None, save_every_n_epochs: Optional[int] = None, save_every_n_eval_epochs: Optional[int] = None, + save_every_n_eval_steps: Optional[int] = None, + save_every_n_predict_steps: Optional[int] = None, keep_last_n_checkpoints: Optional[int] = None, best_checkpoint_config: Optional[BestCheckpointConfig] = None, process_group: Optional[dist.ProcessGroup] = None, @@ -90,12 +104,23 @@ def __init__( raise ValueError( f"Invalid value passed for save_every_n_epochs. Expected to receive either None or positive number, but received {save_every_n_epochs}" ) + if save_every_n_eval_steps is not None and save_every_n_eval_steps <= 0: + raise ValueError( + f"Invalid value passed for save_every_n_eval_steps. Expected to receive either None or positive number, but received {save_every_n_eval_steps}" + ) + if save_every_n_eval_epochs is not None and save_every_n_eval_epochs <= 0: + raise ValueError( + f"Invalid value passed for save_every_n_eval_epochs. Expected to receive either None or positive number, but received {save_every_n_eval_epochs}" + ) + if save_every_n_predict_steps is not None and save_every_n_predict_steps <= 0: + raise ValueError( + f"Invalid value passed for save_every_n_predict_steps. Expected to receive either None or positive number, but received {save_every_n_predict_steps}" + ) if keep_last_n_checkpoints is not None and keep_last_n_checkpoints <= 0: raise ValueError( f"Invalid value passed for keep_last_n_checkpoints. Expected to receive either None or positive number, but received {keep_last_n_checkpoints}" ) - self._best_checkpoint_config = best_checkpoint_config if best_checkpoint_config and best_checkpoint_config.mode not in {"min", "max"}: raise ValueError( f"Invalid value passed for best_checkpoint_config.mode. Expected to receive 'min' or 'max', but received {best_checkpoint_config.mode}" @@ -104,7 +129,10 @@ def __init__( self._save_every_n_train_steps = save_every_n_train_steps self._save_every_n_epochs = save_every_n_epochs self._save_every_n_eval_epochs = save_every_n_eval_epochs + self._save_every_n_eval_steps = save_every_n_eval_steps + self._save_every_n_predict_steps = save_every_n_predict_steps self._keep_last_n_checkpoints = keep_last_n_checkpoints + self._best_checkpoint_config = best_checkpoint_config self._process_group: Optional[dist.ProcessGroup] = None self._setup_gloo_pg(process_group) @@ -147,7 +175,7 @@ def dirpath(self) -> str: return self._checkpoint_manager.dirpath def _generate_checkpoint_and_upkeep( - self, state: State, unit: Union[TTrainUnit, TEvalUnit], hook: str + self, state: State, unit: Union[TTrainUnit, TEvalUnit, TPredictUnit], hook: str ) -> bool: """ Implementation for saving checkpoint while taking care of checkpoint @@ -162,11 +190,16 @@ def _generate_checkpoint_and_upkeep( True if checkpoint was successfully saved. False otherwise. """ # 1) generate checkpoint name - epoch = cast(TTrainUnit, unit).train_progress.num_epochs_completed + epoch = _get_epoch(state, unit) step_mapping = _get_step_phase_mapping(state, unit) + # 1.1) append metric data only for train checkpoints, if best_checkpoint_config is defined metric_data: Optional[MetricData] = None - if metric_value := self._get_tracked_metric_value(unit): + if ( + self._best_checkpoint_config + and state.active_phase == ActivePhase.TRAIN + and (metric_value := self._get_tracked_metric_value(cast(TTrainUnit, unit))) + ): metric_data = MetricData( name=none_throws(self._best_checkpoint_config).monitored_metric, value=metric_value, @@ -179,7 +212,8 @@ def _generate_checkpoint_and_upkeep( process_group=self._process_group, ) - # 2) Determine if we should save checkpoint + # 2) Determine if we should save checkpoint. This is a no-op for eval and predict entrypoints + # since neither best_checkpoint_config nor keep_last_n_checkpoints are supported. if not self._checkpoint_manager.should_save_checkpoint(checkpoint_path): return False @@ -222,9 +256,7 @@ def _generate_checkpoint_and_upkeep( return True - def _get_tracked_metric_value( - self, unit: Union[TTrainUnit, TEvalUnit] - ) -> Optional[float]: + def _get_tracked_metric_value(self, unit: TTrainUnit) -> Optional[float]: """ If the checkpointer has a tracked metric, look the value in the unit using reflection, and cast to float. @@ -271,10 +303,9 @@ def on_train_start(self, state: State, unit: TTrainUnit) -> None: def on_train_step_end(self, state: State, unit: TTrainUnit) -> None: num_steps_completed = unit.train_progress.num_steps_completed - save_every_n_train_steps = self._save_every_n_train_steps if ( - save_every_n_train_steps is None - or num_steps_completed % save_every_n_train_steps != 0 + not self._save_every_n_train_steps + or num_steps_completed % self._save_every_n_train_steps != 0 ): return @@ -282,22 +313,70 @@ def on_train_step_end(self, state: State, unit: TTrainUnit) -> None: def on_train_epoch_end(self, state: State, unit: TTrainUnit) -> None: epoch = unit.train_progress.num_epochs_completed - save_every_n_epochs = self._save_every_n_epochs - if save_every_n_epochs is None or epoch % save_every_n_epochs != 0: + if not self._save_every_n_epochs or epoch % self._save_every_n_epochs != 0: return self._generate_checkpoint_and_upkeep(state, unit, hook="on_train_epoch_end") + def on_train_end(self, state: State, unit: TTrainUnit) -> None: + self._generate_checkpoint_and_upkeep(state, unit, hook="on_train_end") + + def on_eval_start(self, state: State, unit: TEvalUnit) -> None: + if state.entry_point == EntryPoint.EVALUATE: + self._disable_ckpt_optimality_tracking() + + def on_eval_step_end(self, state: State, unit: TEvalUnit) -> None: + num_steps_completed = unit.eval_progress.num_steps_completed + if ( + not self._save_every_n_eval_steps + or num_steps_completed % self._save_every_n_eval_steps != 0 + ): + return + + self._generate_checkpoint_and_upkeep(state, unit, hook="on_eval_step_end") + def on_eval_epoch_end(self, state: State, unit: TEvalUnit) -> None: epoch = unit.eval_progress.num_epochs_completed - save_every_n_eval_epochs = self._save_every_n_eval_epochs - if save_every_n_eval_epochs is None or epoch % save_every_n_eval_epochs != 0: + if ( + not self._save_every_n_eval_epochs + or epoch % self._save_every_n_eval_epochs != 0 + ): return self._generate_checkpoint_and_upkeep(state, unit, hook="on_eval_epoch_end") - def on_train_end(self, state: State, unit: TTrainUnit) -> None: - self._generate_checkpoint_and_upkeep(state, unit, hook="on_train_end") + def on_predict_start(self, state: State, unit: TPredictUnit) -> None: + self._disable_ckpt_optimality_tracking() + + def on_predict_step_end(self, state: State, unit: TPredictUnit) -> None: + num_steps_completed = unit.predict_progress.num_steps_completed + if ( + not self._save_every_n_predict_steps + or num_steps_completed % self._save_every_n_predict_steps != 0 + ): + return + + self._generate_checkpoint_and_upkeep(state, unit, hook="on_predict_step_end") + + def _disable_ckpt_optimality_tracking(self) -> None: + """ + Disables checkpoint optimality tracking. This means that best_checkpoint and keep_last_n_checkpoints + will not be used. This is useful for eval and predict entrypoints, since checkpoints do not include + model parameters. + """ + if self._best_checkpoint_config: + logger.warning( + "Disabling best_checkpoint_config, since it is not supported for eval or predict entrypoints." + ) + self._best_checkpoint_config = None + self._checkpoint_manager._best_checkpoint_config = None + + if self._keep_last_n_checkpoints: + logger.warning( + "Disabling keep_last_n_checkpoints, since is not supported for eval or predict entrypoints." + ) + self._keep_last_n_checkpoints = None + self._checkpoint_manager._keep_last_n_checkpoints = None @abc.abstractmethod def _checkpoint_impl( diff --git a/torchtnt/framework/callbacks/dcp_saver.py b/torchtnt/framework/callbacks/dcp_saver.py index 14e99a1677..7cebc7f5a2 100644 --- a/torchtnt/framework/callbacks/dcp_saver.py +++ b/torchtnt/framework/callbacks/dcp_saver.py @@ -318,7 +318,7 @@ def restore_with_id( ) def _generate_checkpoint_and_upkeep( - self, state: State, unit: Union[TTrainUnit, TEvalUnit], hook: str + self, state: State, unit: Union[TTrainUnit, TEvalUnit, TPredictUnit], hook: str ) -> bool: # if we are still checkpointing, this might cause a collective hang, since several # operations in the base class use the process group. So wait here instead. diff --git a/torchtnt/framework/unit.py b/torchtnt/framework/unit.py index 8a1d7ff1e4..7e470f100d 100644 --- a/torchtnt/framework/unit.py +++ b/torchtnt/framework/unit.py @@ -586,6 +586,16 @@ def on_predict_epoch_end(self, state: State) -> None: """ pass + def on_checkpoint_save(self, state: State, checkpoint_id: str) -> None: + """Hook called after successfully saving a checkpoint. + + Args: + state: a :class:`~torchtnt.framework.state.State` object containing metadata about the training run. + checkpoint_id: the ID of the checkpoint that was saved. Depending on the storage type, this may be + a path, a URL or a unique identifier. + """ + pass + def on_predict_end(self, state: State) -> None: """Hook called after prediction ends.