Skip to content

Commit

Permalink
Merge pull request #50 from salesforce/new_features
Browse files Browse the repository at this point in the history
New features
  • Loading branch information
yangwenz authored Nov 17, 2022
2 parents 72989c2 + d3ebd0e commit b1ce148
Show file tree
Hide file tree
Showing 16 changed files with 534 additions and 22 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ We will continue improving this library to make it more comprehensive in the fut
| Partial dependence plots | Black box | Global | || | | |
| Accumulated local effects | Black box | Global | || | | |
| Sensitivity analysis | Black box | Global | || | | |
| Permutation explanation | Black box | Global | || | | |
| Feature visualization | Torch or TF | Global | | || | |
| Feature maps | Torch or TF | Local | | || | |
| LIME | Black box | Local | |||| |
Expand Down
16 changes: 16 additions & 0 deletions docs/omnixai.explainers.tabular.agnostic.rst
Original file line number Diff line number Diff line change
Expand Up @@ -60,3 +60,19 @@ omnixai.explainers.tabular.agnostic.L2X.l2x module
:members:
:undoc-members:
:show-inheritance:

omnixai.explainers.tabular.agnostic.permutation module
------------------------------------------------------

.. automodule:: omnixai.explainers.tabular.agnostic.permutation
:members:
:undoc-members:
:show-inheritance:

omnixai.explainers.tabular.agnostic.shap_global module
------------------------------------------------------

.. automodule:: omnixai.explainers.tabular.agnostic.shap_global
:members:
:undoc-members:
:show-inheritance:
6 changes: 6 additions & 0 deletions omnixai/explainers/tabular/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,11 @@
from .agnostic.ale import ALE
from .agnostic.sensitivity import SensitivityAnalysisTabular
from .agnostic.L2X.l2x import L2XTabular
from .agnostic.permutation import PermutationImportance
from .agnostic.shap_global import GlobalShapTabular
from .counterfactual.mace.mace import MACEExplainer
from .counterfactual.ce import CounterfactualExplainer
from .counterfactual.knn import KNNCounterfactualExplainer
from .specific.ig import IntegratedGradientTabular
from .specific.linear import LinearRegression
from .specific.linear import LogisticRegression
Expand All @@ -29,8 +32,11 @@
"ALE",
"SensitivityAnalysisTabular",
"L2XTabular",
"PermutationImportance",
"GlobalShapTabular",
"MACEExplainer",
"CounterfactualExplainer",
"KNNCounterfactualExplainer",
"LinearRegression",
"LogisticRegression",
"TreeRegressor",
Expand Down
113 changes: 113 additions & 0 deletions omnixai/explainers/tabular/agnostic/permutation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
#
# Copyright (c) 2022 salesforce.com, inc.
# All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
# For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
#
"""
The permutation feature importance explanation for tabular data.
"""
import numpy as np
import pandas as pd
from typing import Callable, Union
from sklearn.metrics import log_loss
from sklearn.inspection import permutation_importance

from ..base import ExplainerBase, TabularExplainerMixin
from ....data.tabular import Tabular
from ....explanations.tabular.feature_importance import GlobalFeatureImportance


class _Estimator:
def fit(self):
pass


class PermutationImportance(ExplainerBase, TabularExplainerMixin):
"""
The permutation feature importance explanations for tabular data. The permutation feature
importance is defined to be the decrease in a model score when a single feature value
is randomly shuffled.
"""

explanation_type = "global"
alias = ["permutation"]

def __init__(self, training_data: Tabular, predict_function, mode="classification", **kwargs):
"""
:param training_data: The training dataset for training the machine learning model.
:param predict_function: The prediction function corresponding to the model to explain.
When the model is for classification, the outputs of the ``predict_function``
are the class probabilities. When the model is for regression, the outputs of
the ``predict_function`` are the estimated values.
:param mode: The task type, e.g., `classification` or `regression`.
"""
super().__init__()
assert isinstance(training_data, Tabular), \
"training_data should be an instance of Tabular."
assert mode in ["classification", "regression"], \
"`mode` can only be `classification` or `regression`."

self.categorical_columns = training_data.categorical_columns
self.predict_function = predict_function
self.mode = mode

def _build_score_function(self, score_func=None):
if score_func is not None:
def _score(estimator, x, y):
z = self.predict_function(
Tabular(x, categorical_columns=self.categorical_columns)
)
return score_func(y, z)
elif self.mode == "classification":
def _score(estimator, x, y):
z = self.predict_function(
Tabular(x, categorical_columns=self.categorical_columns)
)
return -log_loss(y, z)
else:
def _score(estimator, x, y):
z = self.predict_function(
Tabular(x, categorical_columns=self.categorical_columns)
)
return -np.mean((z - y) ** 2)
return _score

def explain(
self,
X: Tabular,
y: Union[np.ndarray, pd.DataFrame],
n_repeats: int = 30,
score_func: Callable = None
) -> GlobalFeatureImportance:
"""
Generate permutation feature importance scores.
:param X: Data on which permutation importance will be computed.
:param y: Targets or labels.
:param n_repeats: The number of times a feature is randomly shuffled.
:param score_func: The score function measuring the difference between
ground-truth targets and predictions, e.g., -sklearn.metrics.log_loss(y_true, y_pred).
:return: The permutation feature importance explanations.
"""
assert X is not None and y is not None, \
"The test data `X` and target `y` cannot be None."
y = y.values if isinstance(y, pd.DataFrame) else np.array(y)
if y.ndim > 1:
y = y.flatten()
assert X.shape[0] == len(y), \
"The numbers of samples in `X` and `y` are different."
X = X.remove_target_column()

results = permutation_importance(
estimator=_Estimator(),
X=X.to_pd(copy=False),
y=y,
scoring=self._build_score_function(score_func)
)
explanations = GlobalFeatureImportance()
explanations.add(
feature_names=list(X.columns),
importance_scores=results["importances_mean"]
)
return explanations
28 changes: 13 additions & 15 deletions omnixai/explainers/tabular/agnostic/shap.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,15 @@

from ..base import TabularExplainer
from ....data.tabular import Tabular
from ....explanations.tabular.feature_importance import FeatureImportance
from ....explanations.tabular.feature_importance import \
FeatureImportance


class ShapTabular(TabularExplainer):
"""
The SHAP explainer for tabular data.
If using this explainer, please cite the original work: https://github.com/slundberg/shap.
"""

explanation_type = "local"
alias = ["shap"]

Expand All @@ -47,19 +47,22 @@ def __init__(
Please refer to the doc of `shap.KernelExplainer`.
"""
super().__init__(training_data=training_data, predict_function=predict_function, mode=mode, **kwargs)
self.link = kwargs.get("link", None)
if self.link is None:
self.link = "logit" if self.mode == "classification" else "identity"

self.ignored_features = set(ignored_features) if ignored_features is not None else set()
if self.target_column is not None:
assert self.target_column not in self.ignored_features, \
f"The target column {self.target_column} cannot be in the ignored feature list."
self.valid_indices = [i for i, f in enumerate(self.feature_columns) if f not in self.ignored_features]

if "nsamples" not in kwargs:
kwargs["nsamples"] = 100
self.background_data = shap.sample(self.data, nsamples=kwargs["nsamples"])
self.background_data = shap.sample(self.data, nsamples=kwargs.get("nsamples", 100))
self.explainer = shap.KernelExplainer(self.predict_fn, self.background_data, link=self.link, **kwargs)

def explain(self, X, y=None, **kwargs) -> FeatureImportance:
"""
Generates the feature-importance explanations for the input instances.
Generates the local SHAP explanations for the input instances.
:param X: A batch of input instances. When ``X`` is `pd.DataFrame`
or `np.ndarray`, ``X`` will be converted into `Tabular` automatically.
Expand All @@ -68,7 +71,7 @@ def explain(self, X, y=None, **kwargs) -> FeatureImportance:
when ``y = None``.
:param kwargs: Additional parameters for `shap.KernelExplainer.shap_values`,
e.g., ``nsamples`` -- the number of times to re-evaluate the model when explaining each prediction.
:return: The feature-importance explanations for all the input instances.
:return: The feature importance explanations.
"""
X = self._to_tabular(X).remove_target_column()
explanations = FeatureImportance(self.mode)
Expand All @@ -90,12 +93,7 @@ def explain(self, X, y=None, **kwargs) -> FeatureImportance:
y = None

if len(self.ignored_features) == 0:
explainer = shap.KernelExplainer(
self.predict_fn, self.background_data,
link="logit" if self.mode == "classification" else "identity", **kwargs
)
shap_values = explainer.shap_values(instances, **kwargs)

shap_values = self.explainer.shap_values(instances, **kwargs)
for i, instance in enumerate(instances):
df = X.iloc(i).to_pd()
feature_values = \
Expand All @@ -120,12 +118,12 @@ def _predict(_x):
_y = np.tile(instance, (_x.shape[0], 1))
_y[:, self.valid_indices] = _x
return self.predict_fn(_y)

predict_function = _predict
test_x = instance[self.valid_indices]

explainer = shap.KernelExplainer(
predict_function, self.background_data[:, self.valid_indices],
link="logit" if self.mode == "classification" else "identity", **kwargs
link=self.link, **kwargs
)
shap_values = explainer.shap_values(np.expand_dims(test_x, axis=0), **kwargs)

Expand Down
122 changes: 122 additions & 0 deletions omnixai/explainers/tabular/agnostic/shap_global.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
#
# Copyright (c) 2022 salesforce.com, inc.
# All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
# For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
#
"""
The SHAP explainer for global feature importance.
"""
import shap
import numpy as np
from typing import Callable, List

from ..base import TabularExplainer
from ....data.tabular import Tabular
from ....explanations.tabular.feature_importance import GlobalFeatureImportance


class GlobalShapTabular(TabularExplainer):
"""
The SHAP explainer for global feature importance.
If using this explainer, please cite the original work: https://github.com/slundberg/shap.
"""

explanation_type = "global"
alias = ["shap_global"]

def __init__(
self,
training_data: Tabular,
predict_function: Callable,
mode: str = "classification",
ignored_features: List = None,
**kwargs
):
"""
:param training_data: The data used to initialize a SHAP explainer. ``training_data``
can be the training dataset for training the machine learning model. If the training
dataset is large, please set parameter ``nsamples``, e.g., ``nsamples = 100``.
:param predict_function: The prediction function corresponding to the model to explain.
When the model is for classification, the outputs of the ``predict_function``
are the class probabilities. When the model is for regression, the outputs of
the ``predict_function`` are the estimated values.
:param mode: The task type, e.g., `classification` or `regression`.
:param ignored_features: The features ignored in computing feature importance scores.
:param kwargs: Additional parameters to initialize `shap.KernelExplainer`, e.g., ``nsamples``.
Please refer to the doc of `shap.KernelExplainer`.
"""
super().__init__(training_data=training_data, predict_function=predict_function, mode=mode, **kwargs)
self.ignored_features = set(ignored_features) if ignored_features is not None else set()
if self.target_column is not None:
assert self.target_column not in self.ignored_features, \
f"The target column {self.target_column} cannot be in the ignored feature list."
self.valid_indices = [i for i, f in enumerate(self.feature_columns) if f not in self.ignored_features]

if "nsamples" not in kwargs:
kwargs["nsamples"] = 100
self.background_data = shap.sample(self.data, nsamples=kwargs["nsamples"])
self.sampled_data = shap.sample(self.data, nsamples=kwargs["nsamples"])

def _explain_global(self, X, **kwargs) -> GlobalFeatureImportance:
if "nsamples" not in kwargs:
kwargs["nsamples"] = 100
instances = self.sampled_data if X is None else \
self.transformer.transform(X.remove_target_column())

explanations = GlobalFeatureImportance()
explainer = shap.KernelExplainer(
self.predict_fn, self.background_data,
link="logit" if self.mode == "classification" else "identity", **kwargs
)
shap_values = explainer.shap_values(instances, **kwargs)

if self.mode == "classification":
values = 0
for v in shap_values:
values += np.abs(v)
values /= len(shap_values)
shap_values = values

importance_scores = np.mean(np.abs(shap_values), axis=0)
explanations.add(
feature_names=self.feature_columns,
importance_scores=importance_scores,
sort=True
)
return explanations

def explain(
self,
X: Tabular = None,
**kwargs
):
"""
Generates the global SHAP explanations.
:param X: The data will be used to compute global SHAP values, i.e., the mean of the absolute
SHAP value for each feature. If `X` is None, a set of training samples will be used.
:param kwargs: Additional parameters for `shap.KernelExplainer.shap_values`,
e.g., ``nsamples`` -- the number of times to re-evaluate the model when explaining each prediction.
:return: The global feature importance explanations.
"""
return self._explain_global(X=X, **kwargs)

def save(
self,
directory: str,
filename: str = None,
**kwargs
):
"""
Saves the initialized explainer.
:param directory: The folder for the dumped explainer.
:param filename: The filename (the explainer class name if it is None).
"""
super().save(
directory=directory,
filename=filename,
ignored_attributes=["data"],
**kwargs
)
Loading

0 comments on commit b1ce148

Please sign in to comment.