From b471a7c4f2cb34707eb0e35b27ff2e178ea0aea0 Mon Sep 17 00:00:00 2001 From: ywz Date: Mon, 14 Nov 2022 15:06:04 +0800 Subject: [PATCH 01/15] Add global Shap values --- omnixai/explainers/tabular/agnostic/shap.py | 65 +++++++++++++++---- .../shap/test_shap_tabular_global.py | 33 ++++++++++ 2 files changed, 84 insertions(+), 14 deletions(-) create mode 100644 omnixai/tests/explainers/shap/test_shap_tabular_global.py diff --git a/omnixai/explainers/tabular/agnostic/shap.py b/omnixai/explainers/tabular/agnostic/shap.py index 8afb78d1..e29abb36 100644 --- a/omnixai/explainers/tabular/agnostic/shap.py +++ b/omnixai/explainers/tabular/agnostic/shap.py @@ -13,7 +13,8 @@ from ..base import TabularExplainer from ....data.tabular import Tabular -from ....explanations.tabular.feature_importance import FeatureImportance +from ....explanations.tabular.feature_importance import \ + FeatureImportance, GlobalFeatureImportance class ShapTabular(TabularExplainer): @@ -22,7 +23,7 @@ class ShapTabular(TabularExplainer): If using this explainer, please cite the original work: https://github.com/slundberg/shap. """ - explanation_type = "local" + explanation_type = "both" alias = ["shap"] def __init__( @@ -57,19 +58,34 @@ def __init__( kwargs["nsamples"] = 100 self.background_data = shap.sample(self.data, nsamples=kwargs["nsamples"]) - def explain(self, X, y=None, **kwargs) -> FeatureImportance: - """ - Generates the feature-importance explanations for the input instances. + def _explain_global(self, **kwargs) -> GlobalFeatureImportance: + if "nsamples" not in kwargs: + kwargs["nsamples"] = 100 - :param X: A batch of input instances. When ``X`` is `pd.DataFrame` - or `np.ndarray`, ``X`` will be converted into `Tabular` automatically. - :param y: A batch of labels to explain. For regression, ``y`` is ignored. - For classification, the top predicted label of each instance will be explained - when ``y = None``. - :param kwargs: Additional parameters for `shap.KernelExplainer.shap_values`, - e.g., ``nsamples`` -- the number of times to re-evaluate the model when explaining each prediction. - :return: The feature-importance explanations for all the input instances. - """ + instances = self.background_data + explanations = GlobalFeatureImportance() + explainer = shap.KernelExplainer( + self.predict_fn, self.background_data, + link="logit" if self.mode == "classification" else "identity", **kwargs + ) + shap_values = explainer.shap_values(instances, **kwargs) + + if self.mode == "classification": + values = 0 + for v in shap_values: + values += np.abs(v) + values /= len(shap_values) + shap_values = values + + importance_scores = np.mean(np.abs(shap_values), axis=0) + explanations.add( + feature_names=self.feature_columns, + importance_scores=importance_scores, + sort=True + ) + return explanations + + def _explain_local(self, X, y=None, **kwargs) -> FeatureImportance: X = self._to_tabular(X).remove_target_column() explanations = FeatureImportance(self.mode) instances = self.transformer.transform(X) @@ -120,6 +136,7 @@ def _predict(_x): _y = np.tile(instance, (_x.shape[0], 1)) _y[:, self.valid_indices] = _x return self.predict_fn(_y) + predict_function = _predict test_x = instance[self.valid_indices] @@ -148,6 +165,26 @@ def _predict(_x): ) return explanations + def explain(self, X=None, y=None, **kwargs): + """ + Generates the local SHAP explanations for the input instances or global SHAP explanations. + + :param X: A batch of input instances. When ``X`` is `pd.DataFrame` + or `np.ndarray`, ``X`` will be converted into `Tabular` automatically. If X is None, + it will compute global SHAP values, i.e., the mean of the absolute SHAP value for + each feature. + :param y: A batch of labels to explain. For regression, ``y`` is ignored. + For classification, the top predicted label of each instance will be explained + when ``y = None``. + :param kwargs: Additional parameters for `shap.KernelExplainer.shap_values`, + e.g., ``nsamples`` -- the number of times to re-evaluate the model when explaining each prediction. + :return: The feature-importance explanations. + """ + if X is not None: + return self._explain_local(X=X, y=y, **kwargs) + else: + return self._explain_global(**kwargs) + def save( self, directory: str, diff --git a/omnixai/tests/explainers/shap/test_shap_tabular_global.py b/omnixai/tests/explainers/shap/test_shap_tabular_global.py new file mode 100644 index 00000000..9b1d3938 --- /dev/null +++ b/omnixai/tests/explainers/shap/test_shap_tabular_global.py @@ -0,0 +1,33 @@ +# +# Copyright (c) 2022 salesforce.com, inc. +# All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause +# +import os +import unittest +import pprint +from omnixai.utils.misc import set_random_seed +from omnixai.explainers.tabular import ShapTabular +from omnixai.tests.explainers.tasks import TabularClassification + + +class TestShapTabular(unittest.TestCase): + def test(self): + base_folder = os.path.join(os.path.dirname(os.path.abspath(__file__)), "..") + task = TabularClassification(base_folder).train_adult(num_training_samples=2000) + predict_function = lambda z: task.model.predict_proba(task.transform.transform(z)) + + set_random_seed() + explainer = ShapTabular( + training_data=task.train_data, + predict_function=predict_function, + ignored_features=None, + nsamples=100 + ) + explanations = explainer.explain() + pprint.pprint(explanations.get_explanations()) + + +if __name__ == "__main__": + unittest.main() From 31aafb7c74060810bdce36ea4368cc9ab8055c7a Mon Sep 17 00:00:00 2001 From: ywz Date: Mon, 14 Nov 2022 15:06:25 +0800 Subject: [PATCH 02/15] Update setup.py --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index e553bc10..eaa34d15 100644 --- a/setup.py +++ b/setup.py @@ -16,7 +16,7 @@ setup( name="omnixai", - version="1.2.2", + version="1.2.3", author="Wenzhuo Yang, Hung Le, Tanmay Shivprasad Laud, Silvio Savarese, Steven C.H. Hoi", description="OmniXAI: An Explainable AI Toolbox", long_description=open("README.md", "r", encoding="utf-8").read(), From fc924ae862de0d8124d6de7af9fd715e0d791b0b Mon Sep 17 00:00:00 2001 From: ywz Date: Mon, 14 Nov 2022 20:33:59 +0800 Subject: [PATCH 03/15] Add a KNN-based counterfactual explainer --- omnixai/explainers/tabular/__init__.py | 2 + .../explainers/tabular/counterfactual/knn.py | 112 ++++++++++++++++++ .../tabular/counterfactual/mace/retrieval.py | 18 ++- omnixai/tests/explainers/knn/test_ce_knn.py | 39 ++++++ .../tests/visualization/dashboard_tabular.py | 2 +- 5 files changed, 168 insertions(+), 5 deletions(-) create mode 100644 omnixai/explainers/tabular/counterfactual/knn.py create mode 100644 omnixai/tests/explainers/knn/test_ce_knn.py diff --git a/omnixai/explainers/tabular/__init__.py b/omnixai/explainers/tabular/__init__.py index 0a1dcdb6..31737c46 100644 --- a/omnixai/explainers/tabular/__init__.py +++ b/omnixai/explainers/tabular/__init__.py @@ -13,6 +13,7 @@ from .agnostic.L2X.l2x import L2XTabular from .counterfactual.mace.mace import MACEExplainer from .counterfactual.ce import CounterfactualExplainer +from .counterfactual.knn import KNNCounterfactualExplainer from .specific.ig import IntegratedGradientTabular from .specific.linear import LinearRegression from .specific.linear import LogisticRegression @@ -31,6 +32,7 @@ "L2XTabular", "MACEExplainer", "CounterfactualExplainer", + "KNNCounterfactualExplainer", "LinearRegression", "LogisticRegression", "TreeRegressor", diff --git a/omnixai/explainers/tabular/counterfactual/knn.py b/omnixai/explainers/tabular/counterfactual/knn.py new file mode 100644 index 00000000..0c07500c --- /dev/null +++ b/omnixai/explainers/tabular/counterfactual/knn.py @@ -0,0 +1,112 @@ +# +# Copyright (c) 2022 salesforce.com, inc. +# All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause +# +""" +The KNN-based counterfactual explainer for tabular data. +""" +import numpy as np +import pandas as pd +from typing import List, Callable, Union + +from ..base import ExplainerBase +from ...tabular.base import TabularExplainerMixin +from ....data.tabular import Tabular +from ....explanations.tabular.counterfactual import CFExplanation + +from .mace.retrieval import CFRetrieval +from .mace.diversify import DiversityModule + + +class KNNCounterfactualExplainer(ExplainerBase, TabularExplainerMixin): + """ + The counterfactual explainer for tabular data based on KNN search. Given a query instance, + it finds the instances in the training dataset that are close to the query with the desired label. + """ + explanation_type = "local" + alias = ["ce_knn", "knn_ce"] + + def __init__( + self, + training_data: Tabular, + predict_function: Callable, + mode: str = "classification", + **kwargs, + ): + """ + :param training_data: The data used to initialize a MACE explainer. ``training_data`` + can be the training dataset for training the machine learning model. + :param predict_function: The prediction function corresponding to the model to explain. + The model should be a classifier, the outputs of the ``predict_function`` + are the class probabilities. + :param mode: The task type can be `classification` only. + """ + super().__init__() + assert mode == "classification", "MACE supports classification tasks only." + self.predict_function = predict_function + + self.categorical_columns = training_data.categorical_columns + self.target_column = training_data.target_column + self.original_feature_columns = training_data.columns + + self.recall = CFRetrieval(training_data, predict_function, None, **kwargs) + self.diversity = DiversityModule(training_data) + + def explain( + self, + X: Tabular, + y: Union[List, np.ndarray] = None, + max_number_examples: int = 5, + **kwargs + ) -> CFExplanation: + """ + Generates counterfactual explanations. + + :param X: A batch of input instances. When ``X`` is `pd.DataFrame` + or `np.ndarray`, ``X`` will be converted into `Tabular` automatically. + :param y: A batch of the desired labels, which should be different from the predicted labels of ``X``. + If ``y = None``, the desired labels will be the labels different from the predicted labels of ``X``. + :param max_number_examples: The maximum number of the generated counterfactual + examples per class for each input instance. + :return: A CFExplanation object containing the generated explanations. + """ + if y is not None: + assert len(y) == X.shape[0], ( + f"The length of `y` should equal the number of instances in `X`, " f"got {len(y)} != {X.shape[0]}" + ) + + X = self._to_tabular(X).remove_target_column() + scores = self.predict_function(X) + labels = np.argmax(scores, axis=1) + num_classes = scores.shape[1] + + explanations = CFExplanation() + for i in range(X.shape[0]): + x = X.iloc(i) + label = int(labels[i]) + if y is None or y[i] == label: + desired_labels = [z for z in range(num_classes) if z != label] + else: + desired_labels = [int(y[i])] + + all_cfs = [] + for desired_label in desired_labels: + df, _ = self.recall.get_nn_samples(x, desired_label) + examples = Tabular(df, categorical_columns=x.categorical_columns) + cfs = self.diversity.get_diverse_cfs( + self.predict_function, x, examples, + oracle_function=lambda _s: int(desired_label == np.argmax(_s)), + desired_label=desired_label, k=max_number_examples + ) + cfs_df = cfs.to_pd() + if x.continuous_columns: + cfs_df = cfs_df.astype({c: float for c in x.continuous_columns}) + cfs_df["label"] = desired_label + all_cfs.append(cfs_df) + + instance_df = x.to_pd() + instance_df["label"] = label + explanations.add(query=instance_df, cfs=pd.concat(all_cfs) if len(all_cfs) > 0 else None) + return explanations diff --git a/omnixai/explainers/tabular/counterfactual/mace/retrieval.py b/omnixai/explainers/tabular/counterfactual/mace/retrieval.py index 3692e28f..6cc8a108 100644 --- a/omnixai/explainers/tabular/counterfactual/mace/retrieval.py +++ b/omnixai/explainers/tabular/counterfactual/mace/retrieval.py @@ -152,13 +152,13 @@ def _pick_top_columns(self, x: Tabular, candidate_features: Dict, desired_label: columns = [c for c, _ in column_scores][:top_k] return {f: v for f, v in candidate_features.items() if f in columns} - def get_cf_features(self, instance: Tabular, desired_label: int) -> (Dict, np.ndarray): + def get_nn_samples(self, instance: Tabular, desired_label: int) -> (pd.DataFrame, np.ndarray): """ - Finds candidate features for generating counterfactual examples. + Finds nearest neighbor samples in a desired class. :param instance: The query instance. :param desired_label: The desired label. - :return: The candidate features and the indices of the nearest neighbors. + :return: The nearest neighbor samples and the corresponding indices. """ assert isinstance(instance, Tabular), "Input ``instance`` should be an instance of Tabular." assert instance.shape[0] == 1, "Input ``instance`` can only contain one instance." @@ -171,9 +171,19 @@ def get_cf_features(self, instance: Tabular, desired_label: int) -> (Dict, np.nd ) ) indices = self._knn_query(query, desired_label, self.num_neighbors)[0] + y = self.subset.iloc(indices).to_pd(copy=False) + return y, indices + def get_cf_features(self, instance: Tabular, desired_label: int) -> (Dict, np.ndarray): + """ + Finds candidate features for generating counterfactual examples. + + :param instance: The query instance. + :param desired_label: The desired label. + :return: The candidate features and the indices of the nearest neighbors. + """ x = instance.to_pd(copy=False) - y = self.subset.iloc(indices).to_pd(copy=False) + y, indices = self.get_nn_samples(instance, desired_label) cate_candidates, cont_candidates = {}, {} # Categorical feature difference diff --git a/omnixai/tests/explainers/knn/test_ce_knn.py b/omnixai/tests/explainers/knn/test_ce_knn.py new file mode 100644 index 00000000..208e9954 --- /dev/null +++ b/omnixai/tests/explainers/knn/test_ce_knn.py @@ -0,0 +1,39 @@ +# +# Copyright (c) 2022 salesforce.com, inc. +# All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause +# +import os +import unittest +import pandas as pd +from omnixai.explainers.tabular import KNNCounterfactualExplainer +from omnixai.tests.explainers.tasks import TabularClassification + +pd.set_option("display.max_columns", None) + + +class TestKNNCE(unittest.TestCase): + def setUp(self): + base_folder = os.path.join(os.path.dirname(os.path.abspath(__file__)), "..") + task = TabularClassification(base_folder).train_adult(num_training_samples=2000) + self.data = task.data + self.predict_function = lambda z: task.model.predict_proba(task.transform.transform(z)) + self.test_instances = task.test_data.iloc(list(range(5))).remove_target_column() + + def test_explain(self): + explainer = KNNCounterfactualExplainer( + training_data=self.data, + predict_function=self.predict_function, + ) + explanations = explainer.explain(self.test_instances) + for explanation in explanations.get_explanations(): + print("Query instance:") + print(explanation["query"]) + print("Counterfactual examples:") + print(explanation["counterfactual"]) + print("-----------------") + + +if __name__ == "__main__": + unittest.main() diff --git a/omnixai/tests/visualization/dashboard_tabular.py b/omnixai/tests/visualization/dashboard_tabular.py index 36a31ca7..c36c0e42 100644 --- a/omnixai/tests/visualization/dashboard_tabular.py +++ b/omnixai/tests/visualization/dashboard_tabular.py @@ -98,7 +98,7 @@ def test(self): prediction_explanations = explainer.explain() explainers = TabularExplainer( - explainers=["lime", "shap", "mace", "pdp", "ale"], + explainers=["lime", "shap", "mace", "knn_ce", "pdp", "ale"], mode="classification", data=self.tabular_data, model=self.model, From 21eed794904bcabb095008bd446195f8b8b0560f Mon Sep 17 00:00:00 2001 From: ywz Date: Wed, 16 Nov 2022 21:17:48 +0800 Subject: [PATCH 04/15] Add permutation feature importance --- omnixai/explainers/tabular/__init__.py | 2 + .../tabular/agnostic/permutation.py | 114 ++++++++++++++++++ .../test_permutation_classification.py | 25 ++++ omnixai/tests/explainers/tasks.py | 1 + 4 files changed, 142 insertions(+) create mode 100644 omnixai/explainers/tabular/agnostic/permutation.py create mode 100644 omnixai/tests/explainers/permutation/test_permutation_classification.py diff --git a/omnixai/explainers/tabular/__init__.py b/omnixai/explainers/tabular/__init__.py index 31737c46..70057073 100644 --- a/omnixai/explainers/tabular/__init__.py +++ b/omnixai/explainers/tabular/__init__.py @@ -11,6 +11,7 @@ from .agnostic.ale import ALE from .agnostic.sensitivity import SensitivityAnalysisTabular from .agnostic.L2X.l2x import L2XTabular +from .agnostic.permutation import PermutationImportance from .counterfactual.mace.mace import MACEExplainer from .counterfactual.ce import CounterfactualExplainer from .counterfactual.knn import KNNCounterfactualExplainer @@ -30,6 +31,7 @@ "ALE", "SensitivityAnalysisTabular", "L2XTabular", + "PermutationImportance", "MACEExplainer", "CounterfactualExplainer", "KNNCounterfactualExplainer", diff --git a/omnixai/explainers/tabular/agnostic/permutation.py b/omnixai/explainers/tabular/agnostic/permutation.py new file mode 100644 index 00000000..44270689 --- /dev/null +++ b/omnixai/explainers/tabular/agnostic/permutation.py @@ -0,0 +1,114 @@ +# +# Copyright (c) 2022 salesforce.com, inc. +# All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause +# +""" +The permutation feature importance explanation for tabular data. +""" +import numpy as np +import pandas as pd +from typing import Callable +from sklearn.metrics import log_loss +from sklearn.inspection import permutation_importance + +from ..base import ExplainerBase, TabularExplainerMixin +from ....data.tabular import Tabular +from ....explanations.tabular.feature_importance import GlobalFeatureImportance + + +class PermutationImportance(ExplainerBase, TabularExplainerMixin): + """ + The permutation feature importance explanations for tabular data. The permutation feature + importance is defined to be the decrease in a model score when a single feature value + is randomly shuffled. + """ + + explanation_type = "global" + alias = ["permutation"] + + def __init__(self, training_data: Tabular, predict_function, mode="classification", **kwargs): + """ + :param training_data: The training dataset for training the machine learning model. + :param predict_function: The prediction function corresponding to the model to explain. + When the model is for classification, the outputs of the ``predict_function`` + are the class probabilities. When the model is for regression, the outputs of + the ``predict_function`` are the estimated values. + :param mode: The task type, e.g., `classification` or `regression`. + """ + super().__init__() + assert isinstance(training_data, Tabular), \ + "training_data should be an instance of Tabular." + assert mode in ["classification", "regression"], \ + "`mode` can only be `classification` or `regression`." + + self.categorical_columns = training_data.categorical_columns + self.predict_function = predict_function + self.mode = mode + + def _build_score_function(self, score_func=None): + if score_func is not None: + def _score(estimator, x, y): + z = self.predict_function( + Tabular(x, categorical_columns=self.categorical_columns) + ) + return score_func(z, y) + elif self.mode == "classification": + def _score(estimator, x, y): + z = self.predict_function( + Tabular(x, categorical_columns=self.categorical_columns) + ) + return log_loss(y, z) + else: + def _score(estimator, x, y): + z = self.predict_function( + Tabular(x, categorical_columns=self.categorical_columns) + ) + return np.mean((z - y) ** 2) + return _score + + def explain( + self, + X, + y, + n_repeats: int = 30, + score_func: Callable = None + ) -> GlobalFeatureImportance: + """ + Generate permutation feature importance scores. + + :param X: Data on which permutation importance will be computed. + :param y: Targets or labels. + :param n_repeats: The number of times a feature is randomly shuffled. + :param score_func: The score/loss function measuring the difference between + predictions and ground-truth targets. + :return: The permutation feature importance explanations. + """ + if isinstance(y, (list, tuple)): + y = np.array(y) + elif isinstance(y, pd.DataFrame): + y = y.values + elif isinstance(y, np.ndarray): + y = y + else: + raise ValueError(f"The type of `y` is {type(y)}, which is not supported." + f"`y` should be a list, a numpy array or a pandas dataframe.") + if y.ndim > 1: + y = y.flatten() + + assert X.shape[0] == len(y), \ + "The numbers of samples in `X` and `y` are different." + X = X.remove_target_column() + + class _Estimator: + def fit(self): + pass + + results = permutation_importance( + estimator=_Estimator(), + X=X.to_pd(copy=False), + y=y, + scoring=self._build_score_function(score_func) + ) + print(results) diff --git a/omnixai/tests/explainers/permutation/test_permutation_classification.py b/omnixai/tests/explainers/permutation/test_permutation_classification.py new file mode 100644 index 00000000..648a96cb --- /dev/null +++ b/omnixai/tests/explainers/permutation/test_permutation_classification.py @@ -0,0 +1,25 @@ +# +# Copyright (c) 2022 salesforce.com, inc. +# All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause +# +import os +import unittest +from omnixai.utils.misc import set_random_seed +from omnixai.explainers.tabular import PermutationImportance +from omnixai.tests.explainers.tasks import TabularClassification + + +class TestPermutation(unittest.TestCase): + def test_1(self): + set_random_seed() + base_folder = os.path.join(os.path.dirname(os.path.abspath(__file__)), "..") + task = TabularClassification(base_folder).train_adult(num_training_samples=2000) + predict_function = lambda z: task.model.predict_proba(task.transform.transform(z)) + explainer = PermutationImportance(training_data=task.train_data, predict_function=predict_function) + explanations = explainer.explain(X=task.test_data, y=task.test_targets) + + +if __name__ == "__main__": + unittest.main() diff --git a/omnixai/tests/explainers/tasks.py b/omnixai/tests/explainers/tasks.py index 6248c918..fa69aeb0 100644 --- a/omnixai/tests/explainers/tasks.py +++ b/omnixai/tests/explainers/tasks.py @@ -108,6 +108,7 @@ def train_adult(self, num_training_samples=None): data=tabular_data, train_data=transformer.invert(train), test_data=transformer.invert(test), + test_targets=labels_test ) def train_iris(self): From dc98ca68c688ee335ad0ac523e97730e0bd7b475 Mon Sep 17 00:00:00 2001 From: ywz Date: Wed, 16 Nov 2022 21:21:41 +0800 Subject: [PATCH 05/15] Add permutation feature importance --- omnixai/explainers/tabular/agnostic/permutation.py | 13 +++++++++---- .../permutation/test_permutation_classification.py | 3 ++- 2 files changed, 11 insertions(+), 5 deletions(-) diff --git a/omnixai/explainers/tabular/agnostic/permutation.py b/omnixai/explainers/tabular/agnostic/permutation.py index 44270689..6cd36cca 100644 --- a/omnixai/explainers/tabular/agnostic/permutation.py +++ b/omnixai/explainers/tabular/agnostic/permutation.py @@ -59,13 +59,13 @@ def _score(estimator, x, y): z = self.predict_function( Tabular(x, categorical_columns=self.categorical_columns) ) - return log_loss(y, z) + return -log_loss(y, z) else: def _score(estimator, x, y): z = self.predict_function( Tabular(x, categorical_columns=self.categorical_columns) ) - return np.mean((z - y) ** 2) + return -np.mean((z - y) ** 2) return _score def explain( @@ -81,7 +81,7 @@ def explain( :param X: Data on which permutation importance will be computed. :param y: Targets or labels. :param n_repeats: The number of times a feature is randomly shuffled. - :param score_func: The score/loss function measuring the difference between + :param score_func: The score function measuring the difference between predictions and ground-truth targets. :return: The permutation feature importance explanations. """ @@ -105,10 +105,15 @@ class _Estimator: def fit(self): pass + explanations = GlobalFeatureImportance() results = permutation_importance( estimator=_Estimator(), X=X.to_pd(copy=False), y=y, scoring=self._build_score_function(score_func) ) - print(results) + explanations.add( + feature_names=list(X.columns), + importance_scores=results["importances_mean"] + ) + return explanations diff --git a/omnixai/tests/explainers/permutation/test_permutation_classification.py b/omnixai/tests/explainers/permutation/test_permutation_classification.py index 648a96cb..70d0a793 100644 --- a/omnixai/tests/explainers/permutation/test_permutation_classification.py +++ b/omnixai/tests/explainers/permutation/test_permutation_classification.py @@ -12,13 +12,14 @@ class TestPermutation(unittest.TestCase): - def test_1(self): + def test(self): set_random_seed() base_folder = os.path.join(os.path.dirname(os.path.abspath(__file__)), "..") task = TabularClassification(base_folder).train_adult(num_training_samples=2000) predict_function = lambda z: task.model.predict_proba(task.transform.transform(z)) explainer = PermutationImportance(training_data=task.train_data, predict_function=predict_function) explanations = explainer.explain(X=task.test_data, y=task.test_targets) + explanations.ipython_plot() if __name__ == "__main__": From 56cf461e393ceb6a2167a91a7afbaf0a6f51ef41 Mon Sep 17 00:00:00 2001 From: ywz Date: Wed, 16 Nov 2022 21:24:46 +0800 Subject: [PATCH 06/15] Add permutation feature importance --- .../tabular/agnostic/permutation.py | 2 +- .../test_permutation_classification.py | 2 +- .../test_permutation_regression.py | 26 +++++++++++++++++++ 3 files changed, 28 insertions(+), 2 deletions(-) create mode 100644 omnixai/tests/explainers/permutation/test_permutation_regression.py diff --git a/omnixai/explainers/tabular/agnostic/permutation.py b/omnixai/explainers/tabular/agnostic/permutation.py index 6cd36cca..eb824197 100644 --- a/omnixai/explainers/tabular/agnostic/permutation.py +++ b/omnixai/explainers/tabular/agnostic/permutation.py @@ -105,13 +105,13 @@ class _Estimator: def fit(self): pass - explanations = GlobalFeatureImportance() results = permutation_importance( estimator=_Estimator(), X=X.to_pd(copy=False), y=y, scoring=self._build_score_function(score_func) ) + explanations = GlobalFeatureImportance() explanations.add( feature_names=list(X.columns), importance_scores=results["importances_mean"] diff --git a/omnixai/tests/explainers/permutation/test_permutation_classification.py b/omnixai/tests/explainers/permutation/test_permutation_classification.py index 70d0a793..76db3828 100644 --- a/omnixai/tests/explainers/permutation/test_permutation_classification.py +++ b/omnixai/tests/explainers/permutation/test_permutation_classification.py @@ -19,7 +19,7 @@ def test(self): predict_function = lambda z: task.model.predict_proba(task.transform.transform(z)) explainer = PermutationImportance(training_data=task.train_data, predict_function=predict_function) explanations = explainer.explain(X=task.test_data, y=task.test_targets) - explanations.ipython_plot() + explanations.plotly_plot() if __name__ == "__main__": diff --git a/omnixai/tests/explainers/permutation/test_permutation_regression.py b/omnixai/tests/explainers/permutation/test_permutation_regression.py new file mode 100644 index 00000000..dc815269 --- /dev/null +++ b/omnixai/tests/explainers/permutation/test_permutation_regression.py @@ -0,0 +1,26 @@ +# +# Copyright (c) 2022 salesforce.com, inc. +# All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause +# +import unittest +from omnixai.utils.misc import set_random_seed +from omnixai.explainers.tabular import PermutationImportance +from omnixai.tests.explainers.tasks import TabularRegression + + +class TestPDPTabular(unittest.TestCase): + def test(self): + set_random_seed() + task = TabularRegression().train_boston() + predict_function = lambda z: task.model.predict(task.transform.transform(z)) + explainer = PermutationImportance( + training_data=task.train_data, predict_function=predict_function, mode="regression" + ) + explanations = explainer.explain(X=task.test_data, y=task.test_targets) + explanations.plotly_plot() + + +if __name__ == "__main__": + unittest.main() From 110ed3110e98cff730b325c0e4b8ee225c8b7775 Mon Sep 17 00:00:00 2001 From: ywz Date: Wed, 16 Nov 2022 21:26:28 +0800 Subject: [PATCH 07/15] Add permutation feature importance --- omnixai/explainers/tabular/agnostic/permutation.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/omnixai/explainers/tabular/agnostic/permutation.py b/omnixai/explainers/tabular/agnostic/permutation.py index eb824197..4f7dfbb8 100644 --- a/omnixai/explainers/tabular/agnostic/permutation.py +++ b/omnixai/explainers/tabular/agnostic/permutation.py @@ -53,7 +53,7 @@ def _score(estimator, x, y): z = self.predict_function( Tabular(x, categorical_columns=self.categorical_columns) ) - return score_func(z, y) + return score_func(y, z) elif self.mode == "classification": def _score(estimator, x, y): z = self.predict_function( @@ -82,7 +82,7 @@ def explain( :param y: Targets or labels. :param n_repeats: The number of times a feature is randomly shuffled. :param score_func: The score function measuring the difference between - predictions and ground-truth targets. + ground-truth targets and predictions, e.g., -sklearn.metrics.log_loss(y_true, y_pred). :return: The permutation feature importance explanations. """ if isinstance(y, (list, tuple)): From fd91b3370c2a74b527fd70cd2b12e1cfbdff37df Mon Sep 17 00:00:00 2001 From: ywz Date: Wed, 16 Nov 2022 21:29:01 +0800 Subject: [PATCH 08/15] Add permutation feature importance --- omnixai/explainers/tabular/agnostic/permutation.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/omnixai/explainers/tabular/agnostic/permutation.py b/omnixai/explainers/tabular/agnostic/permutation.py index 4f7dfbb8..b96c0e8a 100644 --- a/omnixai/explainers/tabular/agnostic/permutation.py +++ b/omnixai/explainers/tabular/agnostic/permutation.py @@ -9,7 +9,7 @@ """ import numpy as np import pandas as pd -from typing import Callable +from typing import Callable, Union from sklearn.metrics import log_loss from sklearn.inspection import permutation_importance @@ -70,8 +70,8 @@ def _score(estimator, x, y): def explain( self, - X, - y, + X: Tabular, + y: Union[np.ndarray, pd.DataFrame], n_repeats: int = 30, score_func: Callable = None ) -> GlobalFeatureImportance: @@ -85,6 +85,8 @@ def explain( ground-truth targets and predictions, e.g., -sklearn.metrics.log_loss(y_true, y_pred). :return: The permutation feature importance explanations. """ + assert X is not None and y is not None, \ + "The test data `X` and target `y` cannot be None." if isinstance(y, (list, tuple)): y = np.array(y) elif isinstance(y, pd.DataFrame): From ccc6875b85ac7658a318812373de2cb6bec9a350 Mon Sep 17 00:00:00 2001 From: ywz Date: Wed, 16 Nov 2022 21:30:53 +0800 Subject: [PATCH 09/15] Add permutation feature importance --- omnixai/explainers/tabular/agnostic/permutation.py | 11 +---------- 1 file changed, 1 insertion(+), 10 deletions(-) diff --git a/omnixai/explainers/tabular/agnostic/permutation.py b/omnixai/explainers/tabular/agnostic/permutation.py index b96c0e8a..acf2019f 100644 --- a/omnixai/explainers/tabular/agnostic/permutation.py +++ b/omnixai/explainers/tabular/agnostic/permutation.py @@ -87,18 +87,9 @@ def explain( """ assert X is not None and y is not None, \ "The test data `X` and target `y` cannot be None." - if isinstance(y, (list, tuple)): - y = np.array(y) - elif isinstance(y, pd.DataFrame): - y = y.values - elif isinstance(y, np.ndarray): - y = y - else: - raise ValueError(f"The type of `y` is {type(y)}, which is not supported." - f"`y` should be a list, a numpy array or a pandas dataframe.") + y = y.values if isinstance(y, pd.DataFrame) else np.array(y) if y.ndim > 1: y = y.flatten() - assert X.shape[0] == len(y), \ "The numbers of samples in `X` and `y` are different." X = X.remove_target_column() From 0db40a394cd87be209382e41fb41c6016dd3ae58 Mon Sep 17 00:00:00 2001 From: ywz Date: Wed, 16 Nov 2022 21:33:16 +0800 Subject: [PATCH 10/15] Add permutation feature importance --- omnixai/explainers/tabular/agnostic/permutation.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/omnixai/explainers/tabular/agnostic/permutation.py b/omnixai/explainers/tabular/agnostic/permutation.py index acf2019f..a629559a 100644 --- a/omnixai/explainers/tabular/agnostic/permutation.py +++ b/omnixai/explainers/tabular/agnostic/permutation.py @@ -18,6 +18,11 @@ from ....explanations.tabular.feature_importance import GlobalFeatureImportance +class _Estimator: + def fit(self): + pass + + class PermutationImportance(ExplainerBase, TabularExplainerMixin): """ The permutation feature importance explanations for tabular data. The permutation feature @@ -94,10 +99,6 @@ def explain( "The numbers of samples in `X` and `y` are different." X = X.remove_target_column() - class _Estimator: - def fit(self): - pass - results = permutation_importance( estimator=_Estimator(), X=X.to_pd(copy=False), From fddf217f0090210f25b7151f7f171c4aa7be2636 Mon Sep 17 00:00:00 2001 From: ywz Date: Thu, 17 Nov 2022 11:21:25 +0800 Subject: [PATCH 11/15] Add global Shap values --- omnixai/explainers/tabular/__init__.py | 2 + omnixai/explainers/tabular/agnostic/shap.py | 63 ++------- .../tabular/agnostic/shap_global.py | 122 ++++++++++++++++++ .../shap/test_shap_tabular_global.py | 4 +- .../tests/visualization/dashboard_tabular.py | 2 +- 5 files changed, 141 insertions(+), 52 deletions(-) create mode 100644 omnixai/explainers/tabular/agnostic/shap_global.py diff --git a/omnixai/explainers/tabular/__init__.py b/omnixai/explainers/tabular/__init__.py index 70057073..4381ad65 100644 --- a/omnixai/explainers/tabular/__init__.py +++ b/omnixai/explainers/tabular/__init__.py @@ -12,6 +12,7 @@ from .agnostic.sensitivity import SensitivityAnalysisTabular from .agnostic.L2X.l2x import L2XTabular from .agnostic.permutation import PermutationImportance +from .agnostic.shap_global import GlobalShapTabular from .counterfactual.mace.mace import MACEExplainer from .counterfactual.ce import CounterfactualExplainer from .counterfactual.knn import KNNCounterfactualExplainer @@ -32,6 +33,7 @@ "SensitivityAnalysisTabular", "L2XTabular", "PermutationImportance", + "GlobalShapTabular", "MACEExplainer", "CounterfactualExplainer", "KNNCounterfactualExplainer", diff --git a/omnixai/explainers/tabular/agnostic/shap.py b/omnixai/explainers/tabular/agnostic/shap.py index e29abb36..535273f2 100644 --- a/omnixai/explainers/tabular/agnostic/shap.py +++ b/omnixai/explainers/tabular/agnostic/shap.py @@ -14,7 +14,7 @@ from ..base import TabularExplainer from ....data.tabular import Tabular from ....explanations.tabular.feature_importance import \ - FeatureImportance, GlobalFeatureImportance + FeatureImportance class ShapTabular(TabularExplainer): @@ -23,7 +23,7 @@ class ShapTabular(TabularExplainer): If using this explainer, please cite the original work: https://github.com/slundberg/shap. """ - explanation_type = "both" + explanation_type = "local" alias = ["shap"] def __init__( @@ -58,34 +58,19 @@ def __init__( kwargs["nsamples"] = 100 self.background_data = shap.sample(self.data, nsamples=kwargs["nsamples"]) - def _explain_global(self, **kwargs) -> GlobalFeatureImportance: - if "nsamples" not in kwargs: - kwargs["nsamples"] = 100 - - instances = self.background_data - explanations = GlobalFeatureImportance() - explainer = shap.KernelExplainer( - self.predict_fn, self.background_data, - link="logit" if self.mode == "classification" else "identity", **kwargs - ) - shap_values = explainer.shap_values(instances, **kwargs) - - if self.mode == "classification": - values = 0 - for v in shap_values: - values += np.abs(v) - values /= len(shap_values) - shap_values = values - - importance_scores = np.mean(np.abs(shap_values), axis=0) - explanations.add( - feature_names=self.feature_columns, - importance_scores=importance_scores, - sort=True - ) - return explanations + def explain(self, X, y=None, **kwargs) -> FeatureImportance: + """ + Generates the local SHAP explanations for the input instances. - def _explain_local(self, X, y=None, **kwargs) -> FeatureImportance: + :param X: A batch of input instances. When ``X`` is `pd.DataFrame` + or `np.ndarray`, ``X`` will be converted into `Tabular` automatically. + :param y: A batch of labels to explain. For regression, ``y`` is ignored. + For classification, the top predicted label of each instance will be explained + when ``y = None``. + :param kwargs: Additional parameters for `shap.KernelExplainer.shap_values`, + e.g., ``nsamples`` -- the number of times to re-evaluate the model when explaining each prediction. + :return: The feature importance explanations. + """ X = self._to_tabular(X).remove_target_column() explanations = FeatureImportance(self.mode) instances = self.transformer.transform(X) @@ -165,26 +150,6 @@ def _predict(_x): ) return explanations - def explain(self, X=None, y=None, **kwargs): - """ - Generates the local SHAP explanations for the input instances or global SHAP explanations. - - :param X: A batch of input instances. When ``X`` is `pd.DataFrame` - or `np.ndarray`, ``X`` will be converted into `Tabular` automatically. If X is None, - it will compute global SHAP values, i.e., the mean of the absolute SHAP value for - each feature. - :param y: A batch of labels to explain. For regression, ``y`` is ignored. - For classification, the top predicted label of each instance will be explained - when ``y = None``. - :param kwargs: Additional parameters for `shap.KernelExplainer.shap_values`, - e.g., ``nsamples`` -- the number of times to re-evaluate the model when explaining each prediction. - :return: The feature-importance explanations. - """ - if X is not None: - return self._explain_local(X=X, y=y, **kwargs) - else: - return self._explain_global(**kwargs) - def save( self, directory: str, diff --git a/omnixai/explainers/tabular/agnostic/shap_global.py b/omnixai/explainers/tabular/agnostic/shap_global.py new file mode 100644 index 00000000..8ff9ab69 --- /dev/null +++ b/omnixai/explainers/tabular/agnostic/shap_global.py @@ -0,0 +1,122 @@ +# +# Copyright (c) 2022 salesforce.com, inc. +# All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause +# +""" +The SHAP explainer for global feature importance. +""" +import shap +import numpy as np +from typing import Callable, List + +from ..base import TabularExplainer +from ....data.tabular import Tabular +from ....explanations.tabular.feature_importance import GlobalFeatureImportance + + +class GlobalShapTabular(TabularExplainer): + """ + The SHAP explainer for global feature importance. + If using this explainer, please cite the original work: https://github.com/slundberg/shap. + """ + + explanation_type = "both" + alias = ["shap_global"] + + def __init__( + self, + training_data: Tabular, + predict_function: Callable, + mode: str = "classification", + ignored_features: List = None, + **kwargs + ): + """ + :param training_data: The data used to initialize a SHAP explainer. ``training_data`` + can be the training dataset for training the machine learning model. If the training + dataset is large, please set parameter ``nsamples``, e.g., ``nsamples = 100``. + :param predict_function: The prediction function corresponding to the model to explain. + When the model is for classification, the outputs of the ``predict_function`` + are the class probabilities. When the model is for regression, the outputs of + the ``predict_function`` are the estimated values. + :param mode: The task type, e.g., `classification` or `regression`. + :param ignored_features: The features ignored in computing feature importance scores. + :param kwargs: Additional parameters to initialize `shap.KernelExplainer`, e.g., ``nsamples``. + Please refer to the doc of `shap.KernelExplainer`. + """ + super().__init__(training_data=training_data, predict_function=predict_function, mode=mode, **kwargs) + self.ignored_features = set(ignored_features) if ignored_features is not None else set() + if self.target_column is not None: + assert self.target_column not in self.ignored_features, \ + f"The target column {self.target_column} cannot be in the ignored feature list." + self.valid_indices = [i for i, f in enumerate(self.feature_columns) if f not in self.ignored_features] + + if "nsamples" not in kwargs: + kwargs["nsamples"] = 100 + self.background_data = shap.sample(self.data, nsamples=kwargs["nsamples"]) + self.sampled_data = shap.sample(self.data, nsamples=kwargs["nsamples"]) + + def _explain_global(self, X, **kwargs) -> GlobalFeatureImportance: + if "nsamples" not in kwargs: + kwargs["nsamples"] = 100 + instances = self.sampled_data if X is None else \ + self.transformer.transform(X.remove_target_column()) + + explanations = GlobalFeatureImportance() + explainer = shap.KernelExplainer( + self.predict_fn, self.background_data, + link="logit" if self.mode == "classification" else "identity", **kwargs + ) + shap_values = explainer.shap_values(instances, **kwargs) + + if self.mode == "classification": + values = 0 + for v in shap_values: + values += np.abs(v) + values /= len(shap_values) + shap_values = values + + importance_scores = np.mean(np.abs(shap_values), axis=0) + explanations.add( + feature_names=self.feature_columns, + importance_scores=importance_scores, + sort=True + ) + return explanations + + def explain( + self, + X: Tabular = None, + **kwargs + ): + """ + Generates the global SHAP explanations. + + :param X: The data will be used to compute global SHAP values, i.e., the mean of the absolute + SHAP value for each feature. If `X` is None, a set of training samples will be used. + :param kwargs: Additional parameters for `shap.KernelExplainer.shap_values`, + e.g., ``nsamples`` -- the number of times to re-evaluate the model when explaining each prediction. + :return: The global feature importance explanations. + """ + return self._explain_global(X=X, **kwargs) + + def save( + self, + directory: str, + filename: str = None, + **kwargs + ): + """ + Saves the initialized explainer. + + :param directory: The folder for the dumped explainer. + :param filename: The filename (the explainer class name if it is None). + """ + super().save( + directory=directory, + filename=filename, + ignored_attributes=["data"], + **kwargs + ) diff --git a/omnixai/tests/explainers/shap/test_shap_tabular_global.py b/omnixai/tests/explainers/shap/test_shap_tabular_global.py index 9b1d3938..caa3bd85 100644 --- a/omnixai/tests/explainers/shap/test_shap_tabular_global.py +++ b/omnixai/tests/explainers/shap/test_shap_tabular_global.py @@ -8,7 +8,7 @@ import unittest import pprint from omnixai.utils.misc import set_random_seed -from omnixai.explainers.tabular import ShapTabular +from omnixai.explainers.tabular import GlobalShapTabular from omnixai.tests.explainers.tasks import TabularClassification @@ -19,7 +19,7 @@ def test(self): predict_function = lambda z: task.model.predict_proba(task.transform.transform(z)) set_random_seed() - explainer = ShapTabular( + explainer = GlobalShapTabular( training_data=task.train_data, predict_function=predict_function, ignored_features=None, diff --git a/omnixai/tests/visualization/dashboard_tabular.py b/omnixai/tests/visualization/dashboard_tabular.py index c36c0e42..7bf626fd 100644 --- a/omnixai/tests/visualization/dashboard_tabular.py +++ b/omnixai/tests/visualization/dashboard_tabular.py @@ -98,7 +98,7 @@ def test(self): prediction_explanations = explainer.explain() explainers = TabularExplainer( - explainers=["lime", "shap", "mace", "knn_ce", "pdp", "ale"], + explainers=["lime", "shap", "mace", "knn_ce", "pdp", "ale", "shap_global"], mode="classification", data=self.tabular_data, model=self.model, From b35dacfee9d2dc76018cd26e0ed88819ce8590c5 Mon Sep 17 00:00:00 2001 From: ywz Date: Thu, 17 Nov 2022 11:28:35 +0800 Subject: [PATCH 12/15] Add global Shap values --- omnixai/explainers/tabular/agnostic/shap_global.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/omnixai/explainers/tabular/agnostic/shap_global.py b/omnixai/explainers/tabular/agnostic/shap_global.py index 8ff9ab69..2ee94b64 100644 --- a/omnixai/explainers/tabular/agnostic/shap_global.py +++ b/omnixai/explainers/tabular/agnostic/shap_global.py @@ -22,7 +22,7 @@ class GlobalShapTabular(TabularExplainer): If using this explainer, please cite the original work: https://github.com/slundberg/shap. """ - explanation_type = "both" + explanation_type = "global" alias = ["shap_global"] def __init__( From ce94c1556c387f316116fb87f8fc45176affc57c Mon Sep 17 00:00:00 2001 From: ywz Date: Thu, 17 Nov 2022 15:58:51 +0800 Subject: [PATCH 13/15] Revise dashboard figures --- omnixai/explanations/base.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/omnixai/explanations/base.py b/omnixai/explanations/base.py index e1748232..c8f4f36f 100644 --- a/omnixai/explanations/base.py +++ b/omnixai/explanations/base.py @@ -137,7 +137,16 @@ def to_html_div(self, id=None): elif isinstance(self.component, dash_table.DataTable): return html.Div([self.component], id=id) elif isinstance(self.component, plotly.graph_objs.Figure): - return html.Div([dcc.Graph(figure=self.component, id=id)], id=f"div_{id}") + height = self.component.layout.height + if height is None or height <= 450: + return html.Div( + [dcc.Graph(figure=self.component, id=id)], + id=f"div_{id}") + else: + return html.Div( + [dcc.Graph(figure=self.component, id=id)], + id=f"div_{id}", + style={"overflowY": "scroll", "height": 450}) else: raise ValueError(f"The type of `component` ({type(self.component)}) " f"" f"is not supported by DashFigure.") From f0240794e53674f6c90374ef1c91f5aa9104ba03 Mon Sep 17 00:00:00 2001 From: ywz Date: Thu, 17 Nov 2022 16:16:07 +0800 Subject: [PATCH 14/15] Revise SHAP parameters --- omnixai/explainers/tabular/agnostic/shap.py | 20 ++++++++------------ 1 file changed, 8 insertions(+), 12 deletions(-) diff --git a/omnixai/explainers/tabular/agnostic/shap.py b/omnixai/explainers/tabular/agnostic/shap.py index 535273f2..e283ed8f 100644 --- a/omnixai/explainers/tabular/agnostic/shap.py +++ b/omnixai/explainers/tabular/agnostic/shap.py @@ -22,7 +22,6 @@ class ShapTabular(TabularExplainer): The SHAP explainer for tabular data. If using this explainer, please cite the original work: https://github.com/slundberg/shap. """ - explanation_type = "local" alias = ["shap"] @@ -48,15 +47,18 @@ def __init__( Please refer to the doc of `shap.KernelExplainer`. """ super().__init__(training_data=training_data, predict_function=predict_function, mode=mode, **kwargs) + self.link = kwargs.get("link", None) + if self.link is None: + self.link = "logit" if self.mode == "classification" else "identity" + self.ignored_features = set(ignored_features) if ignored_features is not None else set() if self.target_column is not None: assert self.target_column not in self.ignored_features, \ f"The target column {self.target_column} cannot be in the ignored feature list." self.valid_indices = [i for i, f in enumerate(self.feature_columns) if f not in self.ignored_features] - if "nsamples" not in kwargs: - kwargs["nsamples"] = 100 - self.background_data = shap.sample(self.data, nsamples=kwargs["nsamples"]) + self.background_data = shap.sample(self.data, nsamples=kwargs.get("nsamples", 100)) + self.explainer = shap.KernelExplainer(self.predict_fn, self.background_data, link=self.link, **kwargs) def explain(self, X, y=None, **kwargs) -> FeatureImportance: """ @@ -91,12 +93,7 @@ def explain(self, X, y=None, **kwargs) -> FeatureImportance: y = None if len(self.ignored_features) == 0: - explainer = shap.KernelExplainer( - self.predict_fn, self.background_data, - link="logit" if self.mode == "classification" else "identity", **kwargs - ) - shap_values = explainer.shap_values(instances, **kwargs) - + shap_values = self.explainer.shap_values(instances, **kwargs) for i, instance in enumerate(instances): df = X.iloc(i).to_pd() feature_values = \ @@ -124,10 +121,9 @@ def _predict(_x): predict_function = _predict test_x = instance[self.valid_indices] - explainer = shap.KernelExplainer( predict_function, self.background_data[:, self.valid_indices], - link="logit" if self.mode == "classification" else "identity", **kwargs + link=self.link, **kwargs ) shap_values = explainer.shap_values(np.expand_dims(test_x, axis=0), **kwargs) From d3ebd0e3a4d0ddd79381aa84b8ad3119838289fd Mon Sep 17 00:00:00 2001 From: ywz Date: Thu, 17 Nov 2022 16:39:55 +0800 Subject: [PATCH 15/15] Update docs --- README.md | 1 + docs/omnixai.explainers.tabular.agnostic.rst | 16 ++++++++++++++++ 2 files changed, 17 insertions(+) diff --git a/README.md b/README.md index 01beadcb..356b5856 100644 --- a/README.md +++ b/README.md @@ -63,6 +63,7 @@ We will continue improving this library to make it more comprehensive in the fut | Partial dependence plots | Black box | Global | | ✅ | | | | | Accumulated local effects | Black box | Global | | ✅ | | | | | Sensitivity analysis | Black box | Global | | ✅ | | | | +| Permutation explanation | Black box | Global | | ✅ | | | | | Feature visualization | Torch or TF | Global | | | ✅ | | | | Feature maps | Torch or TF | Local | | | ✅ | | | | LIME | Black box | Local | | ✅ | ✅ | ✅ | | diff --git a/docs/omnixai.explainers.tabular.agnostic.rst b/docs/omnixai.explainers.tabular.agnostic.rst index 3597ced6..ee8dbfca 100644 --- a/docs/omnixai.explainers.tabular.agnostic.rst +++ b/docs/omnixai.explainers.tabular.agnostic.rst @@ -60,3 +60,19 @@ omnixai.explainers.tabular.agnostic.L2X.l2x module :members: :undoc-members: :show-inheritance: + +omnixai.explainers.tabular.agnostic.permutation module +------------------------------------------------------ + +.. automodule:: omnixai.explainers.tabular.agnostic.permutation + :members: + :undoc-members: + :show-inheritance: + +omnixai.explainers.tabular.agnostic.shap_global module +------------------------------------------------------ + +.. automodule:: omnixai.explainers.tabular.agnostic.shap_global + :members: + :undoc-members: + :show-inheritance: