diff --git a/torchtnt/framework/callbacks/__init__.py b/torchtnt/framework/callbacks/__init__.py index 35566d1fcd..44d24214cf 100644 --- a/torchtnt/framework/callbacks/__init__.py +++ b/torchtnt/framework/callbacks/__init__.py @@ -7,6 +7,7 @@ # pyre-strict from .base_csv_writer import BaseCSVWriter +from .dcp_saver import DistributedCheckpointSaver from .early_stopping import EarlyStopping from .empty_cuda_cache import EmptyCudaCache from .garbage_collector import GarbageCollector @@ -44,4 +45,5 @@ "TorchSnapshotSaver", "TQDMProgressBar", "TrainProgressMonitor", + "DistributedCheckpointSaver", ]