Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SilhouetteVisualizer add support for more estimators #1294

Merged
merged 10 commits into from
Jul 5, 2023
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
50 changes: 31 additions & 19 deletions tests/test_cluster/test_silhouette.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,15 @@
import sys
import pytest
import matplotlib.pyplot as plt
import numpy as np

from sklearn.datasets import make_blobs
from sklearn.cluster import KMeans, MiniBatchKMeans
from sklearn.cluster import SpectralClustering, AgglomerativeClustering

from unittest import mock
from tests.base import VisualTestCase

from yellowbrick.datasets import load_nfl
from yellowbrick.cluster.silhouette import SilhouetteVisualizer, silhouette_visualizer


Expand All @@ -53,7 +54,6 @@ def test_integrated_kmeans_silhouette(self):
n_samples=1000, n_features=12, centers=8, shuffle=False, random_state=0
)


fig = plt.figure()
ax = fig.add_subplot()

Expand All @@ -62,7 +62,6 @@ def test_integrated_kmeans_silhouette(self):
visualizer.finalize()

self.assert_images_similar(visualizer, remove_legend=True)


@pytest.mark.xfail(sys.platform == "win32", reason="images not close on windows")
def test_integrated_mini_batch_kmeans_silhouette(self):
Expand All @@ -84,7 +83,6 @@ def test_integrated_mini_batch_kmeans_silhouette(self):
visualizer.finalize()

self.assert_images_similar(visualizer, remove_legend=True)


@pytest.mark.skip(reason="no negative silhouette example available yet")
def test_negative_silhouette_score(self):
Expand All @@ -103,7 +101,6 @@ def test_colormap_silhouette(self):
n_samples=1000, n_features=12, centers=8, shuffle=False, random_state=0
)


fig = plt.figure()
ax = fig.add_subplot()

Expand Down Expand Up @@ -138,7 +135,7 @@ def test_colors_silhouette(self):
visualizer.finalize()

self.assert_images_similar(visualizer, remove_legend=True)

def test_colormap_as_colors_silhouette(self):
"""
Test no exceptions for modifying the colors in a silhouette visualizer
Expand All @@ -162,7 +159,7 @@ def test_colormap_as_colors_silhouette(self):
3.2 if sys.platform == "win32" else 0.01
) # Fails on AppVeyor with RMS 3.143
self.assert_images_similar(visualizer, remove_legend=True, tol=tol)

def test_quick_method(self):
"""
Test the quick method producing a valid visualization
Expand All @@ -177,29 +174,44 @@ def test_quick_method(self):

self.assert_images_similar(oz)

@pytest.mark.xfail(
reason="""third test fails with AssertionError: Expected fit
to be called once. Called 0 times."""
)
def test_with_fitted(self):
"""
Test that visualizer properly handles an already-fitted model
"""
X, y = load_nfl(return_dataset=True).to_numpy()

model = MiniBatchKMeans().fit(X, y)
X, y = make_blobs(
n_samples=100, n_features=5, centers=3, shuffle=False, random_state=112
)
model = MiniBatchKMeans().fit(X)
labels = model.predict(X)

with mock.patch.object(model, "fit") as mockfit:
oz = SilhouetteVisualizer(model)
oz.fit(X, y)
oz.fit(X)
mockfit.assert_not_called()

with mock.patch.object(model, "fit") as mockfit:
oz = SilhouetteVisualizer(model, is_fitted=True)
oz.fit(X, y)
oz.fit(X)
mockfit.assert_not_called()

with mock.patch.object(model, "fit") as mockfit:
with mock.patch.object(model, "fit_predict", return_value=labels) as mockfit:
oz = SilhouetteVisualizer(model, is_fitted=False)
oz.fit(X, y)
mockfit.assert_called_once_with(X, y)
oz.fit(X)
mockfit.assert_called_once_with(X, None)

@pytest.mark.parametrize(
"model",
[SpectralClustering, AgglomerativeClustering],
)
def test_clusterer_without_predict(self, model):
"""
Assert that clustering estimators that don't implement
a predict() method utilize fit_predict()
"""
X = np.array([[1, 2], [1, 4], [1, 0], [4, 2], [4, 4], [4, 0]])
try:
visualizer = SilhouetteVisualizer(model(n_clusters=2))
visualizer.fit(X)
visualizer.finalize()
except AttributeError:
self.fail("could not use fit or fit_predict methods")
93 changes: 82 additions & 11 deletions yellowbrick/cluster/silhouette.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,35 @@

from sklearn.metrics import silhouette_score, silhouette_samples

try:
from sklearn.metrics.pairwise import _VALID_METRICS
except ImportError:
_VALID_METRICS = [

Check warning on line 29 in yellowbrick/cluster/silhouette.py

View check run for this annotation

Codecov / codecov/patch

yellowbrick/cluster/silhouette.py#L28-L29

Added lines #L28 - L29 were not covered by tests
"cityblock",
"cosine",
"euclidean",
"l1",
"l2",
"manhattan",
"braycurtis",
"canberra",
"chebyshev",
"correlation",
"dice",
"hamming",
"jaccard",
"kulsinski",
"mahalanobis",
"minkowski",
"rogerstanimoto",
"russellrao",
"seuclidean",
"sokalmichener",
"sokalsneath",
"sqeuclidean",
"yule",
]

from yellowbrick.utils import check_fitted
from yellowbrick.style import resolve_colors
from yellowbrick.cluster.base import ClusteringScoreVisualizer
Expand Down Expand Up @@ -113,7 +142,6 @@
"""

def __init__(self, estimator, ax=None, colors=None, is_fitted="auto", **kwargs):

# Initialize the visualizer bases
super(SilhouetteVisualizer, self).__init__(
estimator, ax=ax, is_fitted=is_fitted, **kwargs
Expand All @@ -130,23 +158,47 @@
def fit(self, X, y=None, **kwargs):
"""
Fits the model and generates the silhouette visualization.

Unlike other visualizers that use the score() method to draw the results, this
visualizer errs on visualizing on fit since this is when the clusters are
computed. This means that a predict call is required in fit (or a fit_predict)
in order to produce the visualization.
"""
# TODO: decide to use this method or the score method to draw.
# NOTE: Probably this would be better in score, but the standard score
# is a little different and I'm not sure how it's used.

# If the estimator is not fitted, fit it; then call predict to get the labels
# for computing the silhoutte score on. If the estimator is already fitted, then
# attempt to predict the labels, but if the estimator is stateless, fit and
# predict on the data specified. At the end of this block, no matter the fitted
# state of the estimator and the method, we should have cluster labels for X.
if not check_fitted(self.estimator, is_fitted_by=self.is_fitted):
# Fit the wrapped estimator
self.estimator.fit(X, y, **kwargs)
if hasattr(self.estimator, "fit_predict"):
labels = self.estimator.fit_predict(X, y, **kwargs)
else:
self.estimator.fit(X, y, **kwargs)
labels = self.estimator.predict(X)

Check warning on line 178 in yellowbrick/cluster/silhouette.py

View check run for this annotation

Codecov / codecov/patch

yellowbrick/cluster/silhouette.py#L177-L178

Added lines #L177 - L178 were not covered by tests
else:
if hasattr(self.estimator, "predict"):
labels = self.estimator.predict(X)
else:
labels = self.estimator.fit_predict(X, y, **kwargs)

Check warning on line 183 in yellowbrick/cluster/silhouette.py

View check run for this annotation

Codecov / codecov/patch

yellowbrick/cluster/silhouette.py#L183

Added line #L183 was not covered by tests

# Get the properties of the dataset
self.n_samples_ = X.shape[0]
self.n_clusters_ = self.estimator.n_clusters

# Compute the number of available clusters from the estimator
if hasattr(self.estimator, "n_clusters"):
self.n_clusters_ = self.estimator.n_clusters
else:
unique_labels = set(labels)
n_noise_clusters = 1 if -1 in unique_labels else 0
self.n_clusters_ = len(unique_labels) - n_noise_clusters

Check warning on line 194 in yellowbrick/cluster/silhouette.py

View check run for this annotation

Codecov / codecov/patch

yellowbrick/cluster/silhouette.py#L192-L194

Added lines #L192 - L194 were not covered by tests

# Identify the distance metric to use for silhouette scoring
metric = self._identify_silhouette_metric()

# Compute the scores of the cluster
labels = self.estimator.predict(X)
self.silhouette_score_ = silhouette_score(X, labels)
self.silhouette_samples_ = silhouette_samples(X, labels)
self.silhouette_score_ = silhouette_score(X, labels, metric=metric)
self.silhouette_samples_ = silhouette_samples(X, labels, metric=metric)

# Draw the silhouette figure
self.draw(labels)
Expand Down Expand Up @@ -185,7 +237,6 @@
# For each cluster, plot the silhouette scores
self.y_tick_pos_ = []
for idx in range(self.n_clusters_):

# Collect silhouette scores for samples in the current cluster .
values = self.silhouette_samples_[labels == idx]
values.sort()
Expand Down Expand Up @@ -260,6 +311,26 @@
# Show legend (Average Silhouette Score axis)
self.ax.legend(loc="best")

def _identify_silhouette_metric(self):
"""
The Silhouette metric must be one of the distance options allowed by
metrics.pairwise.pairwise_distances or a callable. This method attempts to
discover a valid distance metric from the underlying estimator or returns
"euclidean" by default.
"""
if hasattr(self.estimator, "metric"):
if callable(self.estimator.metric):
return self.estimator.metric

Check warning on line 323 in yellowbrick/cluster/silhouette.py

View check run for this annotation

Codecov / codecov/patch

yellowbrick/cluster/silhouette.py#L322-L323

Added lines #L322 - L323 were not covered by tests

if self.estimator.metric in _VALID_METRICS:
return self.estimator.metric

Check warning on line 326 in yellowbrick/cluster/silhouette.py

View check run for this annotation

Codecov / codecov/patch

yellowbrick/cluster/silhouette.py#L325-L326

Added lines #L325 - L326 were not covered by tests

if hasattr(self.estimator, "affinity"):
if self.estimator.affinity in _VALID_METRICS:
return self.estimator.affinity

return "euclidean"


##########################################################################
## Quick Method
Expand Down
Loading