Skip to content

Commit

Permalink
replace _init_optim_state w/ tnt's util
Browse files Browse the repository at this point in the history
Differential Revision:
D56446429

Split of "[tnt] test against stable pytorch version"
  • Loading branch information
JKSenthil authored and facebook-github-bot committed Apr 25, 2024
1 parent 5e51dd5 commit 7a3a2f4
Showing 1 changed file with 3 additions and 4 deletions.
7 changes: 3 additions & 4 deletions torchtnt/framework/callbacks/dcp_saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,6 @@
from torch.distributed import checkpoint as dcp

from torch.distributed.checkpoint._fsspec_filesystem import FsspecReader, FsspecWriter
from torch.distributed.checkpoint.state_dict import _init_optim_state
from torch.distributed.checkpoint.stateful import Stateful
from torchtnt.framework.callbacks._checkpoint_utils import (
_prepare_app_state_for_checkpoint,
_prepare_app_state_for_restore,
Expand All @@ -39,8 +37,9 @@
TTrainUnit,
)
from torchtnt.framework.utils import get_timing_context
from torchtnt.utils.optimizer import init_optim_state
from torchtnt.utils.rank_zero_log import rank_zero_info, rank_zero_warn
from torchtnt.utils.stateful import MultiStateful
from torchtnt.utils.stateful import MultiStateful, Stateful


logger: logging.Logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -249,7 +248,7 @@ def restore(
# `torchtnt.utils.prepare_module.FSDPOptimizerWrapper`, this handles that case.
optimizer = getattr(obj, "optimizer", obj)
if isinstance(optimizer, torch.optim.Optimizer):
_init_optim_state(optimizer)
init_optim_state(optimizer)

dcp.load(
{"app_state": MultiStateful(app_state)},
Expand Down

0 comments on commit 7a3a2f4

Please sign in to comment.