From c38260247232a69c3724086905f270632fb434ac Mon Sep 17 00:00:00 2001 From: Diego Urgell Date: Tue, 25 Jun 2024 15:33:02 -0700 Subject: [PATCH] Implement starter anomaly evaluators Summary: ### This Stack Based on [this RFC](https://docs.google.com/document/d/1K1KQ886dynMRejR0ySH1fctOjS7gxaCS8AB1L_PHxU4/edit?usp=sharing), we are adding a new logger that warns about anomalous values in metrics, and optionally executes a callback function with potential side effects. This could be useful for users to realize sooner that something has gone wrong during training. ### This Diff To get started with anomaly detection, let's first define two evaluators: - Threshold is the most intuitive one, and checks that a metric value is within a predefined range. - IsNaN would be useful to catch fast cases where the loss is NaN because of bad inputs. Later on we can implement more interesting evaluators like outliers, changepoint detection, etc. if needed. Differential Revision: D58564199 --- tests/utils/test_anomaly_evaluation.py | 47 +++++++++++++++++++++++++ torchtnt/utils/__init__.py | 3 ++ torchtnt/utils/anomaly_evaluation.py | 48 ++++++++++++++++++++++++++ 3 files changed, 98 insertions(+) create mode 100644 tests/utils/test_anomaly_evaluation.py diff --git a/tests/utils/test_anomaly_evaluation.py b/tests/utils/test_anomaly_evaluation.py new file mode 100644 index 0000000000..dec30aaed5 --- /dev/null +++ b/tests/utils/test_anomaly_evaluation.py @@ -0,0 +1,47 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +import math +import unittest + +from torchtnt.utils.anomaly_evaluation import IsNaNEvaluator, ThresholdEvaluator + + +class TestAnomalyLogger(unittest.TestCase): + + def test_threshold(self) -> None: + threshold = ThresholdEvaluator(min_val=0.5, max_val=0.9) + self.assertFalse(threshold.is_anomaly()) + + threshold.update(0.4) + self.assertTrue(threshold.is_anomaly()) + + threshold.update(0.6) + self.assertFalse(threshold.is_anomaly()) + + threshold.update(0.95) + self.assertTrue(threshold.is_anomaly()) + + threshold = ThresholdEvaluator(max_val=1) + + threshold.update(100.0) + self.assertTrue(threshold.is_anomaly()) + + threshold.update(-500.0) + self.assertFalse(threshold.is_anomaly()) + + def test_isnan(self) -> None: + isnan = IsNaNEvaluator() + self.assertFalse(isnan.is_anomaly()) + + isnan.update(0.4) + self.assertFalse(isnan.is_anomaly()) + + isnan.update(math.nan) + self.assertTrue(isnan.is_anomaly()) diff --git a/torchtnt/utils/__init__.py b/torchtnt/utils/__init__.py index 397ec67d26..fb8098c360 100644 --- a/torchtnt/utils/__init__.py +++ b/torchtnt/utils/__init__.py @@ -6,6 +6,7 @@ # pyre-strict +from .anomaly_evaluation import IsNaNEvaluator, ThresholdEvaluator from .checkpoint import ( BestCheckpointConfig, CheckpointManager, @@ -88,6 +89,8 @@ ) __all__ = [ + "IsNaNEvaluator", + "ThresholdEvaluator", "CheckpointPath", "MetricData", "get_best_checkpoint_path", diff --git a/torchtnt/utils/anomaly_evaluation.py b/torchtnt/utils/anomaly_evaluation.py index 97fed6cb5d..08b2d4a195 100644 --- a/torchtnt/utils/anomaly_evaluation.py +++ b/torchtnt/utils/anomaly_evaluation.py @@ -9,7 +9,9 @@ import logging +import math from abc import ABC, abstractmethod +from math import inf _logger: logging.Logger = logging.getLogger(__name__) @@ -49,3 +51,49 @@ def is_anomaly(self) -> bool: an anomaly detection algorithm. """ pass + + +class ThresholdEvaluator(MetricAnomalyEvaluator): + """ + Evaluates whether a metric value is anomalous based on a predefined threshold. + """ + + def __init__( + self, + *, + min_val: float = -inf, + max_val: float = inf, + ) -> None: + """ + Args: + min_val: Minimum allowed value. Default value is -inf. + max_val: Maximum allowed value. Default value is inf. + warmup_steps: Number of steps to ignore before evaluating anomalies. Default value is 0. + evaluate_every_n_steps: Step interval to wait in between anomaly evaluations. Default value is 1. + """ + self.min_val = min_val + self.max_val = max_val + self.curr_val: float = min_val + + def update(self, value: float) -> None: + self.curr_val = value + + def is_anomaly(self) -> bool: + return not self.min_val <= self.curr_val <= self.max_val + + +class IsNaNEvaluator(MetricAnomalyEvaluator): + """ + Evaluates whether a metric value is NaN. + """ + + def __init__( + self, + ) -> None: + self.curr_val: float = 0 + + def update(self, value: float) -> None: + self.curr_val = value + + def is_anomaly(self) -> bool: + return math.isnan(self.curr_val)