Skip to content

Commit

Permalink
Enable DCP checkpoints with on_eval_epoch_end hook (pytorch#851)
Browse files Browse the repository at this point in the history
Summary: Pull Request resolved: pytorch#851

Reviewed By: zedsdead01, JKSenthil

Differential Revision: D58967392

fbshipit-source-id: 0ecdd0bec0804167127daa2530bf77b63bf7ca6e
  • Loading branch information
diego-urgell authored and facebook-github-bot committed Jun 25, 2024
1 parent 8afe26d commit 1d118c4
Showing 1 changed file with 6 additions and 1 deletion.
7 changes: 6 additions & 1 deletion torchtnt/framework/callbacks/dcp_saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,12 @@ def _checkpoint_impl(
planner: Optional[SavePlanner] = None,
storage_writer: Optional[StorageWriter] = None,
) -> bool:
if hook not in ["on_train_step_end", "on_train_epoch_end", "on_train_end"]:
if hook not in [
"on_train_step_end",
"on_train_epoch_end",
"on_train_end",
"on_eval_epoch_end",
]:
raise RuntimeError(f"Unexpected hook encountered '{hook}'")

intra_epoch = hook == "on_train_step_end"
Expand Down

0 comments on commit 1d118c4

Please sign in to comment.