diff --git a/docs/source/checkpointing.rst b/docs/source/checkpointing.rst
index 9596c0d776..466b4b262e 100644
--- a/docs/source/checkpointing.rst
+++ b/docs/source/checkpointing.rst
@@ -1,24 +1,24 @@
Checkpointing
================================
-TorchTNT offers checkpointing via the :class:`~torchtnt.framework.callbacks.TorchSnapshotSaver` which uses `TorchSnapshot `_ under the hood.
+TorchTNT offers checkpointing via :class:`~torchtnt.framework.callbacks.DistributedCheckpointSaver` which uses `DCP `_ under the hood.
.. code-block:: python
module = nn.Linear(input_dim, 1)
unit = MyUnit(module=module)
- tss = TorchSnapshotSaver(
+ dcp = DistributedCheckpointSaver(
dirpath=your_dirpath_here,
save_every_n_train_steps=100,
save_every_n_epochs=2,
)
# loads latest checkpoint, if it exists
if latest_checkpoint_dir:
- tss.restore_from_latest(your_dirpath_here, unit, train_dataloader=dataloader)
+ dcp.restore_from_latest(your_dirpath_here, unit, train_dataloader=dataloader)
train(
unit,
dataloader,
- callbacks=[tss]
+ callbacks=[dcp]
)
There is built-in support for saving and loading distributed models (DDP, FSDP).
@@ -37,7 +37,7 @@ The state dict type to be used for checkpointing FSDP modules can be specified i
)
module = prepare_fsdp(module, strategy=fsdp_strategy)
unit = MyUnit(module=module)
- tss = TorchSnapshotSaver(
+ dcp = DistributedCheckpointSaver(
dirpath=your_dirpath_here,
save_every_n_epochs=2,
)
@@ -45,7 +45,7 @@ The state dict type to be used for checkpointing FSDP modules can be specified i
unit,
dataloader,
# checkpointer callback will use state dict type specified in FSDPStrategy
- callbacks=[tss]
+ callbacks=[dcp]
)
Or you can manually set this using `FSDP.set_state_dict_type `_.
@@ -56,14 +56,14 @@ Or you can manually set this using `FSDP.set_state_dict_type