diff --git a/tests/utils/loggers/test_tensorboard.py b/tests/utils/loggers/test_tensorboard.py index 709c437cf0..d9eec49885 100644 --- a/tests/utils/loggers/test_tensorboard.py +++ b/tests/utils/loggers/test_tensorboard.py @@ -11,20 +11,53 @@ import tempfile import unittest -from unittest.mock import Mock, patch +from unittest.mock import MagicMock, Mock, patch from tensorboard.backend.event_processing.event_accumulator import EventAccumulator +from torchtnt.utils.anomaly_evaluation import ThresholdEvaluator +from torchtnt.utils.loggers.anomaly_logger import TrackedMetric from torchtnt.utils.loggers.tensorboard import TensorBoardLogger class TensorBoardLoggerTest(unittest.TestCase): - def test_log(self: TensorBoardLoggerTest) -> None: + + @patch( + "torchtnt.utils.loggers.anomaly_logger.AnomalyLogger.on_anomaly_detected", + ) + def test_log( + self: TensorBoardLoggerTest, mock_on_anomaly_detected: MagicMock + ) -> None: with tempfile.TemporaryDirectory() as log_dir: - logger = TensorBoardLogger(path=log_dir) - for i in range(5): - logger.log("test_log", float(i) ** 2, i) - logger.close() + logger = TensorBoardLogger( + path=log_dir, + tracked_metrics=[ + TrackedMetric( + name="test_log", + anomaly_evaluators=[ + ThresholdEvaluator(min_val=25), + ], + evaluate_every_n_steps=2, + warmup_steps=2, + ) + ], + ) + warning_container = [] + with patch( + "torchtnt.utils.loggers.anomaly_logger.logging.Logger.warning", + side_effect=warning_container.append, + ): + for i in range(5): + logger.log("test_log", float(i) ** 2, i) + logger.close() + + self.assertEqual( + warning_container, + [ + "Found anomaly in metric: test_log, with value: 16.0, using evaluator: ThresholdEvaluator" + ], + ) + mock_on_anomaly_detected.assert_called_with("test_log", 16.0, 4) acc = EventAccumulator(log_dir) acc.Reload() diff --git a/torchtnt/utils/loggers/tensorboard.py b/torchtnt/utils/loggers/tensorboard.py index 3c26da30e9..a54f230566 100644 --- a/torchtnt/utils/loggers/tensorboard.py +++ b/torchtnt/utils/loggers/tensorboard.py @@ -11,16 +11,17 @@ import atexit import logging -from typing import Any, Dict, Mapping, Optional, Union +from typing import Any, Dict, List, Mapping, Optional, Union from torch.utils.tensorboard import SummaryWriter from torchtnt.utils.distributed import get_global_rank -from torchtnt.utils.loggers.logger import MetricLogger, Scalar +from torchtnt.utils.loggers.anomaly_logger import AnomalyLogger, TrackedMetric +from torchtnt.utils.loggers.logger import Scalar logger: logging.Logger = logging.getLogger(__name__) -class TensorBoardLogger(MetricLogger): +class TensorBoardLogger(AnomalyLogger): """ Simple logger for TensorBoard. @@ -28,6 +29,10 @@ class TensorBoardLogger(MetricLogger): will be written to. If the environment variable `RANK` is defined, logger will only log if RANK = 0. + Metrics may be tracked for anomaly detection if they are configured in the + optional `tracked_metrics` argument. See :class:`torchtnt.utils.loggers.AnomalyLogger` + for more details. + Note: If using this logger with distributed training: @@ -38,6 +43,7 @@ class TensorBoardLogger(MetricLogger): Args: path (str): path to write logs to + tracked_metrics: Optional list of TrackedMetric objects to track for anomaly detection. *args: Extra positional arguments to pass to SummaryWriter **kwargs: Extra keyword arguments to pass to SummaryWriter @@ -49,7 +55,14 @@ class TensorBoardLogger(MetricLogger): logger.close() """ - def __init__(self: TensorBoardLogger, path: str, *args: Any, **kwargs: Any) -> None: + def __init__( + self: TensorBoardLogger, + path: str, + tracked_metrics: Optional[List[TrackedMetric]] = None, + *args: Any, + **kwargs: Any, + ) -> None: + super().__init__(tracked_metrics) self._writer: Optional[SummaryWriter] = None self._path: str = path self._rank: int = get_global_rank() @@ -100,6 +113,8 @@ def log(self: TensorBoardLogger, name: str, data: Scalar, step: int) -> None: if self._writer: self._writer.add_scalar(name, data, global_step=step, new_style=True) + super().log(name, data, step) + def log_text(self: TensorBoardLogger, name: str, data: str, step: int) -> None: """Add text data to summary.