Skip to content

Commit

Permalink
Update Checkpoint docs with DCP based checkpointer (#904)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #904

Update Checkpoint docs with DCP based checkpointer

Reviewed By: JKSenthil

Differential Revision: D63278746

fbshipit-source-id: 0c01e21fb516996001d3c12fe74504f65f5ed783
  • Loading branch information
saumishr authored and facebook-github-bot committed Sep 24, 2024
1 parent 843835c commit f03fd9b
Showing 1 changed file with 15 additions and 15 deletions.
30 changes: 15 additions & 15 deletions docs/source/checkpointing.rst
Original file line number Diff line number Diff line change
@@ -1,24 +1,24 @@
Checkpointing
================================

TorchTNT offers checkpointing via the :class:`~torchtnt.framework.callbacks.TorchSnapshotSaver` which uses `TorchSnapshot <https://pytorch.org/torchsnapshot/main/>`_ under the hood.
TorchTNT offers checkpointing via :class:`~torchtnt.framework.callbacks.DistributedCheckpointSaver` which uses `DCP <https://github.com/pytorch/pytorch/tree/main/torch/distributed/checkpoint>`_ under the hood.

.. code-block:: python
module = nn.Linear(input_dim, 1)
unit = MyUnit(module=module)
tss = TorchSnapshotSaver(
dcp = DistributedCheckpointSaver(
dirpath=your_dirpath_here,
save_every_n_train_steps=100,
save_every_n_epochs=2,
)
# loads latest checkpoint, if it exists
if latest_checkpoint_dir:
tss.restore_from_latest(your_dirpath_here, unit, train_dataloader=dataloader)
dcp.restore_from_latest(your_dirpath_here, unit, train_dataloader=dataloader)
train(
unit,
dataloader,
callbacks=[tss]
callbacks=[dcp]
)
There is built-in support for saving and loading distributed models (DDP, FSDP).
Expand All @@ -37,15 +37,15 @@ The state dict type to be used for checkpointing FSDP modules can be specified i
)
module = prepare_fsdp(module, strategy=fsdp_strategy)
unit = MyUnit(module=module)
tss = TorchSnapshotSaver(
dcp = DistributedCheckpointSaver(
dirpath=your_dirpath_here,
save_every_n_epochs=2,
)
train(
unit,
dataloader,
# checkpointer callback will use state dict type specified in FSDPStrategy
callbacks=[tss]
callbacks=[dcp]
)
Or you can manually set this using `FSDP.set_state_dict_type <https://pytorch.org/docs/stable/fsdp.html#torch.distributed.fsdp.FullyShardedDataParallel.set_state_dict_type>`_.
Expand All @@ -56,14 +56,14 @@ Or you can manually set this using `FSDP.set_state_dict_type <https://pytorch.or
module = FSDP(module, ....)
FSDP.set_state_dict_type(module, StateDictType.SHARDED_STATE_DICT)
unit = MyUnit(module=module, ...)
tss = TorchSnapshotSaver(
dcp = DistributedCheckpointSaver(
dirpath=your_dirpath_here,
save_every_n_epochs=2,
)
train(
unit,
dataloader,
callbacks=[tss]
callbacks=[dcp]
)
Expand All @@ -74,15 +74,15 @@ When finetuning your models, you can pass RestoreOptions to avoid loading optimi

.. code-block:: python
tss = TorchSnapshotSaver(
dcp = DistributedCheckpointSaver(
dirpath=your_dirpath_here,
save_every_n_train_steps=100,
save_every_n_epochs=2,
)
# loads latest checkpoint, if it exists
if latest_checkpoint_dir:
tss.restore_from_latest(
dcp.restore_from_latest(
your_dirpath_here,
your_unit,
train_dataloader=dataloader,
Expand All @@ -99,7 +99,7 @@ Sometimes it may be helpful to keep track of how models perform. This can be don
module = nn.Linear(input_dim, 1)
unit = MyUnit(module=module)
tss = TorchSnapshotSaver(
dcp = DistributedCheckpointSaver(
dirpath=your_dirpath_here,
save_every_n_epochs=1,
best_checkpoint_config=BestCheckpointConfig(
Expand All @@ -111,7 +111,7 @@ Sometimes it may be helpful to keep track of how models perform. This can be don
train(
unit,
dataloader,
callbacks=[tss]
callbacks=[dcp]
)
By specifying the monitored metric to be "train_loss", the checkpointer will expect the :class:`~torchtnt.framework.unit.TrainUnit` to have a "train_loss" attribute at the time of checkpointing, and it will cast this value to a float and append the value to the checkpoint path name. This attribute is expected to be computed and kept up to date appropriately in the unit by the user.
Expand All @@ -120,13 +120,13 @@ Later on, the best checkpoint can be loaded via

.. code-block:: python
TorchSnapshotSaver.restore_from_best(your_dirpath_here, unit, metric_name="train_loss", mode="min")
DistributedCheckpointSaver.restore_from_best(your_dirpath_here, unit, metric_name="train_loss", mode="min")
If you'd like to monitor a validation metric (say validation loss after each eval epoch during :py:func:`~torchtnt.framework.fit.fit`), you can use the `save_every_n_eval_epochs` flag instead, like so

.. code-block:: python
tss = TorchSnapshotSaver(
dcp = DistributedCheckpointSaver(
dirpath=your_dirpath_here,
save_every_n_eval_epochs=1,
best_checkpoint_config=BestCheckpointConfig(
Expand All @@ -139,7 +139,7 @@ And to save only the top three performing models, you can use the existing `keep

.. code-block:: python
tss = TorchSnapshotSaver(
dcp = DistributedCheckpointSaver(
dirpath=your_dirpath_here,
save_every_n_eval_epochs=1,
keep_last_n_checkpoints=3,
Expand Down

0 comments on commit f03fd9b

Please sign in to comment.