Skip to content

Commit

Permalink
Introduce stateful metric protocol (pytorch#894)
Browse files Browse the repository at this point in the history
Summary: Pull Request resolved: pytorch#894

Differential Revision: D62395083
  • Loading branch information
diego-urgell authored and facebook-github-bot committed Sep 11, 2024
1 parent 665dd50 commit 1bcb5be
Showing 1 changed file with 18 additions and 1 deletion.
19 changes: 18 additions & 1 deletion torchtnt/utils/stateful.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None: ...
class MultiStateful:
"""
Wrapper for multiple stateful objects. Necessary because we might have multiple nn.Modules or multiple optimizers,
but save/load_checkpoint APIs may only accepts one stateful object.
but save/load_checkpoint APIs may only accept one stateful object.
Stores state_dict as a dict of state_dicts.
"""
Expand All @@ -55,3 +55,20 @@ def state_dict(self) -> Dict[str, Any]:
def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
for k in state_dict:
self.stateful_objs[k].load_state_dict(state_dict[k])


@runtime_checkable
class MetricStateful(Protocol):
"""
Defines the interfaces for metric objects that can be saved and loaded from checkpoints.
This conforms to the API exposed by major metric libraries like torcheval and torchmetrics.
"""

def update(self, *_: Any, **__: Any) -> None: ...

# pyre-ignore[3]: Metric computation may return any type depending on the implementation
def compute(self) -> Any: ...

def state_dict(self) -> Dict[str, Any]: ...

def load_state_dict(self, state_dict: Dict[str, Any]) -> None: ...

0 comments on commit 1bcb5be

Please sign in to comment.