From fcbedd52d896ee682d26818c736b194b13612309 Mon Sep 17 00:00:00 2001 From: sanyhe Date: Thu, 4 Jul 2024 21:53:39 +0800 Subject: [PATCH] feat: add precision-recall curve. --- .../data_mining/model/classification.py | 73 ++++++++++++++----- .../model/func/algo_classification/_common.py | 54 +++++++++++++- .../model/func/algo_classification/_enum.py | 14 ++++ geochemistrypi/data_mining/process/cluster.py | 1 + 4 files changed, 119 insertions(+), 23 deletions(-) create mode 100644 geochemistrypi/data_mining/model/func/algo_classification/_enum.py diff --git a/geochemistrypi/data_mining/model/classification.py b/geochemistrypi/data_mining/model/classification.py index 2b4e1fd0..b81edd66 100644 --- a/geochemistrypi/data_mining/model/classification.py +++ b/geochemistrypi/data_mining/model/classification.py @@ -23,8 +23,19 @@ from ..plot.statistic_plot import basic_statistic from ..utils.base import clear_output, save_data, save_fig, save_text from ._base import LinearWorkflowMixin, TreeWorkflowMixin, WorkflowBase -from .func.algo_classification._common import cross_validation, plot_2d_decision_boundary, plot_confusion_matrix, plot_precision_recall, plot_ROC, resampler, reset_label, score +from .func.algo_classification._common import ( + cross_validation, + plot_2d_decision_boundary, + plot_confusion_matrix, + plot_precision_recall, + plot_precision_recall_threshold, + plot_ROC, + resampler, + reset_label, + score, +) from .func.algo_classification._decision_tree import decision_tree_manual_hyper_parameters +from .func.algo_classification._enum import ClassificationCommonFunction from .func.algo_classification._extra_trees import extra_trees_manual_hyper_parameters from .func.algo_classification._gradient_boosting import gradient_boosting_manual_hyper_parameters from .func.algo_classification._knn import knn_manual_hyper_parameters @@ -39,17 +50,7 @@ class ClassificationWorkflowBase(WorkflowBase): """The base workflow class of classification algorithms.""" - common_function = [ - "Model Score", - "Confusion Matrix", - "Cross Validation", - "Model Prediction", - "Model Persistence", - "Precision Recall Curve", - "ROC Curve", - "Two-dimensional Decision Boundary Diagram", - "Permutation Importance Diagram", - ] + common_function = [func.value for func in ClassificationCommonFunction] def __init__(self) -> None: super().__init__() @@ -163,18 +164,30 @@ def _plot_confusion_matrix(y_test: pd.DataFrame, y_test_predict: pd.DataFrame, t save_data(data, f"Confusion Matrix - {algorithm_name}", local_path, mlflow_path, True) @staticmethod - def _plot_precision_recall(X_test: pd.DataFrame, y_test: pd.DataFrame, trained_model: object, algorithm_name: str, local_path: str, mlflow_path: str) -> None: - print("-----* Precision Recall Curve *-----") - y_probs, precisions, recalls, thresholds = plot_precision_recall(X_test, y_test, trained_model, algorithm_name) - save_fig(f"Precision Recall Curve - {algorithm_name}", local_path, mlflow_path) + def _plot_precision_recall(X_test: pd.DataFrame, y_test: pd.DataFrame, trained_model: object, graph_name: str, algorithm_name: str, local_path: str, mlflow_path: str) -> None: + print(f"-----* {graph_name} *-----") + y_probs, precisions, recalls, thresholds = plot_precision_recall(X_test, y_test, trained_model, graph_name, algorithm_name) + save_fig(f"{graph_name} - {algorithm_name}", local_path, mlflow_path) y_probs = pd.DataFrame(y_probs, columns=["Probabilities"]) precisions = pd.DataFrame(precisions, columns=["Precisions"]) recalls = pd.DataFrame(recalls, columns=["Recalls"]) thresholds = pd.DataFrame(thresholds, columns=["Thresholds"]) - save_data(y_probs, "Precision Recall Curve - Probabilities", local_path, mlflow_path) - save_data(precisions, "Precision Recall Curve - Precisions", local_path, mlflow_path) - save_data(recalls, "Precision Recall Curve - Recalls", local_path, mlflow_path) - save_data(thresholds, "Precision Recall Curve - Thresholds", local_path, mlflow_path) + save_data(precisions, f"{graph_name} - Precisions", local_path, mlflow_path) + save_data(recalls, f"{graph_name} - Recalls", local_path, mlflow_path) + + @staticmethod + def _plot_precision_recall_threshold(X_test: pd.DataFrame, y_test: pd.DataFrame, trained_model: object, graph_name: str, algorithm_name: str, local_path: str, mlflow_path: str) -> None: + print(f"-----* {graph_name} *-----") + y_probs, precisions, recalls, thresholds = plot_precision_recall_threshold(X_test, y_test, trained_model, graph_name, algorithm_name) + save_fig(f"{graph_name} - {algorithm_name}", local_path, mlflow_path) + y_probs = pd.DataFrame(y_probs, columns=["Probabilities"]) + precisions = pd.DataFrame(precisions, columns=["Precisions"]) + recalls = pd.DataFrame(recalls, columns=["Recalls"]) + thresholds = pd.DataFrame(thresholds, columns=["Thresholds"]) + save_data(y_probs, f"{graph_name} - Probabilities", local_path, mlflow_path) + save_data(precisions, f"{graph_name} - Precisions", local_path, mlflow_path) + save_data(recalls, f"{graph_name} - Recalls", local_path, mlflow_path) + save_data(thresholds, f"{graph_name} - Thresholds", local_path, mlflow_path) @staticmethod def _plot_ROC(X_test: pd.DataFrame, y_test: pd.DataFrame, trained_model: object, algorithm_name: str, local_path: str, mlflow_path: str) -> None: @@ -285,6 +298,16 @@ def common_components(self) -> None: X_test=ClassificationWorkflowBase.X_test, y_test=ClassificationWorkflowBase.y_test, trained_model=self.model, + graph_name=ClassificationCommonFunction.PRECISION_RECALL_CURVE.value, + algorithm_name=self.naming, + local_path=GEOPI_OUTPUT_ARTIFACTS_IMAGE_MODEL_OUTPUT_PATH, + mlflow_path=MLFLOW_ARTIFACT_IMAGE_MODEL_OUTPUT_PATH, + ) + self._plot_precision_recall_threshold( + X_test=ClassificationWorkflowBase.X_test, + y_test=ClassificationWorkflowBase.y_test, + trained_model=self.model, + graph_name=ClassificationCommonFunction.PRECISION_RECALL_THRESHOLD_DIAGRAM.value, algorithm_name=self.naming, local_path=GEOPI_OUTPUT_ARTIFACTS_IMAGE_MODEL_OUTPUT_PATH, mlflow_path=MLFLOW_ARTIFACT_IMAGE_MODEL_OUTPUT_PATH, @@ -356,6 +379,16 @@ def common_components(self, is_automl: bool) -> None: X_test=ClassificationWorkflowBase.X_test, y_test=ClassificationWorkflowBase.y_test, trained_model=self.auto_model, + graph_name=ClassificationCommonFunction.PRECISION_RECALL_CURVE.value, + algorithm_name=self.naming, + local_path=GEOPI_OUTPUT_ARTIFACTS_IMAGE_MODEL_OUTPUT_PATH, + mlflow_path=MLFLOW_ARTIFACT_IMAGE_MODEL_OUTPUT_PATH, + ) + self._plot_precision_recall_threshold( + X_test=ClassificationWorkflowBase.X_test, + y_test=ClassificationWorkflowBase.y_test, + trained_model=self.auto_model, + graph_name=ClassificationCommonFunction.PRECISION_RECALL_THRESHOLD_DIAGRAM.value, algorithm_name=self.naming, local_path=GEOPI_OUTPUT_ARTIFACTS_IMAGE_MODEL_OUTPUT_PATH, mlflow_path=MLFLOW_ARTIFACT_IMAGE_MODEL_OUTPUT_PATH, diff --git a/geochemistrypi/data_mining/model/func/algo_classification/_common.py b/geochemistrypi/data_mining/model/func/algo_classification/_common.py index 94d05535..8d2058cb 100644 --- a/geochemistrypi/data_mining/model/func/algo_classification/_common.py +++ b/geochemistrypi/data_mining/model/func/algo_classification/_common.py @@ -196,8 +196,8 @@ def cross_validation(trained_model: object, X_train: pd.DataFrame, y_train: pd.D return scores_result -def plot_precision_recall(X_test, y_test, trained_model: object, algorithm_name: str) -> tuple: - """Plot the precision-recall curve. +def plot_precision_recall(X_test: pd.DataFrame, y_test: pd.DataFrame, trained_model: object, graph_name: str, algorithm_name: str) -> tuple: + """Plot the precision vs. recall diagram. Parameters ---------- @@ -210,6 +210,54 @@ def plot_precision_recall(X_test, y_test, trained_model: object, algorithm_name: trained_model : object The model trained. + graph_name : str + The name of the graph. + + algorithm_name : str + The name of the algorithm. + + Returns + ------- + y_probs : np.ndarray + The probabilities of the model. + + precisions : np.ndarray + The precision of the model. + + recalls : np.ndarray + The recall of the model. + + thresholds : np.ndarray + The thresholds of the model. + """ + # Predict probabilities for the positive class + y_probs = trained_model.predict_proba(X_test)[:, 1] + precisions, recalls, thresholds = precision_recall_curve(y_test, y_probs) + plt.figure() + plt.plot(recalls, precisions, "b-") + plt.xlabel("Recall") + plt.ylabel("Precision") + plt.title(f"{graph_name} - {algorithm_name}") + return y_probs, precisions, recalls, thresholds + + +def plot_precision_recall_threshold(X_test: pd.DataFrame, y_test: pd.DataFrame, trained_model: object, graph_name: str, algorithm_name: str) -> tuple: + """Plot the precision-recall vs. threshold diagram. + + Parameters + ---------- + X_test : pd.DataFrame (n_samples, n_components) + The testing feature data. + + y_test : pd.DataFrame (n_samples, n_components) + The testing target values. + + trained_model : object + The model trained. + + graph_name : str + The name of the graph. + algorithm_name : str The name of the algorithm. @@ -234,7 +282,7 @@ def plot_precision_recall(X_test, y_test, trained_model: object, algorithm_name: plt.plot(thresholds, precisions[:-1], "b--", label="Precision") plt.plot(thresholds, recalls[:-1], "g-", label="Recall") plt.legend(labels=["Precision", "Recall"], loc="best") - plt.title(f"Precision Recall Curve - {algorithm_name}") + plt.title(f"{graph_name} - {algorithm_name}") return y_probs, precisions, recalls, thresholds diff --git a/geochemistrypi/data_mining/model/func/algo_classification/_enum.py b/geochemistrypi/data_mining/model/func/algo_classification/_enum.py new file mode 100644 index 00000000..2552ea0c --- /dev/null +++ b/geochemistrypi/data_mining/model/func/algo_classification/_enum.py @@ -0,0 +1,14 @@ +from enum import Enum + + +class ClassificationCommonFunction(Enum): + MODEL_SCORE = "Model Score" + CONFUSION_MATRIX = "Confusion Matrix" + CROSS_VALIDATION = "Cross Validation" + MODEL_PREDICTION = "Model Prediction" + MODEL_PERSISTENCE = "Model Persistence" + PRECISION_RECALL_CURVE = "Precision-Recall Curve" + PRECISION_RECALL_THRESHOLD_DIAGRAM = "Precision-Recall vs. Threshold Diagram" + ROC_CURVE = "ROC Curve" + TWO_DIMENSIONAL_DECISION_BOUNDARY_DIAGRAM = "Two-dimensional Decision Boundary Diagram" + PERMUTATION_IMPORTANCE_DIAGRAM = "Permutation Importance Diagram" diff --git a/geochemistrypi/data_mining/process/cluster.py b/geochemistrypi/data_mining/process/cluster.py index db87eb41..5ae50b57 100644 --- a/geochemistrypi/data_mining/process/cluster.py +++ b/geochemistrypi/data_mining/process/cluster.py @@ -69,6 +69,7 @@ def activate( # Use Scikit-learn style API to process input data self.clt_workflow.fit(X) + # TODO: Move this into common_components() self.clt_workflow.get_cluster_centers() self.clt_workflow.get_labels()