diff --git a/docs/benchmark.rst b/docs/benchmark.rst index 87e2036..760d600 100644 --- a/docs/benchmark.rst +++ b/docs/benchmark.rst @@ -469,17 +469,17 @@ scale slightly differently depending on the hardware used. +---------------+--------------------+------------------------+ | CPD | Quadratic | ~1e-8*n^2 | +---------------+--------------------+------------------------+ -| DECOMP | Quadratic | ~1e-8*n^2 | +| DECOMP | Linear | ~1e-4*n | +---------------+--------------------+------------------------+ -| DSN | Quadratic | ~1e-8*n^2 | +| DSN | Linear | ~1e-4*n | +---------------+--------------------+------------------------+ -| EB | Linearithmic | ~1-06*n*log(n) | +| EB | Linearithmic | ~1e-6*n*log(n) | +---------------+--------------------+------------------------+ -| FGD | Quadratic | ~1e-8*n^2 | +| FGD | Linearithmic | ~1e-5*n*log(n) | +---------------+--------------------+------------------------+ | FILTER | Quadratic | ~1e-11*n^2 | +---------------+--------------------+------------------------+ -| FWFM | Quadratic | ~1e-8*n^2 | +| FWFM | Linearithmic | ~1e-5*n*log(n) | +---------------+--------------------+------------------------+ | GAMGMM | Quadratic | ~1e-6*n^2 | +---------------+--------------------+------------------------+ @@ -489,7 +489,7 @@ scale slightly differently depending on the hardware used. +---------------+--------------------+------------------------+ | IQR | Linear | ~1e-8*n | +---------------+--------------------+------------------------+ -| KARCH | Quadratic | ~1e-8*n^2 | +| KARCH | Linearithmic | ~1e-5*n*log(n) | +---------------+--------------------+------------------------+ | MAD | Linear | ~1e-8*n | +---------------+--------------------+------------------------+ @@ -497,9 +497,9 @@ scale slightly differently depending on the hardware used. +---------------+--------------------+------------------------+ | META | Cubic | ~1e-12*n^3 | +---------------+--------------------+------------------------+ -| MIXMOD | Linear | ~1e-4*n | +| MIXMOD | Linear | ~1e-3*n | +---------------+--------------------+------------------------+ -| MOLL | Quadratic | ~1e-10*n^2 | +| MOLL | Linearithmic | ~1e-7*n*log(n) | +---------------+--------------------+------------------------+ | MTT | Quadratic | ~1e-10*n^2 | +---------------+--------------------+------------------------+ @@ -511,9 +511,9 @@ scale slightly differently depending on the hardware used. +---------------+--------------------+------------------------+ | VAE | Linear | ~1e-3*n | +---------------+--------------------+------------------------+ -| WIND | Quadratic | ~1e-8*n^2 | +| WIND | Linear | ~1e-4*n | +---------------+--------------------+------------------------+ -| YJ | Quadratic | ~1e-8*n^2 | +| YJ | Linear | ~1e-4*n | +---------------+--------------------+------------------------+ | ZSCORE | Linear | ~1e-8*n | +---------------+--------------------+------------------------+ diff --git a/pythresh/test/test_fastkde.py b/pythresh/test/test_fastkde.py new file mode 100644 index 0000000..f98c8e2 --- /dev/null +++ b/pythresh/test/test_fastkde.py @@ -0,0 +1,68 @@ +import sys +import unittest +from itertools import product +from os.path import dirname as up + +# noinspection PyProtectedMember +import numpy as np +from numpy.testing import assert_equal +from pyod.models.iforest import IForest +from pyod.models.knn import KNN +from pyod.models.pca import PCA +from pyod.utils.data import generate_data + +from pythresh.thresholds.dsn import DSN + +# temporary solution for relative imports in case pythresh is not installed +# if pythresh is installed, no need to use the following line + +path = up(up(up(__file__))) +sys.path.append(path) + +# Test implementation of the fastkde interpolation method + + +class TestFastKDE(unittest.TestCase): + def setUp(self): + self.n_train = 10000 + self.n_test = 100 + self.contamination = 0.1 + self.X_train, self.X_test, self.y_train, self.y_test = generate_data( + n_train=self.n_train, n_test=self.n_test, + contamination=self.contamination, random_state=42) + + clf = KNN() + clf.fit(self.X_train) + + scores = clf.decision_scores_ + + clfs = [KNN(), PCA(), IForest()] + + multiple_scores = [ + clf.fit(self.X_train).decision_scores_ for clf in clfs] + multiple_scores = np.vstack(multiple_scores).T + + self.all_scores = [scores, multiple_scores] + + self.metrics = ['JS', 'MAH'] + + def test_prediction_labels(self): + + params = product(self.all_scores, self.metrics) + + for scores, metric in params: + + self.thres = DSN(metric=metric) + pred_labels = self.thres.eval(scores) + assert (self.thres.thresh_ is not None) + assert (self.thres.dscores_ is not None) + + assert (self.thres.dscores_.min() == 0) + assert (self.thres.dscores_.max() == 1) + + assert_equal(pred_labels.shape, self.y_train.shape) + + if (not np.all(pred_labels == 0)) & (not np.all(pred_labels == 1)): + + assert (pred_labels.min() == 0) + assert (pred_labels.max() == 1) diff --git a/pythresh/thresholds/fwfm.py b/pythresh/thresholds/fwfm.py index 00a84b1..6f92521 100644 --- a/pythresh/thresholds/fwfm.py +++ b/pythresh/thresholds/fwfm.py @@ -75,7 +75,7 @@ def eval(self, decision): base_width = peak_widths(val, peaks, rel_height=0.99)[0] # Normalize and set limit - limit = base_width/len(val) if len(base_width) > 0 else 1.1 + limit = base_width[0]/len(val) if len(base_width) > 0 else 1.1 self.thresh_ = limit