diff --git a/tests/utils/test_anomaly_evaluation.py b/tests/utils/test_anomaly_evaluation.py new file mode 100644 index 0000000000..dec30aaed5 --- /dev/null +++ b/tests/utils/test_anomaly_evaluation.py @@ -0,0 +1,47 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +import math +import unittest + +from torchtnt.utils.anomaly_evaluation import IsNaNEvaluator, ThresholdEvaluator + + +class TestAnomalyLogger(unittest.TestCase): + + def test_threshold(self) -> None: + threshold = ThresholdEvaluator(min_val=0.5, max_val=0.9) + self.assertFalse(threshold.is_anomaly()) + + threshold.update(0.4) + self.assertTrue(threshold.is_anomaly()) + + threshold.update(0.6) + self.assertFalse(threshold.is_anomaly()) + + threshold.update(0.95) + self.assertTrue(threshold.is_anomaly()) + + threshold = ThresholdEvaluator(max_val=1) + + threshold.update(100.0) + self.assertTrue(threshold.is_anomaly()) + + threshold.update(-500.0) + self.assertFalse(threshold.is_anomaly()) + + def test_isnan(self) -> None: + isnan = IsNaNEvaluator() + self.assertFalse(isnan.is_anomaly()) + + isnan.update(0.4) + self.assertFalse(isnan.is_anomaly()) + + isnan.update(math.nan) + self.assertTrue(isnan.is_anomaly()) diff --git a/torchtnt/utils/__init__.py b/torchtnt/utils/__init__.py index 397ec67d26..fb8098c360 100644 --- a/torchtnt/utils/__init__.py +++ b/torchtnt/utils/__init__.py @@ -6,6 +6,7 @@ # pyre-strict +from .anomaly_evaluation import IsNaNEvaluator, ThresholdEvaluator from .checkpoint import ( BestCheckpointConfig, CheckpointManager, @@ -88,6 +89,8 @@ ) __all__ = [ + "IsNaNEvaluator", + "ThresholdEvaluator", "CheckpointPath", "MetricData", "get_best_checkpoint_path", diff --git a/torchtnt/utils/anomaly_evaluation.py b/torchtnt/utils/anomaly_evaluation.py index 97fed6cb5d..08b2d4a195 100644 --- a/torchtnt/utils/anomaly_evaluation.py +++ b/torchtnt/utils/anomaly_evaluation.py @@ -9,7 +9,9 @@ import logging +import math from abc import ABC, abstractmethod +from math import inf _logger: logging.Logger = logging.getLogger(__name__) @@ -49,3 +51,49 @@ def is_anomaly(self) -> bool: an anomaly detection algorithm. """ pass + + +class ThresholdEvaluator(MetricAnomalyEvaluator): + """ + Evaluates whether a metric value is anomalous based on a predefined threshold. + """ + + def __init__( + self, + *, + min_val: float = -inf, + max_val: float = inf, + ) -> None: + """ + Args: + min_val: Minimum allowed value. Default value is -inf. + max_val: Maximum allowed value. Default value is inf. + warmup_steps: Number of steps to ignore before evaluating anomalies. Default value is 0. + evaluate_every_n_steps: Step interval to wait in between anomaly evaluations. Default value is 1. + """ + self.min_val = min_val + self.max_val = max_val + self.curr_val: float = min_val + + def update(self, value: float) -> None: + self.curr_val = value + + def is_anomaly(self) -> bool: + return not self.min_val <= self.curr_val <= self.max_val + + +class IsNaNEvaluator(MetricAnomalyEvaluator): + """ + Evaluates whether a metric value is NaN. + """ + + def __init__( + self, + ) -> None: + self.curr_val: float = 0 + + def update(self, value: float) -> None: + self.curr_val = value + + def is_anomaly(self) -> bool: + return math.isnan(self.curr_val)