diff --git a/docs/source/checkpointing.rst b/docs/source/checkpointing.rst index 6f7c9c0ef8..9596c0d776 100644 --- a/docs/source/checkpointing.rst +++ b/docs/source/checkpointing.rst @@ -119,6 +119,7 @@ By specifying the monitored metric to be "train_loss", the checkpointer will exp Later on, the best checkpoint can be loaded via .. code-block:: python + TorchSnapshotSaver.restore_from_best(your_dirpath_here, unit, metric_name="train_loss", mode="min") If you'd like to monitor a validation metric (say validation loss after each eval epoch during :py:func:`~torchtnt.framework.fit.fit`), you can use the `save_every_n_eval_epochs` flag instead, like so