Skip to content

Commit

Permalink
Merge pull request #46 from salesforce/deployment
Browse files Browse the repository at this point in the history
Deployment
  • Loading branch information
yangwenz authored Oct 25, 2022
2 parents 2adc636 + a5af847 commit 72989c2
Show file tree
Hide file tree
Showing 88 changed files with 1,835 additions and 126 deletions.
95 changes: 92 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,10 @@
3. [Getting Started](#getting-started)
4. [Documentation](https://opensource.salesforce.com/OmniXAI/latest/index.html)
5. [Tutorials](https://opensource.salesforce.com/OmniXAI/latest/tutorials.html)
6. [Dashboard Demo](https://sfr-omnixai-demo.herokuapp.com/)
7. [How to Contribute](https://opensource.salesforce.com/OmniXAI/latest/omnixai.html#how-to-contribute)
8. [Technical Report and Citing OmniXAI](#technical-report-and-citing-omnixai)
6. [Deployment](#deployment)
7. [Dashboard Demo](https://sfr-omnixai-demo.herokuapp.com/)
8. [How to Contribute](https://opensource.salesforce.com/OmniXAI/latest/omnixai.html#how-to-contribute)
9. [Technical Report and Citing OmniXAI](#technical-report-and-citing-omnixai)


## Introduction
Expand Down Expand Up @@ -319,6 +320,94 @@ to generate and visualize explanations. This figure shows a dashboard example of
and time-series anomaly detection:
![alt text](https://github.com/salesforce/OmniXAI/raw/main/docs/_static/demo_nlp_ts.gif)

## Deployment

The explainers in OmniXAI can be easily deployed via [BentoML](https://github.com/bentoml/BentoML).
BentoML is a popular open-source unified model serving framework, supporting multiple platforms including
AWS, GCP, Heroku, etc. We implemented the BentoML-format interfaces for OmniXAI so that users only need
few lines of code to deploy their selected explainers.

Let's take the income prediction task as an example. Given the trained model and the initialized explainer,
you only need to save the explainer in the BentoML local model store:

```python
from omnixai.explainers.tabular import TabularExplainer
from omnixai.deployment.bentoml.omnixai import save_model

explainer = TabularExplainer(
explainers=["lime", "shap", "mace", "pdp", "ale"],
mode="classification",
data=train_data,
model=model,
preprocess=lambda z: transformer.transform(z),
params={
"mace": {"ignored_features": ["Sex", "Race", "Relationship", "Capital Loss"]}
}
)
save_model("tabular_explainer", explainer)
```

And then create a file (e.g., service.py) for the ML service code:

```python
from omnixai.deployment.bentoml.omnixai import init_service

svc = init_service(
model_tag="tabular_explainer:latest",
task_type="tabular",
service_name="tabular_explainer"
)
```

The `init_service` function defines two API endpoints, i.e., `/predict` for model predictions and `/explain` for
generating explanations. You can start an API server locally to test the service code above:

```python
bentoml serve service:svc --reload
```

The endpoints can be accessed locally:

```python
import requests
from requests_toolbelt.multipart.encoder import MultipartEncoder

data = '["39", "State-gov", "77516", "Bachelors", "13", "Never-married", ' \
'"Adm-clerical", "Not-in-family", "White", "Male", "2174", "0", "40", "United-States"]'

# Test the prediction endpoint
prediction = requests.post(
"http://0.0.0.0:3000/predict",
headers={"content-type": "application/json"},
data=data
).text

# Test the explanation endpoint
m = MultipartEncoder(
fields={
"data": data,
"params": '{"lime": {"y": [0]}}',
}
)
result = requests.post(
"http://0.0.0.0:3000/explain",
headers={"Content-Type": m.content_type},
data=m
).text

# Parse the results
from omnixai.explainers.base import AutoExplainerBase
exp = AutoExplainerBase.parse_explanations_from_json(result)
for name, explanation in exp.items():
explanation.ipython_plot()
```

You can build Bento for deployment by following the steps shown in the
[BentoML repo](https://github.com/bentoml/BentoML#how-it-works). For more examples, please
check [Tabular](https://github.com/salesforce/OmniXAI/tree/main/omnixai/tests/deployment/bentoml/tabular),
[Vision](https://github.com/salesforce/OmniXAI/tree/main/omnixai/tests/deployment/bentoml/vision),
[NLP](https://github.com/salesforce/OmniXAI/tree/main/omnixai/tests/deployment/bentoml/nlp).

## How to Contribute

We welcome the contribution from the open-source community to improve the library!
Expand Down
Binary file modified docs/_static/demo_vision.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
17 changes: 17 additions & 0 deletions docs/omnixai.deployment.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
omnixai.deployment package
==========================

.. automodule:: omnixai.deployment
:members:
:undoc-members:
:show-inheritance:

This package is used to deploy OmniXAI explainers via BentoML.

omnixai.deployment.bentoml module
---------------------------------

.. automodule:: omnixai.deployment.bentoml
:members:
:undoc-members:
:show-inheritance:
8 changes: 8 additions & 0 deletions docs/omnixai.rst
Original file line number Diff line number Diff line change
Expand Up @@ -464,3 +464,11 @@ Dashboard for Visualization
:maxdepth: 4

omnixai.visualization

Explainer Deployment
--------------------

.. toctree::
:maxdepth: 4

omnixai.deployment
6 changes: 6 additions & 0 deletions omnixai/deployment/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
#
# Copyright (c) 2022 salesforce.com, inc.
# All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
# For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
#
6 changes: 6 additions & 0 deletions omnixai/deployment/bentoml/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
#
# Copyright (c) 2022 salesforce.com, inc.
# All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
# For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
#
222 changes: 222 additions & 0 deletions omnixai/deployment/bentoml/omnixai.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,222 @@
#
# Copyright (c) 2022 salesforce.com, inc.
# All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
# For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
#
from __future__ import annotations

import json
import os.path
import typing as t
from types import ModuleType

import bentoml
from bentoml import Tag
from bentoml.models import Model
from bentoml.models import ModelContext
from bentoml.exceptions import NotFound
from bentoml.exceptions import BentoMLException
from omnixai.explainers.base import AutoExplainerBase
from omnixai.utils.misc import get_pkg_version
from omnixai.explanations.utils import DefaultJsonEncoder

MODULE_NAME = "omnixai.deployment.bentoml.omnixai"
MODEL_PATH = "saved_model"
API_VERSION = "v1"


def get(tag_like: str | Tag) -> Model:
model = bentoml.models.get(tag_like)
if model.info.module not in (MODULE_NAME, __name__):
raise NotFound(
f"Model {model.tag} was saved with module {model.info.module}, not loading with {MODULE_NAME}."
)
return model


def load_model(
bento_model: str | Tag | Model,
) -> AutoExplainerBase:
"""
Load the OmniXAI explainer with the given tag from the local BentoML model store.
:param bento_model: Either the tag of the model to get from the store, or a BentoML `~bentoml.Model`
instance to load the model from.
:return: The dumped OmniXAI explainer, e.g., ``TabularExplainer``, ``VisionExplainer``.
"""
if not isinstance(bento_model, Model):
bento_model = get(bento_model)

if bento_model.info.module not in (MODULE_NAME, __name__):
raise BentoMLException(
f"Model {bento_model.tag} was saved with module {bento_model.info.module}, not loading with {MODULE_NAME}."
)
model_path = bento_model.path_of(MODEL_PATH)
return AutoExplainerBase.load(model_path)


def save_model(
name: str,
model: AutoExplainerBase,
*,
mode: str = "model_and_data",
signatures: t.Dict = None,
labels: t.Dict[str, str] | None = None,
custom_objects: t.Dict[str, t.Any] | None = None,
external_modules: t.List[ModuleType] | None = None,
metadata: t.Dict[str, t.Any] | None = None,
) -> bentoml.Model:
"""
Save an OmniXAI explainer to BentoML modelstore.
:param name: The name for given model instance. This should pass Python identifier check.
:param model: The explainer to be saved.
:param mode: How to save the explainers, i.e., "model_and_data" and "individual".
"model_and_data" for saving the model, preprocessing function, postprocessing function
and dataset for initialization, or "individual" for saving each initialized explainer
in the AutoExplainer. When there is no explainer that needs to train a post-hoc
explanation model (e.g., L2X) and the dataset for initialization is not too large,
"model_and_data" is a better option. Otherwise, "individual" is a proper option.
:param signatures: The methods to expose for running inference on the target model.
:param labels: User-defined labels for managing models, e.g. team=nlp, stage=dev.
:param custom_objects: User-defined additional python objects to be saved alongside the model.
:param external_modules: User-defined additional python modules to be saved alongside the model or custom objects.
:param metadata: Custom metadata for the given explainer.
:return: A bentoml model.
"""
if not isinstance(model, AutoExplainerBase):
raise TypeError(
f"Given model ({model}) is not an instance of ``AutoExplainerBase``."
)

context = ModelContext(
framework_name="omnixai",
framework_versions={"omnixai": get_pkg_version("omnixai")},
)

if signatures is None:
signatures = {
"predict": {"batchable": True},
"explain": {"batchable": False}
}

with bentoml.models.create(
name,
module=MODULE_NAME,
api_version=API_VERSION,
labels=labels,
custom_objects=custom_objects,
external_modules=external_modules,
metadata=metadata,
context=context,
signatures=signatures,
) as bento_model:
model.save(bento_model.path_of(MODEL_PATH), mode=mode)
return bento_model


def get_runnable(bento_model: Model):

class OmniXAIRunnable(bentoml.Runnable):
SUPPORTED_RESOURCES = ("cpu", "nvidia.com/gpu")
SUPPORTS_CPU_MULTI_THREADING = True

def __init__(self):
super().__init__()
self.model = load_model(bento_model)

def add_runnable_method(method_name, options):
def _run(self, *args, **kwargs):
results = getattr(self.model, method_name)(*args, **kwargs)
return results

OmniXAIRunnable.add_method(
_run,
name=method_name,
batchable=options.batchable,
batch_dim=options.batch_dim,
input_spec=options.input_spec,
output_spec=options.output_spec,
)

for method_name, options in bento_model.info.signatures.items():
add_runnable_method(method_name, options)

return OmniXAIRunnable


def init_service(
model_tag,
task_type,
service_name,
api_name=None,
api_doc=None,
api_route=None
):
"""
:param model_tag: The tag of the model.
:param task_type: The task type, e.g., tabular, vision or nlp.
:param service_name: The name of the service.
:param api_name: The base api name. For example, if `api_name = 'abc'`,
the name of the "predict" and "explain" apis will be "abc_predict" and "abc_explain", respectively.
:param api_doc: The api docs. It can either be a string or a tuple (list) of length 2. If it is a tuple (list),
the first element will be the doc of the "predict" api, and the second one will be the doc of the "explain" api.
:param api_route: The base api route. For example, if `api_route = '/abc'`,
the name of the "predict" and "explain" apis will be "/abc/predict" and "/abc/explain", respectively.
"""
from bentoml.io import NumpyNdarray, JSON, Text, Multipart
assert task_type in ["tabular", "vision", "nlp"], \
f"`task_type` should be 'tabular', 'vision' or 'nlp' other than {task_type}."

model = get(model_tag)
runner = model.to_runner()
svc = bentoml.Service(service_name, runners=[runner])

if api_doc is None:
predict_doc, explain_doc = None, None
elif isinstance(api_doc, str):
predict_doc, explain_doc = api_doc, api_doc
elif isinstance(api_doc, (list, tuple)):
assert len(api_doc) == 2, \
f"`api_doc` is a {type(api_doc)} but has {len(api_doc)} element(s), " \
f"which can only have 2 elements."
predict_doc, explain_doc = api_doc[0], api_doc[1]
else:
raise TypeError(f"`api_doc` should be str, list or tuple instead of {type(api_doc)}.")

if task_type == "tabular":
predict_input_spec = NumpyNdarray()
explain_input_spec = Multipart(data=NumpyNdarray(), params=JSON())
elif task_type == "vision":
predict_input_spec = NumpyNdarray()
explain_input_spec = Multipart(data=NumpyNdarray(), params=JSON())
elif task_type == "nlp":
predict_input_spec = Text()
explain_input_spec = Multipart(data=Text(), params=JSON())
else:
raise ValueError(f"Unknown `task_type`: {task_type}")

@svc.api(
input=predict_input_spec,
output=Text(),
name=f"{api_name}_predict" if api_name is not None else None,
doc=predict_doc,
route=os.path.join(api_route, "predict") if api_route is not None else None
)
def predict(data):
result = runner.predict.run(data)
return result.to_json()

@svc.api(
input=explain_input_spec,
output=Text(),
name=f"{api_name}_explain" if api_name is not None else None,
doc=explain_doc,
route=os.path.join(api_route, "explain") if api_route is not None else None
)
def explain(data, params):
result = runner.explain.run(data, params, run_predict=False)
return json.dumps(result, cls=DefaultJsonEncoder)

return svc
Loading

0 comments on commit 72989c2

Please sign in to comment.