Skip to content

Commit

Permalink
remove init_optim_state in dcp checkpointer
Browse files Browse the repository at this point in the history
Differential Revision: D59661542
  • Loading branch information
JKSenthil authored and facebook-github-bot committed Sep 16, 2024
1 parent 57a4279 commit 9db75de
Showing 1 changed file with 0 additions and 11 deletions.
11 changes: 0 additions & 11 deletions torchtnt/framework/callbacks/dcp_saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
from datetime import timedelta
from typing import Any, Dict, Iterable, List, Optional, Union

import torch
import torch.distributed as dist
from pyre_extensions import none_throws
from torch.distributed import checkpoint as dcp
Expand Down Expand Up @@ -46,7 +45,6 @@
)
from torchtnt.framework.utils import get_timing_context
from torchtnt.utils.checkpoint import BestCheckpointConfig
from torchtnt.utils.optimizer import init_optim_state
from torchtnt.utils.rank_zero_log import rank_zero_info, rank_zero_warn
from torchtnt.utils.stateful import MultiStateful, Stateful

Expand Down Expand Up @@ -323,15 +321,6 @@ def restore_with_id(
"train_dataloader was passed to `restore` but no train dataloader exists in the Snapshot"
)

# necessary for loading optimizers since states are initialized lazy
for obj in app_state.values():
# sometimes optimizers are actually held in a wrapper which handles calling
# state_dict and load_state_dict, sa is the case for
# `torchtnt.utils.prepare_module.FSDPOptimizerWrapper`, this handles that case.
optimizer = getattr(obj, "optimizer", obj)
if isinstance(optimizer, torch.optim.Optimizer):
init_optim_state(optimizer)

dcp.load(
{"app_state": MultiStateful(app_state)},
checkpoint_id=checkpoint_id,
Expand Down

0 comments on commit 9db75de

Please sign in to comment.