Skip to content

Commit

Permalink
state helper - active phase state
Browse files Browse the repository at this point in the history
Reviewed By: diego-urgell

Differential Revision: D56496429

fbshipit-source-id: ab6c3c69fc624a73cf3095f01c970450c107e02a
  • Loading branch information
galrotem authored and facebook-github-bot committed Apr 25, 2024
1 parent e7b9e64 commit ad3ff86
Showing 1 changed file with 13 additions and 0 deletions.
13 changes: 13 additions & 0 deletions torchtnt/framework/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
from enum import auto, Enum
from typing import Generic, Iterable, Optional, TypeVar

from pyre_extensions import none_throws

from torchtnt.utils.timer import BoundedTimer, TimerProtocol

_logger: logging.Logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -199,3 +201,14 @@ def stop(self) -> None:
"""Signal to the loop to end after the current step completes."""
_logger.warning("Received signal to stop")
self._should_stop = True

def active_phase_state(self) -> TPhaseState:
"""Returns the current active phase state."""
if self._active_phase == ActivePhase.TRAIN:
return none_throws(self._train_state)
elif self._active_phase == ActivePhase.EVALUATE:
return none_throws(self._eval_state)
elif self._active_phase == ActivePhase.PREDICT:
return none_throws(self._predict_state)
else:
raise ValueError(f"Invalid active phase: {self._active_phase}")

0 comments on commit ad3ff86

Please sign in to comment.