Skip to content

Commit

Permalink
Merge pull request #37 from salesforce/revise_docs
Browse files Browse the repository at this point in the history
Revise docs
  • Loading branch information
yangwenz authored Sep 8, 2022
2 parents 4551fe3 + d9f924d commit 3164258
Show file tree
Hide file tree
Showing 12 changed files with 833 additions and 58 deletions.
108 changes: 79 additions & 29 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -47,31 +47,32 @@ OmniXAI includes a rich family of explanation methods integrated in a unified in
supports multiple data types (tabular data, images, texts, time-series), multiple types of ML models
(traditional ML in Scikit-learn and deep learning models in PyTorch/TensorFlow), and a range of diverse explaination
methods including "model-specific" and "model-agnostic" methods (such as feature-attribution explanation,
counterfactual explanation, gradient-based explanation, etc). For practitioners, OmniXAI provides an easy-to-use
counterfactual explanation, gradient-based explanation, feature visualization, etc). For practitioners, OmniXAI provides an easy-to-use
unified interface to generate the explanations for their applications by only writing a few lines of
codes, and also a GUI dashboard for visualization for obtaining more insights about decisions.

The following table shows the supported explanation methods and features in our library.
We will continue improving this library to make it more comprehensive in the future, e.g., supporting more
explanation methods for vision, NLP and time-series tasks.

| Method | Model Type | Explanation Type | EDA | Tabular | Image | Text | Timeseries |
:-----------------------:| :---: | :---: |:---:| :---: | :---: | :---: | :---:
| Feature analysis | NA | Global || | | | |
| Feature selection | NA | Global || | | | |
| Prediction metrics | Black box | Global | |||||
| Partial dependence plots | Black box | Global | || | | |
| Accumulated local effects | Black box | Global | || | | |
| Sensitivity analysis | Black box | Global | || | | |
| LIME | Black box | Local | |||| |
| SHAP | Black box* | Local | |||||
| Integrated gradient | Torch or TF | Local | |||| |
| Counterfactual | Black box* | Local | |||||
| Contrastive explanation | Torch or TF | Local | | || | |
| Grad-CAM, Grad-CAM++ | Torch or TF | Local | | || | |
| Learning to explain | Black box | Local | |||| |
| Linear models | Linear models | Global and Local | || | | |
| Tree models | Tree models | Global and Local | || | | |
We will continue improving this library to make it more comprehensive in the future.

| Method | Model Type | Explanation Type | EDA | Tabular | Image | Text | Timeseries |
:-------------------------:|:-------------:|:----------------:|:---:|:-------:|:-----:| :---: | :---:
| Feature analysis | NA | Global || | | | |
| Feature selection | NA | Global || | | | |
| Prediction metrics | Black box | Global | |||||
| Partial dependence plots | Black box | Global | || | | |
| Accumulated local effects | Black box | Global | || | | |
| Sensitivity analysis | Black box | Global | || | | |
| Feature visualization | Torch or TF | Global | | || | |
| LIME | Black box | Local | |||| |
| SHAP | Black box* | Local | |||||
| Integrated gradient | Torch or TF | Local | |||| |
| Counterfactual | Black box* | Local | |||||
| Contrastive explanation | Torch or TF | Local | | || | |
| Grad-CAM, Grad-CAM++ | Torch or TF | Local | | || | |
| Learning to explain | Black box | Local | |||| |
| Linear models | Linear models | Global and Local | || | | |
| Tree models | Tree models | Global and Local | || | | |
| Feature maps | Torch or TF | Local | | || | |

*SHAP* accepts black box models for tabular data, PyTorch/Tensorflow models for image data, transformer models
for text data. *Counterfactual* accepts black box models for tabular, text and time-series data, and PyTorch/Tensorflow models for
Expand Down Expand Up @@ -109,22 +110,29 @@ Some examples:
4. [Text classification](https://github.com/salesforce/OmniXAI/blob/main/tutorials/nlp_imdb.ipynb)
5. [Time-series anomaly detection](https://github.com/salesforce/OmniXAI/blob/main/tutorials/timeseries.ipynb)
6. [Vision-language tasks](https://github.com/salesforce/OmniXAI/blob/main/tutorials/vision/gradcam_vlm.ipynb)
7. [Ranking tasks](https://github.com/salesforce/OmniXAI/blob/main/tutorials/tabular/ranking.ipynb)
8. [Feature visualization](https://github.com/salesforce/OmniXAI/blob/main/tutorials/vision/feature_visualization_torch.ipynb)
9. [Check feature maps](https://github.com/salesforce/OmniXAI/blob/main/tutorials/vision/feature_map_torch.ipynb)

To get started, we recommend the linked tutorials in [tutorials](https://opensource.salesforce.com/OmniXAI/latest/tutorials.html).
In general, we recommend using `TabularExplainer`, `VisionExplainer`,
`NLPExplainer` and `TimeseriesExplainer` for tabular, vision, NLP and time-series tasks, respectively, and using
`DataAnalyzer` and `PredictionAnalyzer` for feature analysis and prediction result analysis.
To generate explanations, one only needs to specify
These classes act as the factories of the individual explainers supported in OmniXAI, providing a simpler
interface to generate multiple explanations. To generate explanations, you only need to specify

- **The ML model to explain**: e.g., a scikit-learn model, a tensorflow model, a pytorch model or a black-box prediction function.
- **The pre-processing function**: i.e., converting raw input features into the model inputs.
- **The post-processing function (optional)**: e.g., converting the model outputs into class probabilities.
- **The explainers to apply**: e.g., SHAP, MACE, Grad-CAM.

Besides using these classes, you can also create a single explainer defined in the `omnixai.explainers` package, e.g.,
`ShapTabular`, `GradCAM`, `IntegratedGradient` or `FeatureVisualizer`.

Let's take the income prediction task as an example.
The [dataset](https://archive.ics.uci.edu/ml/datasets/adult) used in this example is for income prediction.
We recommend using data class `Tabular` to represent a tabular dataset. To create a `Tabular` instance given a pandas
dataframe, one needs to specify the dataframe, the categorical feature names (if exists) and the target/label
dataframe, you need to specify the dataframe, the categorical feature names (if exists) and the target/label
column name (if exists).

```python
Expand Down Expand Up @@ -152,8 +160,8 @@ for a `Tabular` instance. `TabularTransform` is a special transform designed for
By default, it converts categorical features into one-hot encoding, and keeps continuous-valued features.
The method ``transform`` of `TabularTransform` transforms a `Tabular` instance to a numpy array.
If the `Tabular` instance has a target/label column, the last column of the numpy array
will be the target/label. One can also apply any customized preprocessing functions instead of using `TabularTransform`.
After data preprocessing, we train a XGBoost classifier for this task.
will be the target/label. You can apply any customized preprocessing functions instead of using `TabularTransform`.
After data preprocessing, let's train a XGBoost classifier for this task.

```python
from omnixai.preprocessing.tabular import TabularTransform
Expand All @@ -172,7 +180,7 @@ train_data = transformer.invert(train)
test_data = transformer.invert(test)
```

To initialize `TabularExplainer`, we need to set the following parameters:
To initialize `TabularExplainer`, the following parameters need to be set:

- ``explainers``: The names of the explainers to apply, e.g., ["lime", "shap", "mace", "pdp"].
- ``data``: The data used to initialize explainers. ``data`` is the training dataset for training the
Expand All @@ -185,8 +193,8 @@ To initialize `TabularExplainer`, we need to set the following parameters:
- ``mode``: The task type, e.g., "classification" or "regression".

The preprocessing function takes a `Tabular` instance as its input and outputs the processed features that
the ML model consumes. In this example, we simply call ``transformer.transform``. If one uses some customized transforms
on pandas dataframes, the preprocess function has format: `lambda z: some_transform(z.to_pd())`. If the output of ``model``
the ML model consumes. In this example, we simply call ``transformer.transform``. If you use some customized transforms
on pandas dataframes, the preprocess function has this format: `lambda z: some_transform(z.to_pd())`. If the output of ``model``
is not a numpy array, ``postprocess`` needs to be set to convert it into a numpy array.

```python
Expand Down Expand Up @@ -222,7 +230,7 @@ global_explanations = explainers.explain_global(
```

Similarly, we create a `PredictionAnalyzer` for computing performance metrics for this classification task.
To initialize `PredictionAnalyzer`, we set the following parameters:
To initialize `PredictionAnalyzer`, the following parameters need to be set:

- `mode`: The task type, e.g., "classification" or "regression".
- `test_data`: The test dataset, which should be a `Tabular` instance.
Expand Down Expand Up @@ -265,6 +273,48 @@ dashboard.show() # Launch the dashboard
After opening the Dash app in the browser, we will see a dashboard showing the explanations:
![alt text](https://github.com/salesforce/OmniXAI/raw/main/docs/_static/demo.gif)

For vision tasks, the same interface is used to create explainers and generate explanations.
Let's take an image classification model as an example.

```python
from omnixai.explainers.vision import VisionExplainer
from omnixai.visualization.dashboard import Dashboard

explainer = VisionExplainer(
explainers=["gradcam", "lime", "ig", "ce", "feature_visualization"],
mode="classification",
model=model, # An image classification model, e.g., ResNet50
preprocess=preprocess, # The preprocessing function
postprocess=postprocess, # The postprocessing function
params={
# Set the target layer for GradCAM
"gradcam": {"target_layer": model.layer4[-1]},
# Set the objective for feature visualization
"feature_visualization":
{"objectives": [{"layer": model.layer4[-3], "type": "channel", "index": list(range(6))}]}
},
)
# Generate explanations of GradCAM, LIME, IG and CE
local_explanations = explainer.explain(test_img)
# Generate explanations of feature visualization
global_explanations = explainer.explain_global()
# Launch the dashboard
dashboard = Dashboard(
instances=test_img,
local_explanations=local_explanations,
global_explanations=global_explanations
)
dashboard.show()
```

The following figure shows the dashboard of these explanations:
![alt text](https://github.com/salesforce/OmniXAI/raw/main/docs/_static/demo_vision.gif)

For NLP tasks and time-series forecasting/anomaly detection, OmniXAI also provides the same interface
to generate and visualize explanations. This figure shows a dashboard example of text classification
and time-series anomaly detection:
![alt text](https://github.com/salesforce/OmniXAI/raw/main/docs/_static/demo_nlp_ts.gif)

## How to Contribute

We welcome the contribution from the open-source community to improve the library!
Expand Down
Binary file modified docs/_static/demo.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/_static/demo_nlp_ts.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/_static/demo_vision.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
59 changes: 31 additions & 28 deletions docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ Prediction metrics Black box Global
PDP Black box Global ✓
ALE Black box Global ✓
Sensitivity analysis Black box Global ✓
Feature visualization Torch or TF Global ✓
LIME Black box Local ✓ ✓ ✓
SHAP Black box* Local ✓ ✓ ✓ ✓
Integrated gradient Torch or TF Local ✓ ✓ ✓
Expand All @@ -61,6 +62,7 @@ Grad-CAM, Grad-CAM++ Torch or TF Local
Learning to explain Black box Local ✓ ✓ ✓
Linear models Linear models Global and Local ✓
Tree models Tree models Global and Local ✓
Feature maps Torch or TF Local ✓
======================= ==================== ================ ============= ======= ======= ======= ==========

*SHAP* accepts black box models for tabular data, PyTorch/Tensorflow models for image data, transformer models
Expand All @@ -73,34 +75,35 @@ Comparison with Competitors
The following table shows the comparison between our toolkit/library and other existing XAI toolkits/libraries
in literature:

========== ==================== ======= =========== ====== ==== ====== ===== ========
Data Type Method OmniXAI InterpretML AIX360 Eli5 Captum Alibi explainX
========== ==================== ======= =========== ====== ==== ====== ===== ========
Tabular LIME ✓ ✓ ✓ ✘ ✓ ✘ ✘
\ SHAP ✓ ✓ ✓ ✘ ✓ ✓ ✓
\ PDP ✓ ✓ ✘ ✘ ✘ ✘ ✘
\ ALE ✓ ✘ ✘ ✘ ✘ ✓ ✘
\ Sensitivity ✓ ✓ ✘ ✘ ✘ ✘ ✘
\ Integrated gradient ✓ ✘ ✘ ✘ ✓ ✓ ✘
\ Counterfactual ✓ ✘ ✘ ✘ ✘ ✓ ✘
\ Linear models ✓ ✓ ✓ ✓ ✘ ✓ ✓
\ Tree models ✓ ✓ ✓ ✓ ✘ ✓ ✓
\ L2X ✓ ✘ ✘ ✘ ✘ ✘ ✘
Image LIME ✓ ✘ ✘ ✘ ✓ ✘ ✘
\ SHAP ✓ ✘ ✘ ✘ ✓ ✘ ✘
\ Integrated gradient ✓ ✘ ✘ ✘ ✓ ✓ ✘
\ Grad-CAM, Grad-CAM++ ✓ ✘ ✘ ✓ ✓ ✘ ✘
\ Contrastive ✓ ✘ ✓ ✘ ✘ ✓ ✘
\ Counterfactual ✓ ✘ ✘ ✘ ✘ ✓ ✘
\ L2X ✓ ✘ ✘ ✘ ✘ ✘ ✘
Text LIME ✓ ✘ ✘ ✓ ✓ ✘ ✘
\ SHAP ✓ ✘ ✘ ✘ ✓ ✘ ✘
\ Integrated gradient ✓ ✘ ✘ ✘ ✓ ✓ ✘
\ L2X ✓ ✘ ✘ ✘ ✘ ✘ ✘
\ Counterfactual ✓ ✘ ✘ ✘ ✘ ✘ ✘
Timeseries SHAP ✓ ✘ ✘ ✘ ✘ ✘ ✘
\ Counterfactual ✓ ✘ ✘ ✘ ✘ ✘ ✘
========== ==================== ======= =========== ====== ==== ====== ===== ========
========== ===================== ======= =========== ====== ==== ====== ===== ========
Data Type Method OmniXAI InterpretML AIX360 Eli5 Captum Alibi explainX
========== ===================== ======= =========== ====== ==== ====== ===== ========
Tabular LIME ✓ ✓ ✓ ✘ ✓ ✘ ✘
\ SHAP ✓ ✓ ✓ ✘ ✓ ✓ ✓
\ PDP ✓ ✓ ✘ ✘ ✘ ✘ ✘
\ ALE ✓ ✘ ✘ ✘ ✘ ✓ ✘
\ Sensitivity ✓ ✓ ✘ ✘ ✘ ✘ ✘
\ Integrated gradient ✓ ✘ ✘ ✘ ✓ ✓ ✘
\ Counterfactual ✓ ✘ ✘ ✘ ✘ ✓ ✘
\ Linear models ✓ ✓ ✓ ✓ ✘ ✓ ✓
\ Tree models ✓ ✓ ✓ ✓ ✘ ✓ ✓
\ L2X ✓ ✘ ✘ ✘ ✘ ✘ ✘
Image LIME ✓ ✘ ✘ ✘ ✓ ✘ ✘
\ SHAP ✓ ✘ ✘ ✘ ✓ ✘ ✘
\ Integrated gradient ✓ ✘ ✘ ✘ ✓ ✓ ✘
\ Grad-CAM, Grad-CAM++ ✓ ✘ ✘ ✓ ✓ ✘ ✘
\ Contrastive ✓ ✘ ✓ ✘ ✘ ✓ ✘
\ Counterfactual ✓ ✘ ✘ ✘ ✘ ✓ ✘
\ L2X ✓ ✘ ✘ ✘ ✘ ✘ ✘
\ Feature visualization ✓ ✘ ✘ ✘ ✘ ✘ ✘
Text LIME ✓ ✘ ✘ ✓ ✓ ✘ ✘
\ SHAP ✓ ✘ ✘ ✘ ✓ ✘ ✘
\ Integrated gradient ✓ ✘ ✘ ✘ ✓ ✓ ✘
\ L2X ✓ ✘ ✘ ✘ ✘ ✘ ✘
\ Counterfactual ✓ ✘ ✘ ✘ ✘ ✘ ✘
Timeseries SHAP ✓ ✘ ✘ ✘ ✘ ✘ ✘
\ Counterfactual ✓ ✘ ✘ ✘ ✘ ✘ ✘
========== ===================== ======= =========== ====== ==== ====== ===== ========

Installation
############
Expand Down
Loading

0 comments on commit 3164258

Please sign in to comment.