Skip to content

Commit

Permalink
Merge pull request #357 from ZJUEarthData/web
Browse files Browse the repository at this point in the history
feat: add precision-recall curve.
  • Loading branch information
SanyHe committed Jul 4, 2024
2 parents 9172907 + fcbedd5 commit 0623cea
Show file tree
Hide file tree
Showing 4 changed files with 119 additions and 23 deletions.
73 changes: 53 additions & 20 deletions geochemistrypi/data_mining/model/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,19 @@
from ..plot.statistic_plot import basic_statistic
from ..utils.base import clear_output, save_data, save_fig, save_text
from ._base import LinearWorkflowMixin, TreeWorkflowMixin, WorkflowBase
from .func.algo_classification._common import cross_validation, plot_2d_decision_boundary, plot_confusion_matrix, plot_precision_recall, plot_ROC, resampler, reset_label, score
from .func.algo_classification._common import (
cross_validation,
plot_2d_decision_boundary,
plot_confusion_matrix,
plot_precision_recall,
plot_precision_recall_threshold,
plot_ROC,
resampler,
reset_label,
score,
)
from .func.algo_classification._decision_tree import decision_tree_manual_hyper_parameters
from .func.algo_classification._enum import ClassificationCommonFunction
from .func.algo_classification._extra_trees import extra_trees_manual_hyper_parameters
from .func.algo_classification._gradient_boosting import gradient_boosting_manual_hyper_parameters
from .func.algo_classification._knn import knn_manual_hyper_parameters
Expand All @@ -39,17 +50,7 @@
class ClassificationWorkflowBase(WorkflowBase):
"""The base workflow class of classification algorithms."""

common_function = [
"Model Score",
"Confusion Matrix",
"Cross Validation",
"Model Prediction",
"Model Persistence",
"Precision Recall Curve",
"ROC Curve",
"Two-dimensional Decision Boundary Diagram",
"Permutation Importance Diagram",
]
common_function = [func.value for func in ClassificationCommonFunction]

def __init__(self) -> None:
super().__init__()
Expand Down Expand Up @@ -163,18 +164,30 @@ def _plot_confusion_matrix(y_test: pd.DataFrame, y_test_predict: pd.DataFrame, t
save_data(data, f"Confusion Matrix - {algorithm_name}", local_path, mlflow_path, True)

@staticmethod
def _plot_precision_recall(X_test: pd.DataFrame, y_test: pd.DataFrame, trained_model: object, algorithm_name: str, local_path: str, mlflow_path: str) -> None:
print("-----* Precision Recall Curve *-----")
y_probs, precisions, recalls, thresholds = plot_precision_recall(X_test, y_test, trained_model, algorithm_name)
save_fig(f"Precision Recall Curve - {algorithm_name}", local_path, mlflow_path)
def _plot_precision_recall(X_test: pd.DataFrame, y_test: pd.DataFrame, trained_model: object, graph_name: str, algorithm_name: str, local_path: str, mlflow_path: str) -> None:
print(f"-----* {graph_name} *-----")
y_probs, precisions, recalls, thresholds = plot_precision_recall(X_test, y_test, trained_model, graph_name, algorithm_name)
save_fig(f"{graph_name} - {algorithm_name}", local_path, mlflow_path)
y_probs = pd.DataFrame(y_probs, columns=["Probabilities"])
precisions = pd.DataFrame(precisions, columns=["Precisions"])
recalls = pd.DataFrame(recalls, columns=["Recalls"])
thresholds = pd.DataFrame(thresholds, columns=["Thresholds"])
save_data(y_probs, "Precision Recall Curve - Probabilities", local_path, mlflow_path)
save_data(precisions, "Precision Recall Curve - Precisions", local_path, mlflow_path)
save_data(recalls, "Precision Recall Curve - Recalls", local_path, mlflow_path)
save_data(thresholds, "Precision Recall Curve - Thresholds", local_path, mlflow_path)
save_data(precisions, f"{graph_name} - Precisions", local_path, mlflow_path)
save_data(recalls, f"{graph_name} - Recalls", local_path, mlflow_path)

@staticmethod
def _plot_precision_recall_threshold(X_test: pd.DataFrame, y_test: pd.DataFrame, trained_model: object, graph_name: str, algorithm_name: str, local_path: str, mlflow_path: str) -> None:
print(f"-----* {graph_name} *-----")
y_probs, precisions, recalls, thresholds = plot_precision_recall_threshold(X_test, y_test, trained_model, graph_name, algorithm_name)
save_fig(f"{graph_name} - {algorithm_name}", local_path, mlflow_path)
y_probs = pd.DataFrame(y_probs, columns=["Probabilities"])
precisions = pd.DataFrame(precisions, columns=["Precisions"])
recalls = pd.DataFrame(recalls, columns=["Recalls"])
thresholds = pd.DataFrame(thresholds, columns=["Thresholds"])
save_data(y_probs, f"{graph_name} - Probabilities", local_path, mlflow_path)
save_data(precisions, f"{graph_name} - Precisions", local_path, mlflow_path)
save_data(recalls, f"{graph_name} - Recalls", local_path, mlflow_path)
save_data(thresholds, f"{graph_name} - Thresholds", local_path, mlflow_path)

@staticmethod
def _plot_ROC(X_test: pd.DataFrame, y_test: pd.DataFrame, trained_model: object, algorithm_name: str, local_path: str, mlflow_path: str) -> None:
Expand Down Expand Up @@ -285,6 +298,16 @@ def common_components(self) -> None:
X_test=ClassificationWorkflowBase.X_test,
y_test=ClassificationWorkflowBase.y_test,
trained_model=self.model,
graph_name=ClassificationCommonFunction.PRECISION_RECALL_CURVE.value,
algorithm_name=self.naming,
local_path=GEOPI_OUTPUT_ARTIFACTS_IMAGE_MODEL_OUTPUT_PATH,
mlflow_path=MLFLOW_ARTIFACT_IMAGE_MODEL_OUTPUT_PATH,
)
self._plot_precision_recall_threshold(
X_test=ClassificationWorkflowBase.X_test,
y_test=ClassificationWorkflowBase.y_test,
trained_model=self.model,
graph_name=ClassificationCommonFunction.PRECISION_RECALL_THRESHOLD_DIAGRAM.value,
algorithm_name=self.naming,
local_path=GEOPI_OUTPUT_ARTIFACTS_IMAGE_MODEL_OUTPUT_PATH,
mlflow_path=MLFLOW_ARTIFACT_IMAGE_MODEL_OUTPUT_PATH,
Expand Down Expand Up @@ -356,6 +379,16 @@ def common_components(self, is_automl: bool) -> None:
X_test=ClassificationWorkflowBase.X_test,
y_test=ClassificationWorkflowBase.y_test,
trained_model=self.auto_model,
graph_name=ClassificationCommonFunction.PRECISION_RECALL_CURVE.value,
algorithm_name=self.naming,
local_path=GEOPI_OUTPUT_ARTIFACTS_IMAGE_MODEL_OUTPUT_PATH,
mlflow_path=MLFLOW_ARTIFACT_IMAGE_MODEL_OUTPUT_PATH,
)
self._plot_precision_recall_threshold(
X_test=ClassificationWorkflowBase.X_test,
y_test=ClassificationWorkflowBase.y_test,
trained_model=self.auto_model,
graph_name=ClassificationCommonFunction.PRECISION_RECALL_THRESHOLD_DIAGRAM.value,
algorithm_name=self.naming,
local_path=GEOPI_OUTPUT_ARTIFACTS_IMAGE_MODEL_OUTPUT_PATH,
mlflow_path=MLFLOW_ARTIFACT_IMAGE_MODEL_OUTPUT_PATH,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -196,8 +196,8 @@ def cross_validation(trained_model: object, X_train: pd.DataFrame, y_train: pd.D
return scores_result


def plot_precision_recall(X_test, y_test, trained_model: object, algorithm_name: str) -> tuple:
"""Plot the precision-recall curve.
def plot_precision_recall(X_test: pd.DataFrame, y_test: pd.DataFrame, trained_model: object, graph_name: str, algorithm_name: str) -> tuple:
"""Plot the precision vs. recall diagram.
Parameters
----------
Expand All @@ -210,6 +210,54 @@ def plot_precision_recall(X_test, y_test, trained_model: object, algorithm_name:
trained_model : object
The model trained.
graph_name : str
The name of the graph.
algorithm_name : str
The name of the algorithm.
Returns
-------
y_probs : np.ndarray
The probabilities of the model.
precisions : np.ndarray
The precision of the model.
recalls : np.ndarray
The recall of the model.
thresholds : np.ndarray
The thresholds of the model.
"""
# Predict probabilities for the positive class
y_probs = trained_model.predict_proba(X_test)[:, 1]
precisions, recalls, thresholds = precision_recall_curve(y_test, y_probs)
plt.figure()
plt.plot(recalls, precisions, "b-")
plt.xlabel("Recall")
plt.ylabel("Precision")
plt.title(f"{graph_name} - {algorithm_name}")
return y_probs, precisions, recalls, thresholds


def plot_precision_recall_threshold(X_test: pd.DataFrame, y_test: pd.DataFrame, trained_model: object, graph_name: str, algorithm_name: str) -> tuple:
"""Plot the precision-recall vs. threshold diagram.
Parameters
----------
X_test : pd.DataFrame (n_samples, n_components)
The testing feature data.
y_test : pd.DataFrame (n_samples, n_components)
The testing target values.
trained_model : object
The model trained.
graph_name : str
The name of the graph.
algorithm_name : str
The name of the algorithm.
Expand All @@ -234,7 +282,7 @@ def plot_precision_recall(X_test, y_test, trained_model: object, algorithm_name:
plt.plot(thresholds, precisions[:-1], "b--", label="Precision")
plt.plot(thresholds, recalls[:-1], "g-", label="Recall")
plt.legend(labels=["Precision", "Recall"], loc="best")
plt.title(f"Precision Recall Curve - {algorithm_name}")
plt.title(f"{graph_name} - {algorithm_name}")
return y_probs, precisions, recalls, thresholds


Expand Down
14 changes: 14 additions & 0 deletions geochemistrypi/data_mining/model/func/algo_classification/_enum.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from enum import Enum


class ClassificationCommonFunction(Enum):
MODEL_SCORE = "Model Score"
CONFUSION_MATRIX = "Confusion Matrix"
CROSS_VALIDATION = "Cross Validation"
MODEL_PREDICTION = "Model Prediction"
MODEL_PERSISTENCE = "Model Persistence"
PRECISION_RECALL_CURVE = "Precision-Recall Curve"
PRECISION_RECALL_THRESHOLD_DIAGRAM = "Precision-Recall vs. Threshold Diagram"
ROC_CURVE = "ROC Curve"
TWO_DIMENSIONAL_DECISION_BOUNDARY_DIAGRAM = "Two-dimensional Decision Boundary Diagram"
PERMUTATION_IMPORTANCE_DIAGRAM = "Permutation Importance Diagram"
1 change: 1 addition & 0 deletions geochemistrypi/data_mining/process/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ def activate(

# Use Scikit-learn style API to process input data
self.clt_workflow.fit(X)
# TODO: Move this into common_components()
self.clt_workflow.get_cluster_centers()
self.clt_workflow.get_labels()

Expand Down

0 comments on commit 0623cea

Please sign in to comment.