From f03fd9b276bfba1ec87a640c01172e41d816fe07 Mon Sep 17 00:00:00 2001 From: Saurabh Mishra Date: Tue, 24 Sep 2024 08:19:56 -0700 Subject: [PATCH] Update Checkpoint docs with DCP based checkpointer (#904) Summary: Pull Request resolved: https://github.com/pytorch/tnt/pull/904 Update Checkpoint docs with DCP based checkpointer Reviewed By: JKSenthil Differential Revision: D63278746 fbshipit-source-id: 0c01e21fb516996001d3c12fe74504f65f5ed783 --- docs/source/checkpointing.rst | 30 +++++++++++++++--------------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/docs/source/checkpointing.rst b/docs/source/checkpointing.rst index 9596c0d776..466b4b262e 100644 --- a/docs/source/checkpointing.rst +++ b/docs/source/checkpointing.rst @@ -1,24 +1,24 @@ Checkpointing ================================ -TorchTNT offers checkpointing via the :class:`~torchtnt.framework.callbacks.TorchSnapshotSaver` which uses `TorchSnapshot `_ under the hood. +TorchTNT offers checkpointing via :class:`~torchtnt.framework.callbacks.DistributedCheckpointSaver` which uses `DCP `_ under the hood. .. code-block:: python module = nn.Linear(input_dim, 1) unit = MyUnit(module=module) - tss = TorchSnapshotSaver( + dcp = DistributedCheckpointSaver( dirpath=your_dirpath_here, save_every_n_train_steps=100, save_every_n_epochs=2, ) # loads latest checkpoint, if it exists if latest_checkpoint_dir: - tss.restore_from_latest(your_dirpath_here, unit, train_dataloader=dataloader) + dcp.restore_from_latest(your_dirpath_here, unit, train_dataloader=dataloader) train( unit, dataloader, - callbacks=[tss] + callbacks=[dcp] ) There is built-in support for saving and loading distributed models (DDP, FSDP). @@ -37,7 +37,7 @@ The state dict type to be used for checkpointing FSDP modules can be specified i ) module = prepare_fsdp(module, strategy=fsdp_strategy) unit = MyUnit(module=module) - tss = TorchSnapshotSaver( + dcp = DistributedCheckpointSaver( dirpath=your_dirpath_here, save_every_n_epochs=2, ) @@ -45,7 +45,7 @@ The state dict type to be used for checkpointing FSDP modules can be specified i unit, dataloader, # checkpointer callback will use state dict type specified in FSDPStrategy - callbacks=[tss] + callbacks=[dcp] ) Or you can manually set this using `FSDP.set_state_dict_type `_. @@ -56,14 +56,14 @@ Or you can manually set this using `FSDP.set_state_dict_type