Skip to content

Commit

Permalink
Add anomaly detection support to TensorboardLogger (#854)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #854

### This Stack

Based on [this RFC](https://docs.google.com/document/d/1K1KQ886dynMRejR0ySH1fctOjS7gxaCS8AB1L_PHxU4/edit?usp=sharing), we are adding a new logger that warns about anomalous values in metrics, and optionally executes a callback function with potential side effects. This could be useful for users to realize sooner that something has gone wrong during training.

### This Diff

To start leveraging the AnomalyLogger as easily as possible, let's make it the base class for the Tensorboard logger instead of MetricLogger. This will have no effect unless users specify the `tracked_metrics` attribute, which is optional. However, if they do want to use it, they have to make very little changes.

Next diff will do the same for the AIXLogger

Reviewed By: JKSenthil

Differential Revision: D58593222
  • Loading branch information
diego-urgell authored and facebook-github-bot committed Jun 26, 2024
1 parent 1117930 commit 30c7b56
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 10 deletions.
45 changes: 39 additions & 6 deletions tests/utils/loggers/test_tensorboard.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,20 +11,53 @@

import tempfile
import unittest
from unittest.mock import Mock, patch
from unittest.mock import MagicMock, Mock, patch

from tensorboard.backend.event_processing.event_accumulator import EventAccumulator
from torchtnt.utils.anomaly_evaluation import ThresholdEvaluator
from torchtnt.utils.loggers.anomaly_logger import TrackedMetric

from torchtnt.utils.loggers.tensorboard import TensorBoardLogger


class TensorBoardLoggerTest(unittest.TestCase):
def test_log(self: TensorBoardLoggerTest) -> None:

@patch(
"torchtnt.utils.loggers.anomaly_logger.AnomalyLogger.on_anomaly_detected",
)
def test_log(
self: TensorBoardLoggerTest, mock_on_anomaly_detected: MagicMock
) -> None:
with tempfile.TemporaryDirectory() as log_dir:
logger = TensorBoardLogger(path=log_dir)
for i in range(5):
logger.log("test_log", float(i) ** 2, i)
logger.close()
logger = TensorBoardLogger(
path=log_dir,
tracked_metrics=[
TrackedMetric(
name="test_log",
anomaly_evaluators=[
ThresholdEvaluator(min_val=25),
],
evaluate_every_n_steps=2,
warmup_steps=2,
)
],
)
warning_container = []
with patch(
"torchtnt.utils.loggers.anomaly_logger.logging.Logger.warning",
side_effect=warning_container.append,
):
for i in range(5):
logger.log("test_log", float(i) ** 2, i)
logger.close()

self.assertEqual(
warning_container,
[
"Found anomaly in metric: test_log, with value: 16.0, using evaluator: ThresholdEvaluator"
],
)
mock_on_anomaly_detected.assert_called_with("test_log", 16.0, 4)

acc = EventAccumulator(log_dir)
acc.Reload()
Expand Down
23 changes: 19 additions & 4 deletions torchtnt/utils/loggers/tensorboard.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,23 +11,28 @@

import atexit
import logging
from typing import Any, Dict, Mapping, Optional, Union
from typing import Any, Dict, List, Mapping, Optional, Union

from torch.utils.tensorboard import SummaryWriter
from torchtnt.utils.distributed import get_global_rank
from torchtnt.utils.loggers.logger import MetricLogger, Scalar
from torchtnt.utils.loggers.anomaly_logger import AnomalyLogger, TrackedMetric
from torchtnt.utils.loggers.logger import Scalar

logger: logging.Logger = logging.getLogger(__name__)


class TensorBoardLogger(MetricLogger):
class TensorBoardLogger(AnomalyLogger):
"""
Simple logger for TensorBoard.
On construction, the logger creates a new events file that logs
will be written to. If the environment variable `RANK` is defined,
logger will only log if RANK = 0.
Metrics may be tracked for anomaly detection if they are configured in the
optional `tracked_metrics` argument. See :class:`torchtnt.utils.loggers.AnomalyLogger`
for more details.
Note:
If using this logger with distributed training:
Expand All @@ -38,6 +43,7 @@ class TensorBoardLogger(MetricLogger):
Args:
path (str): path to write logs to
tracked_metrics: Optional list of TrackedMetric objects to track for anomaly detection.
*args: Extra positional arguments to pass to SummaryWriter
**kwargs: Extra keyword arguments to pass to SummaryWriter
Expand All @@ -49,7 +55,14 @@ class TensorBoardLogger(MetricLogger):
logger.close()
"""

def __init__(self: TensorBoardLogger, path: str, *args: Any, **kwargs: Any) -> None:
def __init__(
self: TensorBoardLogger,
path: str,
tracked_metrics: Optional[List[TrackedMetric]] = None,
*args: Any,
**kwargs: Any,
) -> None:
super().__init__(tracked_metrics)
self._writer: Optional[SummaryWriter] = None
self._path: str = path
self._rank: int = get_global_rank()
Expand Down Expand Up @@ -100,6 +113,8 @@ def log(self: TensorBoardLogger, name: str, data: Scalar, step: int) -> None:
if self._writer:
self._writer.add_scalar(name, data, global_step=step, new_style=True)

super().log(name, data, step)

def log_text(self: TensorBoardLogger, name: str, data: str, step: int) -> None:
"""Add text data to summary.
Expand Down

0 comments on commit 30c7b56

Please sign in to comment.