Skip to content

Commit

Permalink
Removed Geomstats dependancy from KARCH
Browse files Browse the repository at this point in the history
  • Loading branch information
KulikDM committed Aug 9, 2024
1 parent a7e23cd commit 278c2ec
Show file tree
Hide file tree
Showing 5 changed files with 18 additions and 25 deletions.
1 change: 0 additions & 1 deletion README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,6 @@ Or with **pip**:

- pyclustering (used in the CLUST thresholder)
- ruptures (used in the CPD thresholder)
- geomstats (used in the KARCH thresholder)
- scikit-lego (used in the META thresholder)
- joblib>=0.14.1 (used in the META thresholder and RANK)
- pandas (used in the META thresholder)
Expand Down
1 change: 0 additions & 1 deletion docs/install.rst
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@ Or with **pip**:

- pyclustering (used in the CLUST thresholder)
- ruptures (used in the CPD thresholder)
- geomstats (used in the KARCH thresholder)
- scikit-lego (used in the META thresholder)
- joblib>=0.14.1 (used in the META thresholder and RANK)
- pandas (used in the META thresholder)
Expand Down
1 change: 0 additions & 1 deletion docs/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@

docutils
geomstats
joblib
matplotlib
numpy
Expand Down
38 changes: 18 additions & 20 deletions pythresh/thresholds/karch.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,4 @@
import inspect

import numpy as np
from geomstats.geometry.euclidean import Euclidean
from geomstats.learning.frechet_mean import FrechetMean

from .base import BaseThresholder
from .thresh_utility import check_scores, cut, gen_kde, normalize
Expand Down Expand Up @@ -86,18 +82,6 @@ def eval(self, decision):

self.dscores_ = decision

# Create euclidean manifold and find Karcher mean
manifold = Euclidean(dim=self.ndim)

arg_map = {'old': {'metric': manifold.metric},
'new': {'space': manifold}}

arg_dict = (arg_map['new'] if 'space' in
inspect.signature(FrechetMean).parameters
else arg_map['old'])

estimator = FrechetMean(**arg_dict)

if self.method == 'complex':

# Create kde of scores
Expand All @@ -108,16 +92,30 @@ def eval(self, decision):
try:
# find kde and score dot product and solve the
vals = np.dot(val_data, val_norm)
estimator.fit(vals)
fmean = self._frechet_mean(vals)

except ValueError:
estimator.fit(decision.reshape(1, -1))
fmean = self._frechet_mean(decision.reshape(1, -1))
else:
estimator.fit(decision.reshape(1, -1))
fmean = self._frechet_mean(decision.reshape(1, -1))

# Get the mean of each dimension's Karcher mean
limit = np.mean(estimator.estimate_) + np.std(decision)
limit = np.mean(fmean) + np.std(decision)

self.thresh_ = limit

return cut(decision, limit)

# Adapted from https://github.com/geomstats/geomstats/blob/main/geomstats/learning/frechet_mean.py
def _frechet_mean(self, points, weights=None):
"""Compute the Frechet mean in a Euclidean space."""
if weights is None:
n_points = np.shape(points)[0]
weights = np.ones(n_points)

sum_weights = np.sum(weights)

weighted_points = np.einsum('n,n...->n...', weights, points)

mean = np.sum(weighted_points, axis=0) / sum_weights
return mean
2 changes: 0 additions & 2 deletions requirements-test.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
#geomstats
https://github.com/geomstats/geomstats/archive/main.zip
#pyclustering
https://github.com/KulikDM/pyclustering/archive/Warning-Fix.zip
joblib>=0.14.1
Expand Down

0 comments on commit 278c2ec

Please sign in to comment.