diff --git a/torchtnt/framework/callbacks/dcp_saver.py b/torchtnt/framework/callbacks/dcp_saver.py index 249e405727..c84e0e8264 100644 --- a/torchtnt/framework/callbacks/dcp_saver.py +++ b/torchtnt/framework/callbacks/dcp_saver.py @@ -136,7 +136,12 @@ def _checkpoint_impl( planner: Optional[SavePlanner] = None, storage_writer: Optional[StorageWriter] = None, ) -> bool: - if hook not in ["on_train_step_end", "on_train_epoch_end", "on_train_end"]: + if hook not in [ + "on_train_step_end", + "on_train_epoch_end", + "on_train_end", + "on_eval_epoch_end", + ]: raise RuntimeError(f"Unexpected hook encountered '{hook}'") intra_epoch = hook == "on_train_step_end"