From ad3ff8690fdf0d64fa92460c3b99a208cc1074bf Mon Sep 17 00:00:00 2001 From: Gal Rotem Date: Thu, 25 Apr 2024 13:31:39 -0700 Subject: [PATCH] state helper - active phase state Reviewed By: diego-urgell Differential Revision: D56496429 fbshipit-source-id: ab6c3c69fc624a73cf3095f01c970450c107e02a --- torchtnt/framework/state.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/torchtnt/framework/state.py b/torchtnt/framework/state.py index 7af4036180..f6838fa8b9 100644 --- a/torchtnt/framework/state.py +++ b/torchtnt/framework/state.py @@ -12,6 +12,8 @@ from enum import auto, Enum from typing import Generic, Iterable, Optional, TypeVar +from pyre_extensions import none_throws + from torchtnt.utils.timer import BoundedTimer, TimerProtocol _logger: logging.Logger = logging.getLogger(__name__) @@ -199,3 +201,14 @@ def stop(self) -> None: """Signal to the loop to end after the current step completes.""" _logger.warning("Received signal to stop") self._should_stop = True + + def active_phase_state(self) -> TPhaseState: + """Returns the current active phase state.""" + if self._active_phase == ActivePhase.TRAIN: + return none_throws(self._train_state) + elif self._active_phase == ActivePhase.EVALUATE: + return none_throws(self._eval_state) + elif self._active_phase == ActivePhase.PREDICT: + return none_throws(self._predict_state) + else: + raise ValueError(f"Invalid active phase: {self._active_phase}")