-
Notifications
You must be signed in to change notification settings - Fork 94
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #50 from salesforce/new_features
New features
- Loading branch information
Showing
16 changed files
with
534 additions
and
22 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,113 @@ | ||
# | ||
# Copyright (c) 2022 salesforce.com, inc. | ||
# All rights reserved. | ||
# SPDX-License-Identifier: BSD-3-Clause | ||
# For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause | ||
# | ||
""" | ||
The permutation feature importance explanation for tabular data. | ||
""" | ||
import numpy as np | ||
import pandas as pd | ||
from typing import Callable, Union | ||
from sklearn.metrics import log_loss | ||
from sklearn.inspection import permutation_importance | ||
|
||
from ..base import ExplainerBase, TabularExplainerMixin | ||
from ....data.tabular import Tabular | ||
from ....explanations.tabular.feature_importance import GlobalFeatureImportance | ||
|
||
|
||
class _Estimator: | ||
def fit(self): | ||
pass | ||
|
||
|
||
class PermutationImportance(ExplainerBase, TabularExplainerMixin): | ||
""" | ||
The permutation feature importance explanations for tabular data. The permutation feature | ||
importance is defined to be the decrease in a model score when a single feature value | ||
is randomly shuffled. | ||
""" | ||
|
||
explanation_type = "global" | ||
alias = ["permutation"] | ||
|
||
def __init__(self, training_data: Tabular, predict_function, mode="classification", **kwargs): | ||
""" | ||
:param training_data: The training dataset for training the machine learning model. | ||
:param predict_function: The prediction function corresponding to the model to explain. | ||
When the model is for classification, the outputs of the ``predict_function`` | ||
are the class probabilities. When the model is for regression, the outputs of | ||
the ``predict_function`` are the estimated values. | ||
:param mode: The task type, e.g., `classification` or `regression`. | ||
""" | ||
super().__init__() | ||
assert isinstance(training_data, Tabular), \ | ||
"training_data should be an instance of Tabular." | ||
assert mode in ["classification", "regression"], \ | ||
"`mode` can only be `classification` or `regression`." | ||
|
||
self.categorical_columns = training_data.categorical_columns | ||
self.predict_function = predict_function | ||
self.mode = mode | ||
|
||
def _build_score_function(self, score_func=None): | ||
if score_func is not None: | ||
def _score(estimator, x, y): | ||
z = self.predict_function( | ||
Tabular(x, categorical_columns=self.categorical_columns) | ||
) | ||
return score_func(y, z) | ||
elif self.mode == "classification": | ||
def _score(estimator, x, y): | ||
z = self.predict_function( | ||
Tabular(x, categorical_columns=self.categorical_columns) | ||
) | ||
return -log_loss(y, z) | ||
else: | ||
def _score(estimator, x, y): | ||
z = self.predict_function( | ||
Tabular(x, categorical_columns=self.categorical_columns) | ||
) | ||
return -np.mean((z - y) ** 2) | ||
return _score | ||
|
||
def explain( | ||
self, | ||
X: Tabular, | ||
y: Union[np.ndarray, pd.DataFrame], | ||
n_repeats: int = 30, | ||
score_func: Callable = None | ||
) -> GlobalFeatureImportance: | ||
""" | ||
Generate permutation feature importance scores. | ||
:param X: Data on which permutation importance will be computed. | ||
:param y: Targets or labels. | ||
:param n_repeats: The number of times a feature is randomly shuffled. | ||
:param score_func: The score function measuring the difference between | ||
ground-truth targets and predictions, e.g., -sklearn.metrics.log_loss(y_true, y_pred). | ||
:return: The permutation feature importance explanations. | ||
""" | ||
assert X is not None and y is not None, \ | ||
"The test data `X` and target `y` cannot be None." | ||
y = y.values if isinstance(y, pd.DataFrame) else np.array(y) | ||
if y.ndim > 1: | ||
y = y.flatten() | ||
assert X.shape[0] == len(y), \ | ||
"The numbers of samples in `X` and `y` are different." | ||
X = X.remove_target_column() | ||
|
||
results = permutation_importance( | ||
estimator=_Estimator(), | ||
X=X.to_pd(copy=False), | ||
y=y, | ||
scoring=self._build_score_function(score_func) | ||
) | ||
explanations = GlobalFeatureImportance() | ||
explanations.add( | ||
feature_names=list(X.columns), | ||
importance_scores=results["importances_mean"] | ||
) | ||
return explanations |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,122 @@ | ||
# | ||
# Copyright (c) 2022 salesforce.com, inc. | ||
# All rights reserved. | ||
# SPDX-License-Identifier: BSD-3-Clause | ||
# For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause | ||
# | ||
""" | ||
The SHAP explainer for global feature importance. | ||
""" | ||
import shap | ||
import numpy as np | ||
from typing import Callable, List | ||
|
||
from ..base import TabularExplainer | ||
from ....data.tabular import Tabular | ||
from ....explanations.tabular.feature_importance import GlobalFeatureImportance | ||
|
||
|
||
class GlobalShapTabular(TabularExplainer): | ||
""" | ||
The SHAP explainer for global feature importance. | ||
If using this explainer, please cite the original work: https://github.com/slundberg/shap. | ||
""" | ||
|
||
explanation_type = "global" | ||
alias = ["shap_global"] | ||
|
||
def __init__( | ||
self, | ||
training_data: Tabular, | ||
predict_function: Callable, | ||
mode: str = "classification", | ||
ignored_features: List = None, | ||
**kwargs | ||
): | ||
""" | ||
:param training_data: The data used to initialize a SHAP explainer. ``training_data`` | ||
can be the training dataset for training the machine learning model. If the training | ||
dataset is large, please set parameter ``nsamples``, e.g., ``nsamples = 100``. | ||
:param predict_function: The prediction function corresponding to the model to explain. | ||
When the model is for classification, the outputs of the ``predict_function`` | ||
are the class probabilities. When the model is for regression, the outputs of | ||
the ``predict_function`` are the estimated values. | ||
:param mode: The task type, e.g., `classification` or `regression`. | ||
:param ignored_features: The features ignored in computing feature importance scores. | ||
:param kwargs: Additional parameters to initialize `shap.KernelExplainer`, e.g., ``nsamples``. | ||
Please refer to the doc of `shap.KernelExplainer`. | ||
""" | ||
super().__init__(training_data=training_data, predict_function=predict_function, mode=mode, **kwargs) | ||
self.ignored_features = set(ignored_features) if ignored_features is not None else set() | ||
if self.target_column is not None: | ||
assert self.target_column not in self.ignored_features, \ | ||
f"The target column {self.target_column} cannot be in the ignored feature list." | ||
self.valid_indices = [i for i, f in enumerate(self.feature_columns) if f not in self.ignored_features] | ||
|
||
if "nsamples" not in kwargs: | ||
kwargs["nsamples"] = 100 | ||
self.background_data = shap.sample(self.data, nsamples=kwargs["nsamples"]) | ||
self.sampled_data = shap.sample(self.data, nsamples=kwargs["nsamples"]) | ||
|
||
def _explain_global(self, X, **kwargs) -> GlobalFeatureImportance: | ||
if "nsamples" not in kwargs: | ||
kwargs["nsamples"] = 100 | ||
instances = self.sampled_data if X is None else \ | ||
self.transformer.transform(X.remove_target_column()) | ||
|
||
explanations = GlobalFeatureImportance() | ||
explainer = shap.KernelExplainer( | ||
self.predict_fn, self.background_data, | ||
link="logit" if self.mode == "classification" else "identity", **kwargs | ||
) | ||
shap_values = explainer.shap_values(instances, **kwargs) | ||
|
||
if self.mode == "classification": | ||
values = 0 | ||
for v in shap_values: | ||
values += np.abs(v) | ||
values /= len(shap_values) | ||
shap_values = values | ||
|
||
importance_scores = np.mean(np.abs(shap_values), axis=0) | ||
explanations.add( | ||
feature_names=self.feature_columns, | ||
importance_scores=importance_scores, | ||
sort=True | ||
) | ||
return explanations | ||
|
||
def explain( | ||
self, | ||
X: Tabular = None, | ||
**kwargs | ||
): | ||
""" | ||
Generates the global SHAP explanations. | ||
:param X: The data will be used to compute global SHAP values, i.e., the mean of the absolute | ||
SHAP value for each feature. If `X` is None, a set of training samples will be used. | ||
:param kwargs: Additional parameters for `shap.KernelExplainer.shap_values`, | ||
e.g., ``nsamples`` -- the number of times to re-evaluate the model when explaining each prediction. | ||
:return: The global feature importance explanations. | ||
""" | ||
return self._explain_global(X=X, **kwargs) | ||
|
||
def save( | ||
self, | ||
directory: str, | ||
filename: str = None, | ||
**kwargs | ||
): | ||
""" | ||
Saves the initialized explainer. | ||
:param directory: The folder for the dumped explainer. | ||
:param filename: The filename (the explainer class name if it is None). | ||
""" | ||
super().save( | ||
directory=directory, | ||
filename=filename, | ||
ignored_attributes=["data"], | ||
**kwargs | ||
) |
Oops, something went wrong.