-
Notifications
You must be signed in to change notification settings - Fork 42
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Feature/shap utils #391
Open
Alex6022
wants to merge
21
commits into
emdgroup:main
Choose a base branch
from
Alex6022:feature/shap-utils
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+360
−2
Open
Feature/shap utils #391
Changes from all commits
Commits
Show all changes
21 commits
Select commit
Hold shift + click to select a range
0c8e945
Optional import of shap package.
Alex6022 cbe2e82
1st implementation of SHAP utilities in experimental space and with p…
Alex6022 2597fd4
Implementation option to perform SHAP either in computational or expe…
Alex6022 bc1203e
SHAP package implementation in diagnostics utility, complete tests an…
Alex6022 b348b46
Tests for explainer utilities and generalization for all explainers i…
Alex6022 ae20322
Implemented plotting with non-shap attributions.
Alex6022 de9d1e9
Refactored diagnostics test and optimized handling of maple explainers.
Alex6022 e183957
Shortened plotting method names.
Alex6022 85fb9ba
Merge branch 'emdgroup:main' into feature/shap-utils
Alex6022 c389ac1
Cleanup for PR
Alex6022 55e723c
Renamed diangostics package, enabled optional shap import
Alex6022 1922467
Merge branch 'emdgroup:main' into feature/shap-utils
Alex6022 ee57008
Refactoring of test_diagnostics.py
Alex6022 11b61d1
Merge branch 'emdgroup:main' into feature/shap-utils
Alex6022 08a4c1e
Merge branch 'feature/shap-utils' of https://github.com/Alex6022/bayb…
Alex6022 50846f7
Merge branch 'feature/shap-utils' of https://github.com/Alex6022/bayb…
Alex6022 103a5f7
Fixed changelog merging error
Alex6022 ffda991
Update pyproject.toml
Scienfitz eaa5c38
Rework import flag
Scienfitz 4ca9ffd
Update mypy.ini
Scienfitz 9fddbdd
Rework tests
Scienfitz File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -296,6 +296,7 @@ The available groups are: | |||||
- `mypy`: Required for static type checking. | ||||||
- `onnx`: Required for using custom surrogate models in [ONNX format](https://onnx.ai). | ||||||
- `polars`: Required for optimized search space construction via [Polars](https://docs.pola.rs/) | ||||||
- `diagnostics`: Required for feature importance ranking via [SHAP](https://shap.readthedocs.io/) | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
- `simulation`: Enabling the [simulation](https://emdgroup.github.io/baybe/stable/_autosummary/baybe.simulation.html) module. | ||||||
- `test`: Required for running the tests. | ||||||
- `dev`: All of the above plus `tox` and `pip-audit`. For code contributors. | ||||||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
"""Optional import for diagnostics utilities.""" | ||
|
||
from baybe.exceptions import OptionalImportError | ||
|
||
try: | ||
import shap | ||
except ModuleNotFoundError as ex: | ||
raise OptionalImportError( | ||
"Explainer functionality is unavailable because 'diagnostics' is not installed." | ||
" Consider installing BayBE with 'diagnostics' dependency, e.g. via " | ||
"`pip install baybe[diagnostics]`." | ||
) from ex | ||
|
||
__all__ = [ | ||
"shap", | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,194 @@ | ||
"""Diagnostics utilities.""" | ||
|
||
import numbers | ||
import warnings | ||
|
||
import numpy as np | ||
import pandas as pd | ||
|
||
from baybe import Campaign | ||
from baybe._optional.diagnostics import shap | ||
from baybe.utils.dataframe import to_tensor | ||
|
||
|
||
def explainer( | ||
campaign: Campaign, | ||
explainer_class: shap.Explainer = shap.KernelExplainer, | ||
computational_representation: bool = False, | ||
**kwargs, | ||
) -> shap.Explainer: | ||
"""Create an explainer for the provided campaign. | ||
|
||
Args: | ||
campaign: The campaign to be explained. | ||
explainer_class: The explainer to be used. Default is shap.KernelExplainer. | ||
computational_representation: Whether to compute the Shapley values | ||
in computational or experimental searchspace. | ||
Default is False. | ||
**kwargs: Additional keyword arguments to be passed to the explainer. | ||
|
||
Returns: | ||
The explainer for the provided campaign. | ||
|
||
Raises: | ||
ValueError: If no measurements have been provided yet. | ||
""" | ||
if campaign.measurements.empty: | ||
raise ValueError("No measurements have been provided yet.") | ||
|
||
data = campaign.measurements[[p.name for p in campaign.parameters]].copy() | ||
|
||
if computational_representation: | ||
data = campaign.searchspace.transform(data) | ||
|
||
def model(x): | ||
tensor = to_tensor(x) | ||
output = campaign.get_surrogate()._posterior_comp(tensor).mean | ||
|
||
return output.detach().numpy() | ||
else: | ||
|
||
def model(x): | ||
df = pd.DataFrame(x, columns=data.columns) | ||
output = campaign.get_surrogate().posterior(df).mean | ||
|
||
return output.detach().numpy() | ||
|
||
shap_explainer = explainer_class(model, data, **kwargs) | ||
return shap_explainer | ||
|
||
|
||
def explanation( | ||
campaign: Campaign, | ||
data: np.ndarray = None, | ||
explainer_class: shap.Explainer = shap.KernelExplainer, | ||
computational_representation: bool = False, | ||
**kwargs, | ||
) -> shap.Explanation: | ||
"""Compute the Shapley values for the provided campaign and data. | ||
|
||
Args: | ||
campaign: The campaign to be explained. | ||
data: The data to be explained. | ||
Default is None which uses the campaign's measurements. | ||
explainer_class: The explainer to be used. | ||
Default is shap.KernelExplainer. | ||
computational_representation: Whether to compute the Shapley values | ||
in computational or experimental searchspace. | ||
Default is False. | ||
**kwargs: Additional keyword arguments to be passed to the explainer. | ||
|
||
Returns: | ||
The Shapley values for the provided campaign. | ||
|
||
Raises: | ||
ValueError: If the provided data does not have the same amount of parameters | ||
as previously provided to the explainer. | ||
""" | ||
is_shap_explainer = not explainer_class.__module__.startswith( | ||
"shap.explainers.other." | ||
) | ||
|
||
if not is_shap_explainer and not computational_representation: | ||
raise ValueError( | ||
"Experimental representation is not " | ||
"supported for non-Kernel SHAP explainer." | ||
) | ||
|
||
explainer_obj = explainer( | ||
campaign, | ||
explainer_class=explainer_class, | ||
computational_representation=computational_representation, | ||
**kwargs, | ||
) | ||
|
||
if data is None: | ||
if isinstance(explainer_obj.data, np.ndarray): | ||
data = explainer_obj.data | ||
else: | ||
data = explainer_obj.data.data | ||
elif computational_representation: | ||
data = campaign.searchspace.transform(data) | ||
|
||
if not is_shap_explainer: | ||
"""Return attributions for non-SHAP explainers.""" | ||
if explainer_class.__module__.endswith("maple"): | ||
"""Additional argument for maple to increase comparability to SHAP.""" | ||
attributions = explainer_obj.attributions(data, multiply_by_input=True)[0] | ||
else: | ||
attributions = explainer_obj.attributions(data)[0] | ||
if computational_representation: | ||
feature_names = campaign.searchspace.comp_rep_columns | ||
else: | ||
feature_names = campaign.searchspace.parameter_names | ||
explanations = shap.Explanation( | ||
values=attributions, | ||
base_values=data, | ||
data=data, | ||
) | ||
explanations.feature_names = list(feature_names) | ||
return explanations | ||
|
||
if data.shape[1] != explainer_obj.data.data.shape[1]: | ||
raise ValueError( | ||
"The provided data does not have the same amount " | ||
"of parameters as the shap explainer background." | ||
) | ||
else: | ||
shap_explanations = explainer_obj(data)[:, :, 0] | ||
|
||
return shap_explanations | ||
|
||
|
||
def plot_beeswarm(explanation: shap.Explanation, **kwargs) -> None: | ||
"""Plot the Shapley values using a beeswarm plot.""" | ||
shap.plots.beeswarm(explanation, **kwargs) | ||
|
||
|
||
def plot_waterfall(explanation: shap.Explanation, **kwargs) -> None: | ||
"""Plot the Shapley values using a waterfall plot.""" | ||
shap.plots.waterfall(explanation, **kwargs) | ||
|
||
|
||
def plot_bar(explanation: shap.Explanation, **kwargs) -> None: | ||
"""Plot the Shapley values using a bar plot.""" | ||
shap.plots.bar(explanation, **kwargs) | ||
|
||
|
||
def plot_scatter(explanation: shap.Explanation | memoryview, **kwargs) -> None: | ||
"""Plot the Shapley values using a scatter plot while leaving out string values. | ||
|
||
Args: | ||
explanation: The Shapley values to be plotted. | ||
**kwargs: Additional keyword arguments to be passed to the scatter plot. | ||
|
||
Raises: | ||
ValueError: If the provided explanation object does not match the | ||
required types. | ||
""" | ||
if isinstance(explanation, memoryview): | ||
data = explanation.obj | ||
elif isinstance(explanation, shap.Explanation): | ||
data = explanation.data.data.obj | ||
else: | ||
raise ValueError("The provided explanation argument is not of a valid type.") | ||
|
||
def is_not_numeric_column(col): | ||
return np.array([not isinstance(v, numbers.Number) for v in col]).any() | ||
|
||
if data.ndim == 1: | ||
if is_not_numeric_column(data): | ||
warnings.warn( | ||
"Cannot plot scatter plot for the provided " | ||
"explanation as it contains non-numeric values." | ||
) | ||
else: | ||
shap.plots.scatter(explanation, **kwargs) | ||
else: | ||
number_enum = [i for i, x in enumerate(data[1]) if not isinstance(x, str)] | ||
if len(number_enum) < len(explanation.feature_names): | ||
warnings.warn( | ||
"Cannot plot SHAP scatter plot for all " | ||
"parameters as some contain non-numeric values." | ||
) | ||
shap.plots.scatter(explanation[:, number_enum], **kwargs) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -10,6 +10,8 @@ addopts = | |
--ignore=baybe/_optional | ||
--ignore=baybe/utils/chemistry.py | ||
--ignore=tests/simulate_telemetry.py | ||
--ignore=baybe/utils/diagnostics.py | ||
--ignore=tests/utils/test_diagnostics.py | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. seems youre ignoring the tests you created? is this on purpose? |
||
testpaths = | ||
baybe | ||
tests |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.