diff --git a/torchtnt/framework/callbacks/dcp_saver.py b/torchtnt/framework/callbacks/dcp_saver.py index 5072e636af..3984d6fb86 100644 --- a/torchtnt/framework/callbacks/dcp_saver.py +++ b/torchtnt/framework/callbacks/dcp_saver.py @@ -16,8 +16,6 @@ from torch.distributed import checkpoint as dcp from torch.distributed.checkpoint._fsspec_filesystem import FsspecReader, FsspecWriter -from torch.distributed.checkpoint.state_dict import _init_optim_state -from torch.distributed.checkpoint.stateful import Stateful from torchtnt.framework.callbacks._checkpoint_utils import ( _prepare_app_state_for_checkpoint, _prepare_app_state_for_restore, @@ -39,8 +37,9 @@ TTrainUnit, ) from torchtnt.framework.utils import get_timing_context +from torchtnt.utils.optimizer import init_optim_state from torchtnt.utils.rank_zero_log import rank_zero_info, rank_zero_warn -from torchtnt.utils.stateful import MultiStateful +from torchtnt.utils.stateful import MultiStateful, Stateful logger: logging.Logger = logging.getLogger(__name__) @@ -249,7 +248,7 @@ def restore( # `torchtnt.utils.prepare_module.FSDPOptimizerWrapper`, this handles that case. optimizer = getattr(obj, "optimizer", obj) if isinstance(optimizer, torch.optim.Optimizer): - _init_optim_state(optimizer) + init_optim_state(optimizer) dcp.load( {"app_state": MultiStateful(app_state)},