diff --git a/tests/test_cluster/test_silhouette.py b/tests/test_cluster/test_silhouette.py index 6f6615857..33643a967 100644 --- a/tests/test_cluster/test_silhouette.py +++ b/tests/test_cluster/test_silhouette.py @@ -20,9 +20,11 @@ import sys import pytest import matplotlib.pyplot as plt +import numpy as np from sklearn.datasets import make_blobs from sklearn.cluster import KMeans, MiniBatchKMeans +from sklearn.cluster import SpectralClustering, AgglomerativeClustering from unittest import mock from tests.base import VisualTestCase @@ -203,3 +205,22 @@ def test_with_fitted(self): oz = SilhouetteVisualizer(model, is_fitted=False) oz.fit(X, y) mockfit.assert_called_once_with(X, y) + + + @pytest.mark.parametrize( + "model", + [SpectralClustering, AgglomerativeClustering], + ) + def test_clusterer_without_predict(self, model): + """ + Assert that clustering estimators that don't implement + a predict() method utilize fit_predict() + """ + X = np.array([[1, 2], [1, 4], [1, 0], +... [4, 2], [4, 4], [4, 0]]) + try: + visualizer = SilhouetteVisualizer(model(n_clusters=2)) + visualizer.fit(X) + visualizer.finalize() + except AttributeError: + self.fail("could not use fit or fit_predict methods") diff --git a/yellowbrick/cluster/silhouette.py b/yellowbrick/cluster/silhouette.py index 3ad76645c..4103a6c81 100644 --- a/yellowbrick/cluster/silhouette.py +++ b/yellowbrick/cluster/silhouette.py @@ -151,7 +151,7 @@ def fit(self, X, y=None, **kwargs): if hasattr(self.estimator, "metric"): metric = self.estimator.metric - elif hasattr(self.estimator, "affinity"): + elif hasattr(self.estimator, "affinity") and self.estimator.__class__.__name__ != 'SpectralClustering': metric = self.estimator.affinity else: metric = "euclidean"