Skip to content

Commit

Permalink
Make AppStateMixin metric aware
Browse files Browse the repository at this point in the history
Differential Revision: D62555231
  • Loading branch information
diego-urgell authored and facebook-github-bot committed Sep 12, 2024
1 parent 60a6360 commit 31f4a6a
Showing 1 changed file with 18 additions and 2 deletions.
20 changes: 18 additions & 2 deletions torchtnt/framework/unit.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from torchtnt.utils.lr_scheduler import TLRScheduler
from torchtnt.utils.prepare_module import _is_fsdp_module, FSDPOptimizerWrapper
from torchtnt.utils.progress import Progress
from torchtnt.utils.stateful import Stateful
from torchtnt.utils.stateful import MetricStateful, Stateful


_logger: logging.Logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -51,6 +51,7 @@ def __init__(self) -> None:
self._optimizers: Dict[str, torch.optim.Optimizer] = {}
self._lr_schedulers: Dict[str, TLRScheduler] = {}
self._progress: Dict[str, Progress] = {}
self._metrics: Dict[str, MetricStateful] = {}
# catch-all for miscellaneous statefuls
self._misc_statefuls: Dict[str, Any] = {}
# TODO: include other known statefuls
Expand All @@ -67,6 +68,7 @@ def app_state(self) -> Dict[str, Any]:
**self.tracked_lr_schedulers(),
**self.tracked_progress(),
**self.tracked_misc_statefuls(),
**self.tracked_metrics(),
}
return app_state

Expand All @@ -84,6 +86,9 @@ def tracked_lr_schedulers(
def tracked_progress(self) -> Dict[str, Progress]:
return self._progress

def tracked_metrics(self) -> Dict[str, MetricStateful]:
return self._metrics

def tracked_misc_statefuls(self) -> Dict[str, Any]:
return self._misc_statefuls

Expand All @@ -104,6 +109,10 @@ def __getattr__(self, name: str) -> object:
_progress = self.__dict__["_progress"]
if name in _progress:
return _progress[name]
if "_metrics" in self.__dict__:
_metrics = self.__dict__["_metrics"]
if name in _metrics:
return _metrics[name]
if "_misc_statefuls" in self.__dict__:
_misc_statefuls = self.__dict__["_misc_statefuls"]
if name in _misc_statefuls:
Expand All @@ -128,12 +137,16 @@ def _update_attr(
self._optimizers,
self._lr_schedulers,
self._progress,
self._metrics,
self._misc_statefuls,
)
tracked_objects[name] = value

def __setattr__(self, name: str, value: object) -> None:
if isinstance(value, torch.nn.Module):
# Check first for metrics since some libraries subclass nn.Modsule as well
if isinstance(value, MetricStateful):
self._update_attr(name, value, self.__dict__.get("_metrics"))
elif isinstance(value, torch.nn.Module):
self._update_attr(name, value, self.__dict__.get("_modules"))
elif isinstance(value, torch.optim.Optimizer):
self._update_attr(name, value, self.__dict__.get("_optimizers"))
Expand Down Expand Up @@ -163,6 +176,7 @@ def __setattr__(self, name: str, value: object) -> None:
self._modules,
self._optimizers,
self._lr_schedulers,
self._metrics,
self._misc_statefuls,
)
super().__setattr__(name, value)
Expand All @@ -176,6 +190,8 @@ def __delattr__(self, name: str) -> None:
del self._lr_schedulers[name]
elif name in self._progress:
del self._progress[name]
elif name in self._metrics:
del self._metrics[name]
elif name in self._misc_statefuls:
del self._misc_statefuls[name]
else:
Expand Down

0 comments on commit 31f4a6a

Please sign in to comment.