-
-
Notifications
You must be signed in to change notification settings - Fork 557
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
SilhouetteVisualizer
add support for more estimators
#1294
Changes from 3 commits
7b485b6
b1b2e9f
e67f05a
e6ce343
6be136a
36b6e8d
1150cc3
c7490d5
f111a3f
e50a829
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -135,18 +135,30 @@ def fit(self, X, y=None, **kwargs): | |
# NOTE: Probably this would be better in score, but the standard score | ||
# is a little different and I'm not sure how it's used. | ||
|
||
if not check_fitted(self.estimator, is_fitted_by=self.is_fitted): | ||
# Fit the wrapped estimator | ||
self.estimator.fit(X, y, **kwargs) | ||
# if estimator is fitted AND has attribute | ||
if check_fitted(self.estimator, is_fitted_by=self.is_fitted) and hasattr(self.estimator, "predict"): | ||
labels = self.estimator.predict(X) | ||
else: # if estimator is NOT fitted, OR estimator does NOT implement predict() | ||
labels = self.estimator.fit_predict(X, y, **kwargs) | ||
|
||
# Get the properties of the dataset | ||
self.n_samples_ = X.shape[0] | ||
self.n_clusters_ = self.estimator.n_clusters | ||
|
||
if hasattr(self.estimator, "n_clusters"): | ||
self.n_clusters_ = self.estimator.n_clusters | ||
else: | ||
self.n_clusters_ = len(set(labels)) - (1 if -1 in labels else 0) | ||
bbengfort marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
if hasattr(self.estimator, "metric"): | ||
metric = self.estimator.metric | ||
elif hasattr(self.estimator, "affinity"): | ||
metric = self.estimator.affinity | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @lwgray this is where the error is occurring for SpectralClustering - @stergion what model prompted you to add this metric selector? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. it was for the AffinityPropagation, AgglomerativeClustering, FeatureAgglomeration. Since sklearn version 1.2, AffinityPropagation was not updated, it still uses |
||
else: | ||
metric = "euclidean" | ||
|
||
# Compute the scores of the cluster | ||
labels = self.estimator.predict(X) | ||
self.silhouette_score_ = silhouette_score(X, labels) | ||
self.silhouette_samples_ = silhouette_samples(X, labels) | ||
self.silhouette_score_ = silhouette_score(X, labels, metric=metric) | ||
self.silhouette_samples_ = silhouette_samples(X, labels, metric=metric) | ||
|
||
# Draw the silhouette figure | ||
self.draw(labels) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great 👍 way to cover fit_predict here.