Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Back out "Move BestCheckpointConfig to utils/checkpoint.py" #817

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions tests/framework/callbacks/test_base_checkpointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,14 +31,16 @@
from torchtnt.framework.callbacks.base_checkpointer import (
BaseCheckpointer as BaseCheckpointer,
)
from torchtnt.framework.callbacks.checkpointer_types import RestoreOptions
from torchtnt.framework.callbacks.checkpointer_types import (
BestCheckpointConfig,
RestoreOptions,
)
from torchtnt.framework.callbacks.lambda_callback import Lambda
from torchtnt.framework.fit import fit
from torchtnt.framework.state import State

from torchtnt.framework.train import train
from torchtnt.framework.unit import AppStateMixin, TrainUnit, TTrainData
from torchtnt.utils.checkpoint import BestCheckpointConfig
from torchtnt.utils.distributed import get_global_rank, spawn_multi_process
from torchtnt.utils.env import init_from_env
from torchtnt.utils.test_utils import skip_if_not_distributed
Expand Down
6 changes: 4 additions & 2 deletions torchtnt/framework/callbacks/base_checkpointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,10 @@
import torch.distributed as dist
from pyre_extensions import none_throws
from torchtnt.framework.callback import Callback
from torchtnt.framework.callbacks.checkpointer_types import RestoreOptions
from torchtnt.framework.callbacks.checkpointer_types import (
BestCheckpointConfig,
RestoreOptions,
)
from torchtnt.framework.state import EntryPoint, State
from torchtnt.framework.unit import AppStateMixin, TEvalUnit, TTrainData, TTrainUnit
from torchtnt.framework.utils import get_timing_context
Expand All @@ -25,7 +28,6 @@
_metadata_exists,
_sort_by_metric_value,
_sort_by_recency,
BestCheckpointConfig,
get_best_checkpoint_path,
get_checkpoint_dirpaths,
get_latest_checkpoint_path,
Expand Down
16 changes: 15 additions & 1 deletion torchtnt/framework/callbacks/checkpointer_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
# pyre-strict

from dataclasses import dataclass
from typing import Optional
from typing import Literal, Optional


# TODO: eventually support overriding all knobs
Expand Down Expand Up @@ -39,3 +39,17 @@ class RestoreOptions:
restore_eval_progress: bool = True
restore_optimizers: bool = True
restore_lr_schedulers: bool = True


@dataclass
class BestCheckpointConfig:
"""
Config for saving the best checkpoints.

Args:
monitored_metric: Metric to monitor for saving best checkpoints. Must be an numerical or tensor attribute on the unit.
mode: One of `min` or `max`. The save file is overwritten based the max or min of the monitored metric.
"""

monitored_metric: str
mode: Literal["min", "max"] = "min"
7 changes: 5 additions & 2 deletions torchtnt/framework/callbacks/dcp_saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,11 @@
)

from torchtnt.framework.callbacks.base_checkpointer import BaseCheckpointer
from torchtnt.framework.callbacks.checkpointer_types import KnobOptions, RestoreOptions
from torchtnt.framework.callbacks.checkpointer_types import (
BestCheckpointConfig,
KnobOptions,
RestoreOptions,
)
from torchtnt.framework.state import State
from torchtnt.framework.unit import (
AppStateMixin,
Expand All @@ -32,7 +36,6 @@
TTrainUnit,
)
from torchtnt.framework.utils import get_timing_context
from torchtnt.utils.checkpoint import BestCheckpointConfig
from torchtnt.utils.optimizer import init_optim_state
from torchtnt.utils.rank_zero_log import rank_zero_info, rank_zero_warn
from torchtnt.utils.stateful import MultiStateful, Stateful
Expand Down
7 changes: 5 additions & 2 deletions torchtnt/framework/callbacks/torchsnapshot_saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,11 @@
)

from torchtnt.framework.callbacks.base_checkpointer import BaseCheckpointer
from torchtnt.framework.callbacks.checkpointer_types import KnobOptions, RestoreOptions
from torchtnt.framework.callbacks.checkpointer_types import (
BestCheckpointConfig,
KnobOptions,
RestoreOptions,
)
from torchtnt.framework.state import State
from torchtnt.framework.unit import (
AppStateMixin,
Expand All @@ -32,7 +36,6 @@
TTrainUnit,
)
from torchtnt.framework.utils import get_timing_context
from torchtnt.utils.checkpoint import BestCheckpointConfig
from torchtnt.utils.optimizer import init_optim_state
from torchtnt.utils.rank_zero_log import rank_zero_info, rank_zero_warn
from torchtnt.utils.stateful import Stateful
Expand Down
2 changes: 0 additions & 2 deletions torchtnt/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
# pyre-strict

from .checkpoint import (
BestCheckpointConfig,
CheckpointPath,
get_best_checkpoint_path,
get_checkpoint_dirpaths,
Expand Down Expand Up @@ -92,7 +91,6 @@
"get_best_checkpoint_path",
"get_checkpoint_dirpaths",
"get_latest_checkpoint_path",
"BestCheckpointConfig",
"copy_data_to_device",
"CPUStats",
"get_device_from_env",
Expand Down
14 changes: 0 additions & 14 deletions torchtnt/utils/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,20 +31,6 @@ class MetricData:
value: float


@dataclass
class BestCheckpointConfig:
"""
Config for saving the best checkpoints.

Args:
monitored_metric: Metric to monitor for saving best checkpoints. Must be an numerical or tensor attribute on the unit.
mode: One of `min` or `max`. The save file is overwritten based the max or min of the monitored metric.
"""

monitored_metric: str
mode: Literal["min", "max"] = "min"


@total_ordering
class CheckpointPath:
"""
Expand Down