Skip to content

Commit

Permalink
Separate out dataclass to torchsnapshot_saver_types.py (#614)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #614

This allows importing just the dataclass w/o depending on all of torchsnapshot - use case for things like CLIs

Reviewed By: galrotem

Differential Revision: D51139173

fbshipit-source-id: 472079502c625fc1ed3c073b00af46058da3330f
  • Loading branch information
gunchu authored and facebook-github-bot committed Nov 9, 2023
1 parent 206f58a commit 08dad27
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 32 deletions.
36 changes: 4 additions & 32 deletions torchtnt/framework/callbacks/torchsnapshot_saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
import os
import re
from contextlib import contextmanager, ExitStack
from dataclasses import dataclass
from typing import (
Any,
cast,
Expand All @@ -27,6 +26,10 @@
from pyre_extensions import none_throws

from torchtnt.framework.callback import Callback
from torchtnt.framework.callbacks.torchsnapshot_saver_types import (
KnobOptions,
RestoreOptions,
)
from torchtnt.framework.state import EntryPoint, State
from torchtnt.framework.unit import (
AppStateMixin,
Expand Down Expand Up @@ -65,37 +68,6 @@
logger: logging.Logger = logging.getLogger(__name__)


# TODO: eventually support overriding all knobs
@dataclass
class KnobOptions:
"""
Controls the knobs in TorchSnapshot.
Args:
max_per_rank_io_concurrency: Maximum number of concurrent IO operations per rank. Defaults to 16.
"""

max_per_rank_io_concurrency: Optional[int] = None


@dataclass
class RestoreOptions:
"""
Options when restoring a snapshot.
Args:
restore_train_progress: Whether to restore the training progress state.
restore_eval_progress: Whether to restore the evaluation progress state.
restore_optimizers: Whether to restore the optimizer states.
restore_lr_schedulers: Whether to restore the lr scheduler states.
"""

restore_train_progress: bool = True
restore_eval_progress: bool = True
restore_optimizers: bool = True
restore_lr_schedulers: bool = True


class TorchSnapshotSaver(Callback):
"""
A callback which periodically saves the application state during training using `TorchSnapshot <https://pytorch.org/torchsnapshot/>`_.
Expand Down
38 changes: 38 additions & 0 deletions torchtnt/framework/callbacks/torchsnapshot_saver_types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from dataclasses import dataclass
from typing import Optional

# TODO: eventually support overriding all knobs
@dataclass
class KnobOptions:
"""
Controls the knobs in TorchSnapshot.
Args:
max_per_rank_io_concurrency: Maximum number of concurrent IO operations per rank. Defaults to 16.
"""

max_per_rank_io_concurrency: Optional[int] = None


@dataclass
class RestoreOptions:
"""
Options when restoring a snapshot.
Args:
restore_train_progress: Whether to restore the training progress state.
restore_eval_progress: Whether to restore the evaluation progress state.
restore_optimizers: Whether to restore the optimizer states.
restore_lr_schedulers: Whether to restore the lr scheduler states.
"""

restore_train_progress: bool = True
restore_eval_progress: bool = True
restore_optimizers: bool = True
restore_lr_schedulers: bool = True

0 comments on commit 08dad27

Please sign in to comment.