Skip to content

Commit

Permalink
Merge pull request #74 from salesforce/revise_knn
Browse files Browse the repository at this point in the history
MACE allows generating CF examples without using KNN search
  • Loading branch information
yangwenz authored Feb 8, 2023
2 parents f373e93 + d7e252a commit 4656953
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 2 deletions.
10 changes: 8 additions & 2 deletions omnixai/explainers/tabular/counterfactual/mace/mace.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from ....tabular.base import TabularExplainerMixin
from .....data.tabular import Tabular

from .retrieval import CFRetrieval
from .retrieval import CFRetrieval, SimpleCFRetrieval
from .rl import RL
from .gld import GLD
from .greedy import Greedy
Expand All @@ -42,6 +42,7 @@ def __init__(
mode: str = "classification",
ignored_features: List = None,
method: str = "gld",
use_knn: bool = True,
**kwargs,
):
"""
Expand All @@ -54,6 +55,7 @@ def __init__(
are the class probabilities.
:param mode: The task type can be `classification` only.
:param ignored_features: The features ignored in generating counterfactual examples.
:param use_knn: Whether to use KNN search to find candidate features for generating counterfactual examples.
:param kwargs: Additional parameters used in `CFRetrieval` and `GLD`. For more information, please
refer to the classes `mace.retrieval.CFRetrieval` and `mace.gld.GLD`.
"""
Expand All @@ -68,7 +70,11 @@ def __init__(
self.target_column = training_data.target_column
self.original_feature_columns = training_data.columns

self.recall = CFRetrieval(training_data, predict_function, ignored_features, **kwargs)
if use_knn:
self.recall = CFRetrieval(training_data, predict_function, ignored_features, **kwargs)
else:
self.recall = SimpleCFRetrieval(training_data, ignored_features, **kwargs)

self.diversity = DiversityModule(training_data)
self.refinement = BinarySearchRefinement(training_data)
if method == "gld":
Expand Down
44 changes: 44 additions & 0 deletions omnixai/explainers/tabular/counterfactual/mace/retrieval.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from typing import List, Dict, Callable, Union

from .....data.tabular import Tabular
from .....preprocessing.base import Identity
from .....preprocessing.encode import OneHot, KBins
from .....preprocessing.pipeline import Pipeline
from .....preprocessing.tabular import TabularTransform
Expand Down Expand Up @@ -217,3 +218,46 @@ def get_cf_features(self, instance: Tabular, desired_label: int) -> (Dict, np.nd
if self.column_top_k > 0:
res = self._pick_top_columns(instance, res, desired_label, self.column_top_k)
return res, indices


class SimpleCFRetrieval:
"""
The class for extracting all the feature values in a dataset.
"""

def __init__(
self,
training_data: Tabular,
ignored_features: List = None,
num_cont_bins: int = 10,
**kwargs
):
"""
:param training_data: The training data.
:param ignored_features: The features ignored in generating counterfactual examples.
:param num_cont_bins: The number of bins for discretizing continuous-valued features.
:param kwargs: Other parameters.
"""
assert isinstance(training_data, Tabular), "`training_data` should be an instance of Tabular."
self.ignored_features = ignored_features if ignored_features is not None else []
subset = training_data.remove_target_column()

transformer = TabularTransform(
cate_transform=Identity(), cont_transform=KBins(n_bins=num_cont_bins)
).fit(subset)
df = transformer.invert(transformer.transform(subset)).to_pd(copy=False)

self.features = {}
for col in df.columns:
if col not in self.ignored_features:
self.features[col] = sorted(set(df[col].unique()))

def get_cf_features(self, instance: Tabular, desired_label: int) -> (Dict, None):
"""
Finds candidate features for generating counterfactual examples.
:param instance: The query instance.
:param desired_label: The desired label.
:return: The candidate features
"""
return self.features, None

0 comments on commit 4656953

Please sign in to comment.