From 2ef31b8a1ace665a0220593489dfbe96cffd55a0 Mon Sep 17 00:00:00 2001 From: Diego Urgell Date: Tue, 25 Jun 2024 15:33:02 -0700 Subject: [PATCH] Add base AnomalyEvaluator class Summary: ### This Stack Based on [this RFC](https://docs.google.com/document/d/1K1KQ886dynMRejR0ySH1fctOjS7gxaCS8AB1L_PHxU4/edit?usp=sharing), we are adding a new logger that warns about anomalous values in metrics, and optionally executes a callback function with potential side effects. This could be useful for users to realize sooner that something has gone wrong during training. ### This Diff To provide flexibility when detecting anomalous metric values, instead of assuming and hardcoding a predefined check (like a threshold), let's create an interface that can be overriden to implement custom checks. Differential Revision: D58564201 --- torchtnt/utils/anomaly_evaluation.py | 51 ++++++++++++++++++++++++++++ 1 file changed, 51 insertions(+) create mode 100644 torchtnt/utils/anomaly_evaluation.py diff --git a/torchtnt/utils/anomaly_evaluation.py b/torchtnt/utils/anomaly_evaluation.py new file mode 100644 index 0000000000..97fed6cb5d --- /dev/null +++ b/torchtnt/utils/anomaly_evaluation.py @@ -0,0 +1,51 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + + +import logging +from abc import ABC, abstractmethod + +_logger: logging.Logger = logging.getLogger(__name__) + + +class MetricAnomalyEvaluator(ABC): + """ + Abstract base class for metric anomaly evaluators. An evaluator specifies the logic to determine that + a particular metric value is anomalous. To implement a custom method, create a subclass and implement + the following methods: + - :py:meth:`~torchtnt.utils.loggers.metric_anomaly_logger.MetricAnomalyEvaluator.update` should receive + the metric value and update the internal state. This is specially useful for algorithms that require + storing some previous values, moving averages, etc. + - :py:meth:`~torchtnt.utils.loggers.metric_anomaly_logger.MetricAnomalyEvaluator.is_anomaly` determines + whether the current metric state is anomalous. + + Likely there are some warm-up steps before the metric is stable and can be checked against anomalies, so + the separation of state update and actual detection provides this flexibility. + """ + + @abstractmethod + def update(self, value: float) -> None: + """ + Update the internal state with the given metric value. This should not determine anomalies itself, but + only aggregate the current value according to the anomaly detection algorithm. + + Note:: If no aggregation is required, this method can store the value directly, to be used in `is_anomaly`. + + Args: + value: Metric value + """ + pass + + @abstractmethod + def is_anomaly(self) -> bool: + """ + Determine whether the current metric state is anomalous. This should be overridden with custom logic related to + an anomaly detection algorithm. + """ + pass