From de3aed7912cb1d1f1b6e6e77e0cb313b18b4902c Mon Sep 17 00:00:00 2001 From: Diego Urgell Date: Tue, 25 Jun 2024 16:18:47 -0700 Subject: [PATCH] Add anomaly detection support to TensorboardLogger (#854) Summary: Pull Request resolved: https://github.com/pytorch/tnt/pull/854 ### This Stack Based on [this RFC](https://docs.google.com/document/d/1K1KQ886dynMRejR0ySH1fctOjS7gxaCS8AB1L_PHxU4/edit?usp=sharing), we are adding a new logger that warns about anomalous values in metrics, and optionally executes a callback function with potential side effects. This could be useful for users to realize sooner that something has gone wrong during training. ### This Diff To start leveraging the AnomalyLogger as easily as possible, let's make it the base class for the Tensorboard logger instead of MetricLogger. This will have no effect unless users specify the `tracked_metrics` attribute, which is optional. However, if they do want to use it, they have to make very little changes. Next diff will do the same for the AIXLogger Reviewed By: JKSenthil Differential Revision: D58593222 --- tests/utils/loggers/test_tensorboard.py | 45 +++++++++++++++++++++---- torchtnt/utils/loggers/tensorboard.py | 23 ++++++++++--- 2 files changed, 58 insertions(+), 10 deletions(-) diff --git a/tests/utils/loggers/test_tensorboard.py b/tests/utils/loggers/test_tensorboard.py index 709c437cf0..d9eec49885 100644 --- a/tests/utils/loggers/test_tensorboard.py +++ b/tests/utils/loggers/test_tensorboard.py @@ -11,20 +11,53 @@ import tempfile import unittest -from unittest.mock import Mock, patch +from unittest.mock import MagicMock, Mock, patch from tensorboard.backend.event_processing.event_accumulator import EventAccumulator +from torchtnt.utils.anomaly_evaluation import ThresholdEvaluator +from torchtnt.utils.loggers.anomaly_logger import TrackedMetric from torchtnt.utils.loggers.tensorboard import TensorBoardLogger class TensorBoardLoggerTest(unittest.TestCase): - def test_log(self: TensorBoardLoggerTest) -> None: + + @patch( + "torchtnt.utils.loggers.anomaly_logger.AnomalyLogger.on_anomaly_detected", + ) + def test_log( + self: TensorBoardLoggerTest, mock_on_anomaly_detected: MagicMock + ) -> None: with tempfile.TemporaryDirectory() as log_dir: - logger = TensorBoardLogger(path=log_dir) - for i in range(5): - logger.log("test_log", float(i) ** 2, i) - logger.close() + logger = TensorBoardLogger( + path=log_dir, + tracked_metrics=[ + TrackedMetric( + name="test_log", + anomaly_evaluators=[ + ThresholdEvaluator(min_val=25), + ], + evaluate_every_n_steps=2, + warmup_steps=2, + ) + ], + ) + warning_container = [] + with patch( + "torchtnt.utils.loggers.anomaly_logger.logging.Logger.warning", + side_effect=warning_container.append, + ): + for i in range(5): + logger.log("test_log", float(i) ** 2, i) + logger.close() + + self.assertEqual( + warning_container, + [ + "Found anomaly in metric: test_log, with value: 16.0, using evaluator: ThresholdEvaluator" + ], + ) + mock_on_anomaly_detected.assert_called_with("test_log", 16.0, 4) acc = EventAccumulator(log_dir) acc.Reload() diff --git a/torchtnt/utils/loggers/tensorboard.py b/torchtnt/utils/loggers/tensorboard.py index 3c26da30e9..a54f230566 100644 --- a/torchtnt/utils/loggers/tensorboard.py +++ b/torchtnt/utils/loggers/tensorboard.py @@ -11,16 +11,17 @@ import atexit import logging -from typing import Any, Dict, Mapping, Optional, Union +from typing import Any, Dict, List, Mapping, Optional, Union from torch.utils.tensorboard import SummaryWriter from torchtnt.utils.distributed import get_global_rank -from torchtnt.utils.loggers.logger import MetricLogger, Scalar +from torchtnt.utils.loggers.anomaly_logger import AnomalyLogger, TrackedMetric +from torchtnt.utils.loggers.logger import Scalar logger: logging.Logger = logging.getLogger(__name__) -class TensorBoardLogger(MetricLogger): +class TensorBoardLogger(AnomalyLogger): """ Simple logger for TensorBoard. @@ -28,6 +29,10 @@ class TensorBoardLogger(MetricLogger): will be written to. If the environment variable `RANK` is defined, logger will only log if RANK = 0. + Metrics may be tracked for anomaly detection if they are configured in the + optional `tracked_metrics` argument. See :class:`torchtnt.utils.loggers.AnomalyLogger` + for more details. + Note: If using this logger with distributed training: @@ -38,6 +43,7 @@ class TensorBoardLogger(MetricLogger): Args: path (str): path to write logs to + tracked_metrics: Optional list of TrackedMetric objects to track for anomaly detection. *args: Extra positional arguments to pass to SummaryWriter **kwargs: Extra keyword arguments to pass to SummaryWriter @@ -49,7 +55,14 @@ class TensorBoardLogger(MetricLogger): logger.close() """ - def __init__(self: TensorBoardLogger, path: str, *args: Any, **kwargs: Any) -> None: + def __init__( + self: TensorBoardLogger, + path: str, + tracked_metrics: Optional[List[TrackedMetric]] = None, + *args: Any, + **kwargs: Any, + ) -> None: + super().__init__(tracked_metrics) self._writer: Optional[SummaryWriter] = None self._path: str = path self._rank: int = get_global_rank() @@ -100,6 +113,8 @@ def log(self: TensorBoardLogger, name: str, data: Scalar, step: int) -> None: if self._writer: self._writer.add_scalar(name, data, global_step=step, new_style=True) + super().log(name, data, step) + def log_text(self: TensorBoardLogger, name: str, data: str, step: int) -> None: """Add text data to summary.