Skip to content

Commit

Permalink
Added test to verify that cluster estimator without a predict() metho…
Browse files Browse the repository at this point in the history
…d are implementing fit_predict(). Added condition to make sure Spectral Clustering metric is not being set to
  • Loading branch information
lwgray committed Jun 30, 2023
1 parent 6be136a commit 36b6e8d
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 1 deletion.
21 changes: 21 additions & 0 deletions tests/test_cluster/test_silhouette.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,11 @@
import sys
import pytest
import matplotlib.pyplot as plt
import numpy as np

from sklearn.datasets import make_blobs
from sklearn.cluster import KMeans, MiniBatchKMeans
from sklearn.cluster import SpectralClustering, AgglomerativeClustering

from unittest import mock
from tests.base import VisualTestCase
Expand Down Expand Up @@ -203,3 +205,22 @@ def test_with_fitted(self):
oz = SilhouetteVisualizer(model, is_fitted=False)
oz.fit(X, y)
mockfit.assert_called_once_with(X, y)


@pytest.mark.parametrize(
"model",
[SpectralClustering, AgglomerativeClustering],
)
def test_clusterer_without_predict(self, model):
"""
Assert that clustering estimators that don't implement
a predict() method utilize fit_predict()
"""
X = np.array([[1, 2], [1, 4], [1, 0],
... [4, 2], [4, 4], [4, 0]])
try:
visualizer = SilhouetteVisualizer(model(n_clusters=2))
visualizer.fit(X)
visualizer.finalize()
except AttributeError:
self.fail("could not use fit or fit_predict methods")
2 changes: 1 addition & 1 deletion yellowbrick/cluster/silhouette.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ def fit(self, X, y=None, **kwargs):

if hasattr(self.estimator, "metric"):
metric = self.estimator.metric
elif hasattr(self.estimator, "affinity"):
elif hasattr(self.estimator, "affinity") and self.estimator.__class__.__name__ != 'SpectralClustering':
metric = self.estimator.affinity
else:
metric = "euclidean"
Expand Down

0 comments on commit 36b6e8d

Please sign in to comment.