Skip to content

Commit

Permalink
Implement starter anomaly evaluators
Browse files Browse the repository at this point in the history
Summary:
### This Stack

Based on [this RFC](https://docs.google.com/document/d/1K1KQ886dynMRejR0ySH1fctOjS7gxaCS8AB1L_PHxU4/edit?usp=sharing), we are adding a new logger that warns about anomalous values in metrics, and optionally executes a callback function with potential side effects. This could be useful for users to realize sooner that something has gone wrong during training.

### This Diff

To get started with anomaly detection, let's first define two evaluators:
- Threshold is the most intuitive one, and checks that a metric value is within a predefined range.
- IsNaN would be useful to catch fast cases where the loss is NaN because of bad inputs.

Later on we can implement more interesting evaluators like outliers, changepoint detection, etc. if needed.

Differential Revision: D58564199
  • Loading branch information
Diego Urgell authored and facebook-github-bot committed Jun 25, 2024
1 parent 2ef31b8 commit c382602
Show file tree
Hide file tree
Showing 3 changed files with 98 additions and 0 deletions.
47 changes: 47 additions & 0 deletions tests/utils/test_anomaly_evaluation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-strict

import math
import unittest

from torchtnt.utils.anomaly_evaluation import IsNaNEvaluator, ThresholdEvaluator


class TestAnomalyLogger(unittest.TestCase):

def test_threshold(self) -> None:
threshold = ThresholdEvaluator(min_val=0.5, max_val=0.9)
self.assertFalse(threshold.is_anomaly())

threshold.update(0.4)
self.assertTrue(threshold.is_anomaly())

threshold.update(0.6)
self.assertFalse(threshold.is_anomaly())

threshold.update(0.95)
self.assertTrue(threshold.is_anomaly())

threshold = ThresholdEvaluator(max_val=1)

threshold.update(100.0)
self.assertTrue(threshold.is_anomaly())

threshold.update(-500.0)
self.assertFalse(threshold.is_anomaly())

def test_isnan(self) -> None:
isnan = IsNaNEvaluator()
self.assertFalse(isnan.is_anomaly())

isnan.update(0.4)
self.assertFalse(isnan.is_anomaly())

isnan.update(math.nan)
self.assertTrue(isnan.is_anomaly())
3 changes: 3 additions & 0 deletions torchtnt/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

# pyre-strict

from .anomaly_evaluation import IsNaNEvaluator, ThresholdEvaluator
from .checkpoint import (
BestCheckpointConfig,
CheckpointManager,
Expand Down Expand Up @@ -88,6 +89,8 @@
)

__all__ = [
"IsNaNEvaluator",
"ThresholdEvaluator",
"CheckpointPath",
"MetricData",
"get_best_checkpoint_path",
Expand Down
48 changes: 48 additions & 0 deletions torchtnt/utils/anomaly_evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@


import logging
import math
from abc import ABC, abstractmethod
from math import inf

_logger: logging.Logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -49,3 +51,49 @@ def is_anomaly(self) -> bool:
an anomaly detection algorithm.
"""
pass


class ThresholdEvaluator(MetricAnomalyEvaluator):
"""
Evaluates whether a metric value is anomalous based on a predefined threshold.
"""

def __init__(
self,
*,
min_val: float = -inf,
max_val: float = inf,
) -> None:
"""
Args:
min_val: Minimum allowed value. Default value is -inf.
max_val: Maximum allowed value. Default value is inf.
warmup_steps: Number of steps to ignore before evaluating anomalies. Default value is 0.
evaluate_every_n_steps: Step interval to wait in between anomaly evaluations. Default value is 1.
"""
self.min_val = min_val
self.max_val = max_val
self.curr_val: float = min_val

def update(self, value: float) -> None:
self.curr_val = value

def is_anomaly(self) -> bool:
return not self.min_val <= self.curr_val <= self.max_val


class IsNaNEvaluator(MetricAnomalyEvaluator):
"""
Evaluates whether a metric value is NaN.
"""

def __init__(
self,
) -> None:
self.curr_val: float = 0

def update(self, value: float) -> None:
self.curr_val = value

def is_anomaly(self) -> bool:
return math.isnan(self.curr_val)

0 comments on commit c382602

Please sign in to comment.