Skip to content

Commit

Permalink
feat:Add the MeanShift algorithm to the clustering model and make the…
Browse files Browse the repository at this point in the history
… special function shared across different algorithms.
  • Loading branch information
Crt1124 committed Aug 11, 2024
1 parent d3bebba commit 9f5f2b5
Show file tree
Hide file tree
Showing 4 changed files with 44 additions and 15 deletions.
16 changes: 16 additions & 0 deletions geochemistrypi/data_mining/model/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,3 +371,19 @@ def _plot_3d_surface_diagram(feature_data: pd.DataFrame, target_data: pd.DataFra
save_fig(f"3D Surface Diagram - {algorithm_name}", local_path, mlflow_path)
data = pd.concat([feature_data, target_data, y_test_predict], axis=1)
save_data(data, f"3D Surface Diagram - {algorithm_name}", local_path, mlflow_path)


class ClusteringMetricsMixin:
"""Mixin class for clustering metrics."""

@staticmethod
def _get_num_clusters(func_name: str, algorithm_name: str, trained_model: object, store_path: str) -> None:
"""Get and log the number of clusters."""
labels = trained_model.labels_
num_clusters = len(np.unique(labels))
print(f"-----* {func_name} *-----")
print(f"{func_name}: {num_clusters}")
num_clusters_dict = {f"{func_name}": num_clusters}
mlflow.log_metrics(num_clusters_dict)
num_clusters_str = json.dumps(num_clusters_dict, indent=4)
save_text(num_clusters_str, f"{func_name} - {algorithm_name}", store_path)
29 changes: 18 additions & 11 deletions geochemistrypi/data_mining/model/clustering.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

from ..constants import MLFLOW_ARTIFACT_DATA_PATH, MLFLOW_ARTIFACT_IMAGE_MODEL_OUTPUT_PATH
from ..utils.base import clear_output, save_data, save_fig, save_text
from ._base import WorkflowBase
from ._base import ClusteringMetricsMixin, WorkflowBase
from .func.algo_clustering._affinitypropagation import affinitypropagation_manual_hyper_parameters
from .func.algo_clustering._agglomerative import agglomerative_manual_hyper_parameters
from .func.algo_clustering._common import plot_silhouette_diagram, plot_silhouette_value_diagram, scatter2d, scatter3d, score
Expand Down Expand Up @@ -274,7 +274,7 @@ def __init__(
might change in the future for a better heuristic.
References
----------------------------------------
----------
Scikit-learn API: sklearn.cluster.KMeans
https://scikit-learn.org/stable/modules/generated/sklearn.cluster.KMeans.html
"""
Expand Down Expand Up @@ -336,11 +336,11 @@ def special_components(self, **kwargs: Union[Dict, np.ndarray, int]) -> None:
)


class DBSCANClustering(ClusteringWorkflowBase):
class DBSCANClustering(ClusteringMetricsMixin, ClusteringWorkflowBase):
"""The automation workflow of using DBSCAN algorithm to make insightful products."""

name = "DBSCAN"
special_function = ["Virtualization of Result in 2D Graph"]
special_function = ["Num of Clusters"]

def __init__(
self,
Expand Down Expand Up @@ -389,7 +389,7 @@ def __init__(
The number of parallel jobs to run. None means 1 unless in a joblib.parallel_backend context. -1 means using all processors. See Glossary for more details.
References
----------------------------------------
----------
Scikit-learn API: sklearn.cluster.DBSCAN
https://scikit-learn.org/stable/modules/generated/sklearn.cluster.DBSCAN.html
"""
Expand Down Expand Up @@ -426,7 +426,14 @@ def manual_hyper_parameters(cls) -> Dict:
return hyper_parameters

def special_components(self, **kwargs: Union[Dict, np.ndarray, int]) -> None:
"""Invoke all special application functions for this algorithms by Scikit-learn framework."""
"""Invoke all special application functions for this algorithm by Scikit-learn framework."""
GEOPI_OUTPUT_METRICS_PATH = os.getenv("GEOPI_OUTPUT_METRICS_PATH")
self._get_num_clusters(
func_name=MeanShiftSpecialFunction.NUM_CLUSTERS.value,
algorithm_name=self.naming,
trained_model=self.model,
store_path=GEOPI_OUTPUT_METRICS_PATH,
)


class Agglomerative(ClusteringWorkflowBase):
Expand Down Expand Up @@ -617,7 +624,7 @@ def __init__(
this parameter was previously hardcoded as 0.
References
----------------------------------------
----------
Scikit-learn API: sklearn.cluster.AffinityPropagation
https://scikit-learn.org/stable/modules/generated/sklearn.cluster.AffinityPropagation
"""
Expand Down Expand Up @@ -659,10 +666,10 @@ def special_components(self, **kwargs: Union[Dict, np.ndarray, int]) -> None:
"""Invoke all special application functions for this algorithms by Scikit-learn framework."""


class MeanShiftClustering(ClusteringWorkflowBase):
class MeanShiftClustering(ClusteringMetricsMixin, ClusteringWorkflowBase):
name = "MeanShift"

special_function = [func.value for func in MeanShiftSpecialFunction]
special_function = ["Num of Clusters"]

def __init__(
self,
Expand Down Expand Up @@ -730,7 +737,7 @@ def __init__(
.. versionadded:: 0.22
References
----------------------------------------
----------
Scikit-learn API: sklearn.cluster.MeanShift
https://scikit-learn.org/stable/modules/generated/sklearn.cluster.MeanShift
"""
Expand Down Expand Up @@ -760,7 +767,7 @@ def special_components(self, **kwargs: Union[Dict, np.ndarray, int]) -> None:
"""Invoke all special application functions for this algorithm by Scikit-learn framework."""
GEOPI_OUTPUT_METRICS_PATH = os.getenv("GEOPI_OUTPUT_METRICS_PATH")
self._get_num_clusters(
func_name=MeanShiftSpecialFunction.NUM_CLUSTERS,
func_name=MeanShiftSpecialFunction.NUM_CLUSTERS.value,
algorithm_name=self.naming,
trained_model=self.model,
store_path=GEOPI_OUTPUT_METRICS_PATH,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,4 @@ class KMeansSpecialFunction(Enum):


class MeanShiftSpecialFunction(Enum):
NUM_CLUSTERS = "Num Clusters"
NUM_CLUSTERS = "Num of Clusters"
Original file line number Diff line number Diff line change
Expand Up @@ -16,24 +16,30 @@ def meanshift_manual_hyper_parameters() -> Dict:
"""
print("Bandwidth: The bandwidth of the kernel used in the algorithm. This parameter can greatly influence the results.")
print("If you do not have a specific value in mind, you can leave this as 0, and the algorithm will estimate it automatically.")
bandwidth_input = num_input(SECTION[2], "Enter Bandwidth (or None for automatic estimation): ")
print("A good starting point could be around 0.5 to 1.5, depending on your data's scale.")
bandwidth_input = num_input(SECTION[2], "Enter Bandwidth (or 0 for automatic estimation): ")
bandwidth = None if bandwidth_input == 0 else bandwidth_input

print("Cluster All: By default, only points at least as close to a cluster center as the given bandwidth are assigned to that cluster.")
print("Setting this to False will prevent points from being assigned to any cluster if they are too far away. Leave it True if you want all data points to be part of some cluster.")
print("For most use cases, 'True' is recommended to ensure all points are clustered.")
cluster_all = str_input(["True", "False"], SECTION[2])

print("Bin Seeding: If true, initial kernel locations are binned points, speeding up the algorithm with fewer seeds. Default is False.")

print("Setting this to True can be useful for large datasets to speed up computation. Consider using True if your dataset is large.")
bin_seeding = str_input(["True", "False"], SECTION[2])

print("Min Bin Frequency: To speed up the algorithm, accept only those bins with at least min_bin_freq points as seeds.")
print("A typical value is 1, but you might increase this for very large datasets to reduce the number of seeds.")
min_bin_freq = num_input(SECTION[2], "Enter Min Bin Frequency (default is 1): ")

print("Number of Jobs: The number of jobs to use for the computation. None means 1 unless in a joblib.parallel_backend context. -1 means using all processors.")
print("Number of Jobs: The number of jobs to use for the computation. 1 means using all processors.")
print("If you are unsure, use 1 to utilize all available processors.")
n_jobs = num_input(SECTION[2], "Enter Number of Jobs (or None): ")
n_jobs = -1 if n_jobs == 1 else int(n_jobs)

print("Max Iterations: Maximum number of iterations, per seed point before the clustering operation terminates (for that seed point), if has not converged yet.")
print("The default value is 300, which is sufficient for most use cases. You might increase this for very complex data.")
max_iter = num_input(SECTION[2], "Enter Max Iterations (default is 300): ")

hyper_parameters = {
Expand Down

0 comments on commit 9f5f2b5

Please sign in to comment.