diff --git a/torchtnt/framework/callbacks/dcp_saver.py b/torchtnt/framework/callbacks/dcp_saver.py index f091e8de94..b1b4a232a5 100644 --- a/torchtnt/framework/callbacks/dcp_saver.py +++ b/torchtnt/framework/callbacks/dcp_saver.py @@ -12,7 +12,6 @@ from datetime import timedelta from typing import Any, Dict, Iterable, List, Optional, Union -import torch import torch.distributed as dist from pyre_extensions import none_throws from torch.distributed import checkpoint as dcp @@ -46,7 +45,6 @@ ) from torchtnt.framework.utils import get_timing_context from torchtnt.utils.checkpoint import BestCheckpointConfig -from torchtnt.utils.optimizer import init_optim_state from torchtnt.utils.rank_zero_log import rank_zero_info, rank_zero_warn from torchtnt.utils.stateful import MultiStateful, Stateful @@ -323,15 +321,6 @@ def restore_with_id( "train_dataloader was passed to `restore` but no train dataloader exists in the Snapshot" ) - # necessary for loading optimizers since states are initialized lazy - for obj in app_state.values(): - # sometimes optimizers are actually held in a wrapper which handles calling - # state_dict and load_state_dict, sa is the case for - # `torchtnt.utils.prepare_module.FSDPOptimizerWrapper`, this handles that case. - optimizer = getattr(obj, "optimizer", obj) - if isinstance(optimizer, torch.optim.Optimizer): - init_optim_state(optimizer) - dcp.load( {"app_state": MultiStateful(app_state)}, checkpoint_id=checkpoint_id,