diff --git a/torchtnt/framework/unit.py b/torchtnt/framework/unit.py index 1f58aba32a..8a1d7ff1e4 100644 --- a/torchtnt/framework/unit.py +++ b/torchtnt/framework/unit.py @@ -22,7 +22,7 @@ from torchtnt.utils.lr_scheduler import TLRScheduler from torchtnt.utils.prepare_module import _is_fsdp_module, FSDPOptimizerWrapper from torchtnt.utils.progress import Progress -from torchtnt.utils.stateful import Stateful +from torchtnt.utils.stateful import MetricStateful, Stateful _logger: logging.Logger = logging.getLogger(__name__) @@ -51,6 +51,7 @@ def __init__(self) -> None: self._optimizers: Dict[str, torch.optim.Optimizer] = {} self._lr_schedulers: Dict[str, TLRScheduler] = {} self._progress: Dict[str, Progress] = {} + self._metrics: Dict[str, MetricStateful] = {} # catch-all for miscellaneous statefuls self._misc_statefuls: Dict[str, Any] = {} # TODO: include other known statefuls @@ -67,6 +68,7 @@ def app_state(self) -> Dict[str, Any]: **self.tracked_lr_schedulers(), **self.tracked_progress(), **self.tracked_misc_statefuls(), + **self.tracked_metrics(), } return app_state @@ -84,6 +86,9 @@ def tracked_lr_schedulers( def tracked_progress(self) -> Dict[str, Progress]: return self._progress + def tracked_metrics(self) -> Dict[str, MetricStateful]: + return self._metrics + def tracked_misc_statefuls(self) -> Dict[str, Any]: return self._misc_statefuls @@ -104,6 +109,10 @@ def __getattr__(self, name: str) -> object: _progress = self.__dict__["_progress"] if name in _progress: return _progress[name] + if "_metrics" in self.__dict__: + _metrics = self.__dict__["_metrics"] + if name in _metrics: + return _metrics[name] if "_misc_statefuls" in self.__dict__: _misc_statefuls = self.__dict__["_misc_statefuls"] if name in _misc_statefuls: @@ -128,12 +137,16 @@ def _update_attr( self._optimizers, self._lr_schedulers, self._progress, + self._metrics, self._misc_statefuls, ) tracked_objects[name] = value def __setattr__(self, name: str, value: object) -> None: - if isinstance(value, torch.nn.Module): + # Check first for metrics since some libraries subclass nn.Module as well + if isinstance(value, MetricStateful): + self._update_attr(name, value, self.__dict__.get("_metrics")) + elif isinstance(value, torch.nn.Module): self._update_attr(name, value, self.__dict__.get("_modules")) elif isinstance(value, torch.optim.Optimizer): self._update_attr(name, value, self.__dict__.get("_optimizers")) @@ -163,6 +176,7 @@ def __setattr__(self, name: str, value: object) -> None: self._modules, self._optimizers, self._lr_schedulers, + self._metrics, self._misc_statefuls, ) super().__setattr__(name, value) @@ -176,6 +190,8 @@ def __delattr__(self, name: str) -> None: del self._lr_schedulers[name] elif name in self._progress: del self._progress[name] + elif name in self._metrics: + del self._metrics[name] elif name in self._misc_statefuls: del self._misc_statefuls[name] else: