forked from pytorch/tnt
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Implement starter anomaly evaluators
Summary: ### This Stack Based on [this RFC](https://docs.google.com/document/d/1K1KQ886dynMRejR0ySH1fctOjS7gxaCS8AB1L_PHxU4/edit?usp=sharing), we are adding a new logger that warns about anomalous values in metrics, and optionally executes a callback function with potential side effects. This could be useful for users to realize sooner that something has gone wrong during training. ### This Diff To get started with anomaly detection, let's first define two evaluators: - Threshold is the most intuitive one, and checks that a metric value is within a predefined range. - IsNaN would be useful to catch fast cases where the loss is NaN because of bad inputs. Later on we can implement more interesting evaluators like outliers, changepoint detection, etc. if needed. Differential Revision: D58564199
- Loading branch information
1 parent
2ef31b8
commit c382602
Showing
3 changed files
with
98 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
#!/usr/bin/env python3 | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
# pyre-strict | ||
|
||
import math | ||
import unittest | ||
|
||
from torchtnt.utils.anomaly_evaluation import IsNaNEvaluator, ThresholdEvaluator | ||
|
||
|
||
class TestAnomalyLogger(unittest.TestCase): | ||
|
||
def test_threshold(self) -> None: | ||
threshold = ThresholdEvaluator(min_val=0.5, max_val=0.9) | ||
self.assertFalse(threshold.is_anomaly()) | ||
|
||
threshold.update(0.4) | ||
self.assertTrue(threshold.is_anomaly()) | ||
|
||
threshold.update(0.6) | ||
self.assertFalse(threshold.is_anomaly()) | ||
|
||
threshold.update(0.95) | ||
self.assertTrue(threshold.is_anomaly()) | ||
|
||
threshold = ThresholdEvaluator(max_val=1) | ||
|
||
threshold.update(100.0) | ||
self.assertTrue(threshold.is_anomaly()) | ||
|
||
threshold.update(-500.0) | ||
self.assertFalse(threshold.is_anomaly()) | ||
|
||
def test_isnan(self) -> None: | ||
isnan = IsNaNEvaluator() | ||
self.assertFalse(isnan.is_anomaly()) | ||
|
||
isnan.update(0.4) | ||
self.assertFalse(isnan.is_anomaly()) | ||
|
||
isnan.update(math.nan) | ||
self.assertTrue(isnan.is_anomaly()) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters