diff --git a/docs/source/checkpointing.rst b/docs/source/checkpointing.rst index 9596c0d776..466b4b262e 100644 --- a/docs/source/checkpointing.rst +++ b/docs/source/checkpointing.rst @@ -1,24 +1,24 @@ Checkpointing ================================ -TorchTNT offers checkpointing via the :class:`~torchtnt.framework.callbacks.TorchSnapshotSaver` which uses `TorchSnapshot `_ under the hood. +TorchTNT offers checkpointing via :class:`~torchtnt.framework.callbacks.DistributedCheckpointSaver` which uses `DCP `_ under the hood. .. code-block:: python module = nn.Linear(input_dim, 1) unit = MyUnit(module=module) - tss = TorchSnapshotSaver( + dcp = DistributedCheckpointSaver( dirpath=your_dirpath_here, save_every_n_train_steps=100, save_every_n_epochs=2, ) # loads latest checkpoint, if it exists if latest_checkpoint_dir: - tss.restore_from_latest(your_dirpath_here, unit, train_dataloader=dataloader) + dcp.restore_from_latest(your_dirpath_here, unit, train_dataloader=dataloader) train( unit, dataloader, - callbacks=[tss] + callbacks=[dcp] ) There is built-in support for saving and loading distributed models (DDP, FSDP). @@ -37,7 +37,7 @@ The state dict type to be used for checkpointing FSDP modules can be specified i ) module = prepare_fsdp(module, strategy=fsdp_strategy) unit = MyUnit(module=module) - tss = TorchSnapshotSaver( + dcp = DistributedCheckpointSaver( dirpath=your_dirpath_here, save_every_n_epochs=2, ) @@ -45,7 +45,7 @@ The state dict type to be used for checkpointing FSDP modules can be specified i unit, dataloader, # checkpointer callback will use state dict type specified in FSDPStrategy - callbacks=[tss] + callbacks=[dcp] ) Or you can manually set this using `FSDP.set_state_dict_type `_. @@ -56,14 +56,14 @@ Or you can manually set this using `FSDP.set_state_dict_type