Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update Checkpoint docs with DCP based checkpointer #904

Closed
wants to merge 1 commit into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 15 additions & 15 deletions docs/source/checkpointing.rst
Original file line number Diff line number Diff line change
@@ -1,24 +1,24 @@
Checkpointing
================================

TorchTNT offers checkpointing via the :class:`~torchtnt.framework.callbacks.TorchSnapshotSaver` which uses `TorchSnapshot <https://pytorch.org/torchsnapshot/main/>`_ under the hood.
TorchTNT offers checkpointing via :class:`~torchtnt.framework.callbacks.DistributedCheckpointSaver` which uses `DCP <https://github.com/pytorch/pytorch/tree/main/torch/distributed/checkpoint>`_ under the hood.

.. code-block:: python

module = nn.Linear(input_dim, 1)
unit = MyUnit(module=module)
tss = TorchSnapshotSaver(
dcp = DistributedCheckpointSaver(
dirpath=your_dirpath_here,
save_every_n_train_steps=100,
save_every_n_epochs=2,
)
# loads latest checkpoint, if it exists
if latest_checkpoint_dir:
tss.restore_from_latest(your_dirpath_here, unit, train_dataloader=dataloader)
dcp.restore_from_latest(your_dirpath_here, unit, train_dataloader=dataloader)
train(
unit,
dataloader,
callbacks=[tss]
callbacks=[dcp]
)

There is built-in support for saving and loading distributed models (DDP, FSDP).
Expand All @@ -37,15 +37,15 @@ The state dict type to be used for checkpointing FSDP modules can be specified i
)
module = prepare_fsdp(module, strategy=fsdp_strategy)
unit = MyUnit(module=module)
tss = TorchSnapshotSaver(
dcp = DistributedCheckpointSaver(
dirpath=your_dirpath_here,
save_every_n_epochs=2,
)
train(
unit,
dataloader,
# checkpointer callback will use state dict type specified in FSDPStrategy
callbacks=[tss]
callbacks=[dcp]
)

Or you can manually set this using `FSDP.set_state_dict_type <https://pytorch.org/docs/stable/fsdp.html#torch.distributed.fsdp.FullyShardedDataParallel.set_state_dict_type>`_.
Expand All @@ -56,14 +56,14 @@ Or you can manually set this using `FSDP.set_state_dict_type <https://pytorch.or
module = FSDP(module, ....)
FSDP.set_state_dict_type(module, StateDictType.SHARDED_STATE_DICT)
unit = MyUnit(module=module, ...)
tss = TorchSnapshotSaver(
dcp = DistributedCheckpointSaver(
dirpath=your_dirpath_here,
save_every_n_epochs=2,
)
train(
unit,
dataloader,
callbacks=[tss]
callbacks=[dcp]
)


Expand All @@ -74,15 +74,15 @@ When finetuning your models, you can pass RestoreOptions to avoid loading optimi

.. code-block:: python

tss = TorchSnapshotSaver(
dcp = DistributedCheckpointSaver(
dirpath=your_dirpath_here,
save_every_n_train_steps=100,
save_every_n_epochs=2,
)

# loads latest checkpoint, if it exists
if latest_checkpoint_dir:
tss.restore_from_latest(
dcp.restore_from_latest(
your_dirpath_here,
your_unit,
train_dataloader=dataloader,
Expand All @@ -99,7 +99,7 @@ Sometimes it may be helpful to keep track of how models perform. This can be don

module = nn.Linear(input_dim, 1)
unit = MyUnit(module=module)
tss = TorchSnapshotSaver(
dcp = DistributedCheckpointSaver(
dirpath=your_dirpath_here,
save_every_n_epochs=1,
best_checkpoint_config=BestCheckpointConfig(
Expand All @@ -111,7 +111,7 @@ Sometimes it may be helpful to keep track of how models perform. This can be don
train(
unit,
dataloader,
callbacks=[tss]
callbacks=[dcp]
)

By specifying the monitored metric to be "train_loss", the checkpointer will expect the :class:`~torchtnt.framework.unit.TrainUnit` to have a "train_loss" attribute at the time of checkpointing, and it will cast this value to a float and append the value to the checkpoint path name. This attribute is expected to be computed and kept up to date appropriately in the unit by the user.
Expand All @@ -120,13 +120,13 @@ Later on, the best checkpoint can be loaded via

.. code-block:: python

TorchSnapshotSaver.restore_from_best(your_dirpath_here, unit, metric_name="train_loss", mode="min")
DistributedCheckpointSaver.restore_from_best(your_dirpath_here, unit, metric_name="train_loss", mode="min")

If you'd like to monitor a validation metric (say validation loss after each eval epoch during :py:func:`~torchtnt.framework.fit.fit`), you can use the `save_every_n_eval_epochs` flag instead, like so

.. code-block:: python

tss = TorchSnapshotSaver(
dcp = DistributedCheckpointSaver(
dirpath=your_dirpath_here,
save_every_n_eval_epochs=1,
best_checkpoint_config=BestCheckpointConfig(
Expand All @@ -139,7 +139,7 @@ And to save only the top three performing models, you can use the existing `keep

.. code-block:: python

tss = TorchSnapshotSaver(
dcp = DistributedCheckpointSaver(
dirpath=your_dirpath_here,
save_every_n_eval_epochs=1,
keep_last_n_checkpoints=3,
Expand Down
Loading