Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/shap utils #391

Open
wants to merge 21 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
0c8e945
Optional import of shap package.
Alex6022 Sep 29, 2024
cbe2e82
1st implementation of SHAP utilities in experimental space and with p…
Alex6022 Sep 29, 2024
2597fd4
Implementation option to perform SHAP either in computational or expe…
Alex6022 Oct 1, 2024
bc1203e
SHAP package implementation in diagnostics utility, complete tests an…
Alex6022 Oct 3, 2024
b348b46
Tests for explainer utilities and generalization for all explainers i…
Alex6022 Oct 3, 2024
ae20322
Implemented plotting with non-shap attributions.
Alex6022 Oct 3, 2024
de9d1e9
Refactored diagnostics test and optimized handling of maple explainers.
Alex6022 Oct 4, 2024
e183957
Shortened plotting method names.
Alex6022 Oct 4, 2024
85fb9ba
Merge branch 'emdgroup:main' into feature/shap-utils
Alex6022 Oct 4, 2024
c389ac1
Cleanup for PR
Alex6022 Oct 4, 2024
55e723c
Renamed diangostics package, enabled optional shap import
Alex6022 Oct 23, 2024
1922467
Merge branch 'emdgroup:main' into feature/shap-utils
Alex6022 Oct 23, 2024
ee57008
Refactoring of test_diagnostics.py
Alex6022 Oct 27, 2024
11b61d1
Merge branch 'emdgroup:main' into feature/shap-utils
Alex6022 Oct 28, 2024
08a4c1e
Merge branch 'feature/shap-utils' of https://github.com/Alex6022/bayb…
Alex6022 Oct 28, 2024
50846f7
Merge branch 'feature/shap-utils' of https://github.com/Alex6022/bayb…
Alex6022 Oct 28, 2024
103a5f7
Fixed changelog merging error
Alex6022 Oct 28, 2024
ffda991
Update pyproject.toml
Scienfitz Nov 1, 2024
eaa5c38
Rework import flag
Scienfitz Nov 1, 2024
4ca9ffd
Update mypy.ini
Scienfitz Nov 1, 2024
9fddbdd
Rework tests
Scienfitz Nov 1, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## [Unreleased]
### Added
- Added SHAP analysis within the new `diagnostics` package.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
- Added SHAP analysis within the new `diagnostics` package.
- `diagnostics` dependency group
- SHAP explanations

- `allow_missing` and `allow_extra` keyword arguments to `Objective.transform`

### Deprecations
Expand Down
4 changes: 3 additions & 1 deletion CONTRIBUTORS.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,4 +25,6 @@
- Di Jin (Merck Life Science KGaA, Darmstadt, Germany):\
Cardinality constraints
- Julian Streibel (Merck Life Science KGaA, Darmstadt, Germany):\
Bernoulli multi-armed bandit and Thompson sampling
Bernoulli multi-armed bandit and Thompson sampling
- Alexander Wieczorek (Swiss Federal Institute for Materials Science and Technology, Dübendorf, Switzerland):\
SHAP explainers for diagnoatics
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,7 @@ The available groups are:
- `mypy`: Required for static type checking.
- `onnx`: Required for using custom surrogate models in [ONNX format](https://onnx.ai).
- `polars`: Required for optimized search space construction via [Polars](https://docs.pola.rs/)
- `diagnostics`: Required for feature importance ranking via [SHAP](https://shap.readthedocs.io/)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
- `diagnostics`: Required for feature importance ranking via [SHAP](https://shap.readthedocs.io/)
- `diagnostics`: Required for built-in model and campaign analysis, e.g. [SHAP](https://shap.readthedocs.io/)

- `simulation`: Enabling the [simulation](https://emdgroup.github.io/baybe/stable/_autosummary/baybe.simulation.html) module.
- `test`: Required for running the tests.
- `dev`: All of the above plus `tox` and `pip-audit`. For code contributors.
Expand Down
16 changes: 16 additions & 0 deletions baybe/_optional/diagnostics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
"""Optional import for diagnostics utilities."""

from baybe.exceptions import OptionalImportError

try:
import shap
except ModuleNotFoundError as ex:
raise OptionalImportError(
"Explainer functionality is unavailable because 'diagnostics' is not installed."
" Consider installing BayBE with 'diagnostics' dependency, e.g. via "
"`pip install baybe[diagnostics]`."
) from ex

__all__ = [
"shap",
]
2 changes: 2 additions & 0 deletions baybe/_optional/info.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ def exclude_sys_path(path: str, /): # noqa: DOC402, DOC404
MORDRED_INSTALLED = find_spec("mordred") is not None
ONNX_INSTALLED = find_spec("onnxruntime") is not None
POLARS_INSTALLED = find_spec("polars") is not None
SHAP_INSTALLED = find_spec("shap") is not None
PRE_COMMIT_INSTALLED = find_spec("pre_commit") is not None
PYDOCLINT_INSTALLED = find_spec("pydoclint") is not None
RDKIT_INSTALLED = find_spec("rdkit") is not None
Expand All @@ -45,6 +46,7 @@ def exclude_sys_path(path: str, /): # noqa: DOC402, DOC404

# Package combinations
CHEM_INSTALLED = MORDRED_INSTALLED and RDKIT_INSTALLED
DIAGNOSTICS_INSTALLED = SHAP_INSTALLED
LINT_INSTALLED = all(
(
FLAKE8_INSTALLED,
Expand Down
194 changes: 194 additions & 0 deletions baybe/utils/diagnostics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,194 @@
"""Diagnostics utilities."""

import numbers
import warnings

import numpy as np
import pandas as pd

from baybe import Campaign
from baybe._optional.diagnostics import shap
from baybe.utils.dataframe import to_tensor


def explainer(
campaign: Campaign,
explainer_class: shap.Explainer = shap.KernelExplainer,
computational_representation: bool = False,
**kwargs,
) -> shap.Explainer:
"""Create an explainer for the provided campaign.

Args:
campaign: The campaign to be explained.
explainer_class: The explainer to be used. Default is shap.KernelExplainer.
computational_representation: Whether to compute the Shapley values
in computational or experimental searchspace.
Default is False.
**kwargs: Additional keyword arguments to be passed to the explainer.

Returns:
The explainer for the provided campaign.

Raises:
ValueError: If no measurements have been provided yet.
"""
if campaign.measurements.empty:
raise ValueError("No measurements have been provided yet.")

data = campaign.measurements[[p.name for p in campaign.parameters]].copy()

if computational_representation:
data = campaign.searchspace.transform(data)

def model(x):
tensor = to_tensor(x)
output = campaign.get_surrogate()._posterior_comp(tensor).mean

return output.detach().numpy()
else:

def model(x):
df = pd.DataFrame(x, columns=data.columns)
output = campaign.get_surrogate().posterior(df).mean

return output.detach().numpy()

shap_explainer = explainer_class(model, data, **kwargs)
return shap_explainer


def explanation(
campaign: Campaign,
data: np.ndarray = None,
explainer_class: shap.Explainer = shap.KernelExplainer,
computational_representation: bool = False,
**kwargs,
) -> shap.Explanation:
"""Compute the Shapley values for the provided campaign and data.

Args:
campaign: The campaign to be explained.
data: The data to be explained.
Default is None which uses the campaign's measurements.
explainer_class: The explainer to be used.
Default is shap.KernelExplainer.
computational_representation: Whether to compute the Shapley values
in computational or experimental searchspace.
Default is False.
**kwargs: Additional keyword arguments to be passed to the explainer.

Returns:
The Shapley values for the provided campaign.

Raises:
ValueError: If the provided data does not have the same amount of parameters
as previously provided to the explainer.
"""
is_shap_explainer = not explainer_class.__module__.startswith(
"shap.explainers.other."
)

if not is_shap_explainer and not computational_representation:
raise ValueError(
"Experimental representation is not "
"supported for non-Kernel SHAP explainer."
)

explainer_obj = explainer(
campaign,
explainer_class=explainer_class,
computational_representation=computational_representation,
**kwargs,
)

if data is None:
if isinstance(explainer_obj.data, np.ndarray):
data = explainer_obj.data
else:
data = explainer_obj.data.data
elif computational_representation:
data = campaign.searchspace.transform(data)

if not is_shap_explainer:
"""Return attributions for non-SHAP explainers."""
if explainer_class.__module__.endswith("maple"):
"""Additional argument for maple to increase comparability to SHAP."""
attributions = explainer_obj.attributions(data, multiply_by_input=True)[0]
else:
attributions = explainer_obj.attributions(data)[0]
if computational_representation:
feature_names = campaign.searchspace.comp_rep_columns
else:
feature_names = campaign.searchspace.parameter_names
explanations = shap.Explanation(
values=attributions,
base_values=data,
data=data,
)
explanations.feature_names = list(feature_names)
return explanations

if data.shape[1] != explainer_obj.data.data.shape[1]:
raise ValueError(
"The provided data does not have the same amount "
"of parameters as the shap explainer background."
)
else:
shap_explanations = explainer_obj(data)[:, :, 0]

return shap_explanations


def plot_beeswarm(explanation: shap.Explanation, **kwargs) -> None:
"""Plot the Shapley values using a beeswarm plot."""
shap.plots.beeswarm(explanation, **kwargs)


def plot_waterfall(explanation: shap.Explanation, **kwargs) -> None:
"""Plot the Shapley values using a waterfall plot."""
shap.plots.waterfall(explanation, **kwargs)


def plot_bar(explanation: shap.Explanation, **kwargs) -> None:
"""Plot the Shapley values using a bar plot."""
shap.plots.bar(explanation, **kwargs)


def plot_scatter(explanation: shap.Explanation | memoryview, **kwargs) -> None:
"""Plot the Shapley values using a scatter plot while leaving out string values.

Args:
explanation: The Shapley values to be plotted.
**kwargs: Additional keyword arguments to be passed to the scatter plot.

Raises:
ValueError: If the provided explanation object does not match the
required types.
"""
if isinstance(explanation, memoryview):
data = explanation.obj
elif isinstance(explanation, shap.Explanation):
data = explanation.data.data.obj
else:
raise ValueError("The provided explanation argument is not of a valid type.")

def is_not_numeric_column(col):
return np.array([not isinstance(v, numbers.Number) for v in col]).any()

if data.ndim == 1:
if is_not_numeric_column(data):
warnings.warn(
"Cannot plot scatter plot for the provided "
"explanation as it contains non-numeric values."
)
else:
shap.plots.scatter(explanation, **kwargs)
else:
number_enum = [i for i, x in enumerate(data[1]) if not isinstance(x, str)]
if len(number_enum) < len(explanation.feature_names):
warnings.warn(
"Cannot plot SHAP scatter plot for all "
"parameters as some contain non-numeric values."
)
shap.plots.scatter(explanation[:, number_enum], **kwargs)
3 changes: 3 additions & 0 deletions mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -64,3 +64,6 @@ ignore_missing_imports = True

[mypy-polars]
ignore_missing_imports = True

[mypy-shap.*]
ignore_missing_imports = True
6 changes: 6 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ onnx = [

dev = [
"baybe[chem]",
"baybe[diagnostics]",
"baybe[docs]",
"baybe[examples]",
"baybe[lint]",
Expand All @@ -94,6 +95,11 @@ dev = [
"uv>=0.3.0", # `uv lock` (for lockfiles) is stable since 0.3.0: https://github.com/astral-sh/uv/issues/2679#event-13950215962
]

diagnostics = [
"shap>=0.46.0",
"lime>=0.2.0.1"
]

docs = [
"baybe[examples]", # docs cannot be built without running examples
"furo>=2023.09.10",
Expand Down
2 changes: 2 additions & 0 deletions pytest.ini
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ addopts =
--ignore=baybe/_optional
--ignore=baybe/utils/chemistry.py
--ignore=tests/simulate_telemetry.py
--ignore=baybe/utils/diagnostics.py
--ignore=tests/utils/test_diagnostics.py
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

seems youre ignoring the tests you created? is this on purpose?

testpaths =
baybe
tests
Loading
Loading