Skip to content

Commit

Permalink
perf: imporve meanshift-realted code and the common functions 'cluste…
Browse files Browse the repository at this point in the history
…r center' and 'cluster label'.
  • Loading branch information
SanyHe committed Aug 31, 2024
1 parent 742c2ee commit 14c5d15
Show file tree
Hide file tree
Showing 5 changed files with 85 additions and 75 deletions.
7 changes: 3 additions & 4 deletions geochemistrypi/data_mining/model/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,11 +377,10 @@ class ClusteringMetricsMixin:
"""Mixin class for clustering metrics."""

@staticmethod
def _get_num_clusters(func_name: str, algorithm_name: str, trained_model: object, store_path: str) -> None:
"""Get and log the number of clusters."""
labels = trained_model.labels_
num_clusters = len(np.unique(labels))
def _get_num_clusters(labels: pd.Series, func_name: str, algorithm_name: str, store_path: str) -> None:
"""Get and log the number of clusters. It is only used in those algorithms which don't allow to set the number of cluster in advance."""
print(f"-----* {func_name} *-----")
num_clusters = len(np.unique(labels.to_numpy()))
print(f"{func_name}: {num_clusters}")
num_clusters_dict = {f"{func_name}": num_clusters}
mlflow.log_metrics(num_clusters_dict)
Expand Down
104 changes: 59 additions & 45 deletions geochemistrypi/data_mining/model/clustering.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@ class ClusteringWorkflowBase(WorkflowBase):

def __init__(self):
super().__init__()
self.clustering_result = None
self.cluster_labels = None
self.cluster_centers = None
self.mode = "Clustering"

def fit(self, X: pd.DataFrame, y: Optional[pd.DataFrame] = None) -> None:
Expand All @@ -43,33 +44,43 @@ def manual_hyper_parameters(cls) -> Dict:
"""Manual hyper-parameters specification."""
return dict()

# TODO(Samson [email protected]): This function might need to be rethought.
def get_cluster_centers(self) -> np.ndarray:
@staticmethod
def _get_cluster_centers(func_name: str, trained_model: object, algorithm_name: str, local_path: str, mlflow_path: str) -> Optional[pd.DataFrame]:
"""Get the cluster centers."""
print("-----* Clustering Centers *-----")
print(getattr(self.model, "cluster_centers_", "This class don not have cluster_centers_"))
return getattr(self.model, "cluster_centers_", "This class don not have cluster_centers_")
print(f"-----* {func_name} *-----")
cluster_centers = getattr(trained_model, "cluster_centers_", None)
if cluster_centers is None:
print("This algorithm does not provide cluster centers")
else:
column_name = []
for i in range(cluster_centers.shape[1]):
column_name.append(f"Dimension {i+1}")
print(cluster_centers)

def get_labels(self):
cluster_centers = pd.DataFrame(cluster_centers, columns=column_name)
save_data(cluster_centers, f"{func_name} - {algorithm_name}", local_path, mlflow_path)
return cluster_centers

@staticmethod
def _get_cluster_labels(func_name: str, trained_model: object, algorithm_name: str, local_path: str, mlflow_path: str) -> pd.DataFrame:
"""Get the cluster labels."""
print("-----* Clustering Labels *-----")
# self.X['clustering result'] = self.model.labels_
self.clustering_result = pd.DataFrame(self.model.labels_, columns=["clustering result"])
print(self.clustering_result)
GEOPI_OUTPUT_ARTIFACTS_DATA_PATH = os.getenv("GEOPI_OUTPUT_ARTIFACTS_DATA_PATH")
save_data(self.clustering_result, f"{self.naming} Result", GEOPI_OUTPUT_ARTIFACTS_DATA_PATH, MLFLOW_ARTIFACT_DATA_PATH)
print(f"-----* {func_name} *-----")
cluster_label = pd.DataFrame(trained_model.labels_, columns=[func_name])
print(cluster_label)
save_data(cluster_label, f"{func_name} - {algorithm_name}", local_path, mlflow_path)
return cluster_label

@staticmethod
def _score(data: pd.DataFrame, labels: pd.DataFrame, func_name: str, algorithm_name: str, store_path: str) -> None:
def _score(data: pd.DataFrame, labels: pd.Series, func_name: str, algorithm_name: str, store_path: str) -> None:
"""Calculate the score of the model."""
print(f"-----* {func_name} *-----")
scores = score(data, labels)
scores_str = json.dumps(scores, indent=4)
save_text(scores_str, f"{func_name}- {algorithm_name}", store_path)
save_text(scores_str, f"{func_name} - {algorithm_name}", store_path)
mlflow.log_metrics(scores)

@staticmethod
def _scatter2d(data: pd.DataFrame, labels: pd.DataFrame, cluster_centers_: np.ndarray, algorithm_name: str, local_path: str, mlflow_path: str) -> None:
def _scatter2d(data: pd.DataFrame, labels: pd.Series, cluster_centers_: pd.DataFrame, algorithm_name: str, local_path: str, mlflow_path: str) -> None:
"""Plot the two-dimensional diagram of the clustering result."""
print("-----* Cluster Two-Dimensional Diagram *-----")
scatter2d(data, labels, cluster_centers_, algorithm_name)
Expand All @@ -78,7 +89,7 @@ def _scatter2d(data: pd.DataFrame, labels: pd.DataFrame, cluster_centers_: np.nd
save_data(data_with_labels, f"Cluster Two-Dimensional Diagram - {algorithm_name}", local_path, mlflow_path)

@staticmethod
def _scatter3d(data: pd.DataFrame, labels: pd.DataFrame, algorithm_name: str, local_path: str, mlflow_path: str) -> None:
def _scatter3d(data: pd.DataFrame, labels: pd.Series, algorithm_name: str, local_path: str, mlflow_path: str) -> None:
"""Plot the three-dimensional diagram of the clustering result."""
print("-----* Cluster Three-Dimensional Diagram *-----")
scatter3d(data, labels, algorithm_name)
Expand All @@ -87,7 +98,7 @@ def _scatter3d(data: pd.DataFrame, labels: pd.DataFrame, algorithm_name: str, lo
save_data(data_with_labels, f"Cluster Two-Dimensional Diagram - {algorithm_name}", local_path, mlflow_path)

@staticmethod
def _plot_silhouette_diagram(data: pd.DataFrame, labels: pd.DataFrame, model: object, cluster_centers_: np.ndarray, algorithm_name: str, local_path: str, mlflow_path: str) -> None:
def _plot_silhouette_diagram(data: pd.DataFrame, labels: pd.Series, model: object, cluster_centers_: np.ndarray, algorithm_name: str, local_path: str, mlflow_path: str) -> None:
"""Plot the silhouette diagram of the clustering result."""
print("-----* Silhouette Diagram *-----")
plot_silhouette_diagram(data, labels, cluster_centers_, model, algorithm_name)
Expand All @@ -99,7 +110,7 @@ def _plot_silhouette_diagram(data: pd.DataFrame, labels: pd.DataFrame, model: ob
save_data(cluster_center_data, "Silhouette Diagram - Cluster Centers", local_path, mlflow_path)

@staticmethod
def _plot_silhouette_value_diagram(data: pd.DataFrame, labels: pd.DataFrame, algorithm_name: str, local_path: str, mlflow_path: str) -> None:
def _plot_silhouette_value_diagram(data: pd.DataFrame, labels: pd.Series, algorithm_name: str, local_path: str, mlflow_path: str) -> None:
"""Plot the silhouette value diagram of the clustering result."""
print("-----* Silhouette value Diagram *-----")
plot_silhouette_value_diagram(data, labels, algorithm_name)
Expand All @@ -111,9 +122,24 @@ def common_components(self) -> None:
"""Invoke all common application functions for clustering algorithms."""
GEOPI_OUTPUT_METRICS_PATH = os.getenv("GEOPI_OUTPUT_METRICS_PATH")
GEOPI_OUTPUT_ARTIFACTS_IMAGE_MODEL_OUTPUT_PATH = os.getenv("GEOPI_OUTPUT_ARTIFACTS_IMAGE_MODEL_OUTPUT_PATH")
GEOPI_OUTPUT_ARTIFACTS_DATA_PATH = os.getenv("GEOPI_OUTPUT_ARTIFACTS_DATA_PATH")
self.cluster_centers = self._get_cluster_centers(
func_name=ClusteringCommonFunction.CLUSTER_CENTERS.value,
trained_model=self.model,
algorithm_name=self.naming,
local_path=GEOPI_OUTPUT_ARTIFACTS_DATA_PATH,
mlflow_path=MLFLOW_ARTIFACT_DATA_PATH,
)
self.cluster_labels = self._get_cluster_labels(
func_name=ClusteringCommonFunction.CLUSTER_LABELS.value,
trained_model=self.model,
algorithm_name=self.naming,
local_path=GEOPI_OUTPUT_ARTIFACTS_DATA_PATH,
mlflow_path=MLFLOW_ARTIFACT_DATA_PATH,
)
self._score(
data=self.X,
labels=self.clustering_result["clustering result"],
labels=self.cluster_labels[ClusteringCommonFunction.CLUSTER_LABELS.value],
func_name=ClusteringCommonFunction.MODEL_SCORE.value,
algorithm_name=self.naming,
store_path=GEOPI_OUTPUT_METRICS_PATH,
Expand All @@ -123,8 +149,8 @@ def common_components(self) -> None:
two_dimen_axis_index, two_dimen_data = self.choose_dimension_data(self.X, 2)
self._scatter2d(
data=two_dimen_data,
labels=self.clustering_result["clustering result"],
cluster_centers_=self.get_cluster_centers(),
labels=self.cluster_labels[ClusteringCommonFunction.CLUSTER_LABELS.value],
cluster_centers_=self.cluster_centers,
algorithm_name=self.naming,
local_path=GEOPI_OUTPUT_ARTIFACTS_IMAGE_MODEL_OUTPUT_PATH,
mlflow_path=MLFLOW_ARTIFACT_IMAGE_MODEL_OUTPUT_PATH,
Expand All @@ -134,7 +160,7 @@ def common_components(self) -> None:
three_dimen_axis_index, three_dimen_data = self.choose_dimension_data(self.X, 3)
self._scatter3d(
data=three_dimen_data,
labels=self.clustering_result["clustering result"],
labels=self.cluster_labels[ClusteringCommonFunction.CLUSTER_LABELS.value],
algorithm_name=self.naming,
local_path=GEOPI_OUTPUT_ARTIFACTS_IMAGE_MODEL_OUTPUT_PATH,
mlflow_path=MLFLOW_ARTIFACT_IMAGE_MODEL_OUTPUT_PATH,
Expand All @@ -144,8 +170,8 @@ def common_components(self) -> None:
two_dimen_axis_index, two_dimen_data = self.choose_dimension_data(self.X, 2)
self._scatter2d(
data=two_dimen_data,
labels=self.clustering_result["clustering result"],
cluster_centers_=self.get_cluster_centers(),
labels=self.cluster_labels[ClusteringCommonFunction.CLUSTER_LABELS.value],
cluster_centers_=self.cluster_centers,
algorithm_name=self.naming,
local_path=GEOPI_OUTPUT_ARTIFACTS_IMAGE_MODEL_OUTPUT_PATH,
mlflow_path=MLFLOW_ARTIFACT_IMAGE_MODEL_OUTPUT_PATH,
Expand All @@ -154,16 +180,16 @@ def common_components(self) -> None:
# no need to choose
self._scatter3d(
data=self.X,
labels=self.clustering_result["clustering result"],
labels=self.cluster_labels[ClusteringCommonFunction.CLUSTER_LABELS.value],
algorithm_name=self.naming,
local_path=GEOPI_OUTPUT_ARTIFACTS_IMAGE_MODEL_OUTPUT_PATH,
mlflow_path=MLFLOW_ARTIFACT_IMAGE_MODEL_OUTPUT_PATH,
)
elif self.X.shape[1] == 2:
self._scatter2d(
data=self.X,
labels=self.clustering_result["clustering result"],
cluster_centers_=self.get_cluster_centers(),
labels=self.cluster_labels[ClusteringCommonFunction.CLUSTER_LABELS.value],
cluster_centers_=self.cluster_centers,
algorithm_name=self.naming,
local_path=GEOPI_OUTPUT_ARTIFACTS_IMAGE_MODEL_OUTPUT_PATH,
mlflow_path=MLFLOW_ARTIFACT_IMAGE_MODEL_OUTPUT_PATH,
Expand All @@ -173,16 +199,16 @@ def common_components(self) -> None:

self._plot_silhouette_diagram(
data=self.X,
labels=self.clustering_result["clustering result"],
cluster_centers_=self.get_cluster_centers(),
labels=self.cluster_labels[ClusteringCommonFunction.CLUSTER_LABELS.value],
cluster_centers_=self.cluster_centers,
model=self.model,
algorithm_name=self.naming,
local_path=GEOPI_OUTPUT_ARTIFACTS_IMAGE_MODEL_OUTPUT_PATH,
mlflow_path=MLFLOW_ARTIFACT_IMAGE_MODEL_OUTPUT_PATH,
)
self._plot_silhouette_value_diagram(
data=self.X,
labels=self.clustering_result["clustering result"],
labels=self.cluster_labels[ClusteringCommonFunction.CLUSTER_LABELS.value],
algorithm_name=self.naming,
local_path=GEOPI_OUTPUT_ARTIFACTS_IMAGE_MODEL_OUTPUT_PATH,
mlflow_path=MLFLOW_ARTIFACT_IMAGE_MODEL_OUTPUT_PATH,
Expand Down Expand Up @@ -429,9 +455,9 @@ def special_components(self, **kwargs: Union[Dict, np.ndarray, int]) -> None:
"""Invoke all special application functions for this algorithm by Scikit-learn framework."""
GEOPI_OUTPUT_METRICS_PATH = os.getenv("GEOPI_OUTPUT_METRICS_PATH")
self._get_num_clusters(
labels=self.cluster_labels[ClusteringCommonFunction.CLUSTER_LABELS.value],
func_name=MeanShiftSpecialFunction.NUM_CLUSTERS.value,
algorithm_name=self.naming,
trained_model=self.model,
store_path=GEOPI_OUTPUT_METRICS_PATH,
)

Expand Down Expand Up @@ -767,24 +793,12 @@ def special_components(self, **kwargs: Union[Dict, np.ndarray, int]) -> None:
"""Invoke all special application functions for this algorithm by Scikit-learn framework."""
GEOPI_OUTPUT_METRICS_PATH = os.getenv("GEOPI_OUTPUT_METRICS_PATH")
self._get_num_clusters(
labels=self.cluster_labels[ClusteringCommonFunction.CLUSTER_LABELS.value],
func_name=MeanShiftSpecialFunction.NUM_CLUSTERS.value,
algorithm_name=self.naming,
trained_model=self.model,
store_path=GEOPI_OUTPUT_METRICS_PATH,
)

@staticmethod
def _get_num_clusters(func_name: str, algorithm_name: str, trained_model: object, store_path: str) -> None:
"""Get and log the number of clusters."""
labels = trained_model.labels_
num_clusters = len(np.unique(labels))
print(f"-----* {func_name} *-----")
print(f"{func_name}: {num_clusters}")
num_clusters_dict = {f"{func_name}": num_clusters}
mlflow.log_metrics(num_clusters_dict)
num_clusters_str = json.dumps(num_clusters_dict, indent=4)
save_text(num_clusters_str, f"{func_name} - {algorithm_name}", store_path)


class SpectralClustering(ClusteringWorkflowBase):
name = "Spectral"
Expand Down
38 changes: 19 additions & 19 deletions geochemistrypi/data_mining/model/func/algo_clustering/_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,15 @@
from sklearn.metrics import calinski_harabasz_score, silhouette_samples, silhouette_score


def score(data: pd.DataFrame, labels: pd.DataFrame) -> Dict:
def score(data: pd.DataFrame, labels: pd.Series) -> Dict:
"""Calculate the scores of the clustering model.
Parameters
----------
data : pd.DataFrame (n_samples, n_components)
The true values.
labels : pd.DataFrame (n_samples, n_components)
labels : pd.Series (n_samples, )
Labels of each point.
Returns
Expand All @@ -38,7 +38,7 @@ def score(data: pd.DataFrame, labels: pd.DataFrame) -> Dict:
return scores


def scatter2d(data: pd.DataFrame, labels: pd.DataFrame, cluster_centers_: np.ndarray, algorithm_name: str) -> None:
def scatter2d(data: pd.DataFrame, labels: pd.Series, cluster_centers_: pd.DataFrame, algorithm_name: str) -> None:
"""
Draw the result-2D diagram for analysis.
Expand All @@ -47,10 +47,10 @@ def scatter2d(data: pd.DataFrame, labels: pd.DataFrame, cluster_centers_: np.nda
data : pd.DataFrame (n_samples, n_components)
The features of the data.
labels : pd.DataFrame (n_samples,)
labels : pd.Series (n_samples,)
Labels of each point.
cluster_centers_: np.ndarray (n_samples,)
cluster_centers_: pd.DataFrame (n_samples,)
Coordinates of cluster centers. If the algorithm stops before fully converging (see tol and max_iter), these will not be consistent with labels_.
algorithm_name : str
Expand Down Expand Up @@ -94,20 +94,20 @@ def scatter2d(data: pd.DataFrame, labels: pd.DataFrame, cluster_centers_: np.nda
plt.scatter(cluster_data.iloc[:, 0], cluster_data.iloc[:, 1], c=color, marker=marker)

# Plot the cluster centers
if not isinstance(cluster_centers_, str):
if cluster_centers_ is not None:
# Draw white circles at cluster centers
plt.scatter(cluster_centers_[:, 0], cluster_centers_[:, 1], c="white", marker="o", alpha=1, s=200, edgecolor="k")
plt.scatter(cluster_centers_.iloc[:, 0], cluster_centers_.iloc[:, 1], c="white", marker="o", alpha=1, s=200, edgecolor="k")

# Label the cluster centers
for i, c in enumerate(cluster_centers_):
for i, c in enumerate(cluster_centers_.to_numpy()):
plt.scatter(c[0], c[1], marker="$%d$" % i, alpha=1, s=50, edgecolor="k")

plt.xlabel(f"{data.columns[0]}")
plt.ylabel(f"{data.columns[1]}")
plt.title(f"Cluster Data Bi-plot - {algorithm_name}")


def scatter3d(data: pd.DataFrame, labels: pd.DataFrame, algorithm_name: str) -> None:
def scatter3d(data: pd.DataFrame, labels: pd.Series, algorithm_name: str) -> None:
"""
Draw the result-3D diagram for analysis.
Expand All @@ -116,7 +116,7 @@ def scatter3d(data: pd.DataFrame, labels: pd.DataFrame, algorithm_name: str) ->
data : pd.DataFrame (n_samples, n_components)
The features of the data.
labels : pd.DataFrame (n_samples,)
labels : pd.Series (n_samples,)
Labels of each point.
algorithm_name : str
Expand Down Expand Up @@ -177,7 +177,7 @@ def scatter3d(data: pd.DataFrame, labels: pd.DataFrame, algorithm_name: str) ->
ax2.set_title(f"Cluster Data Tri-plot - {algorithm_name}")


def plot_silhouette_diagram(data: pd.DataFrame, labels: pd.DataFrame, cluster_centers_: np.ndarray, model: object, algorithm_name: str) -> None:
def plot_silhouette_diagram(data: pd.DataFrame, labels: pd.Series, cluster_centers_: pd.DataFrame, model: object, algorithm_name: str) -> None:
"""
Draw the silhouette diagram for analysis.
Expand All @@ -186,10 +186,10 @@ def plot_silhouette_diagram(data: pd.DataFrame, labels: pd.DataFrame, cluster_ce
data : pd.DataFrame (n_samples, n_components)
The true values.
labels : pd.DataFrame (n_samples,)
labels : pd.Series (n_samples,)
Labels of each point.
cluster_centers_: np.ndarray (n_samples,)
cluster_centers_: pd.DataFrame (n_samples,)
Coordinates of cluster centers. If the algorithm stops before fully converging (see tol and max_iter), these will not be consistent with labels_.
model : sklearn algorithm model
Expand Down Expand Up @@ -286,11 +286,11 @@ def plot_silhouette_diagram(data: pd.DataFrame, labels: pd.DataFrame, cluster_ce
colors = cm.nipy_spectral(labels.astype(float) / n_clusters)
ax2.scatter(data.iloc[:, 0], data.iloc[:, 1], marker=".", s=30, lw=0, alpha=0.7, c=colors, edgecolor="k")

if not isinstance(cluster_centers_, str):
if cluster_centers_ is not None:
# Draw white circles at cluster centers
ax2.scatter(
cluster_centers_[:, 0],
cluster_centers_[:, 1],
cluster_centers_.iloc[:, 0],
cluster_centers_.iloc[:, 1],
marker="o",
c="white",
alpha=1,
Expand All @@ -299,7 +299,7 @@ def plot_silhouette_diagram(data: pd.DataFrame, labels: pd.DataFrame, cluster_ce
)

# Label the cluster centers
for i, c in enumerate(cluster_centers_):
for i, c in enumerate(cluster_centers_.to_numpy()):
ax2.scatter(c[0], c[1], marker="$%d$" % i, alpha=1, s=50, edgecolor="k")

ax2.set_title("The visualization of the clustered data.")
Expand All @@ -312,15 +312,15 @@ def plot_silhouette_diagram(data: pd.DataFrame, labels: pd.DataFrame, cluster_ce
)


def plot_silhouette_value_diagram(data, labels, algorithm_name: str):
def plot_silhouette_value_diagram(data: pd.DataFrame, labels: pd.Series, algorithm_name: str) -> None:
"""Calculate the scores of the clustering model.
Parameters
----------
data : pd.DataFrame (n_samples, n_components)
The true values.
labels : pd.DataFrame (n_samples, n_components)
labels : pd.Series (n_samples, )
Labels of each point.
algorithm_name : str
Expand Down
Loading

0 comments on commit 14c5d15

Please sign in to comment.