Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Return interpret data #758

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 37 additions & 7 deletions bambi/interpret/effects.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
# pylint: disable=consider-iterating-dictionary
# pylint: disable=too-many-instance-attributes
# pylint: disable=ungrouped-imports
from dataclasses import dataclass, field
import itertools
Expand All @@ -15,6 +17,7 @@
average_over,
ConditionalInfo,
enforce_dtypes,
get_posterior,
identity,
merge,
VariableInfo,
Expand Down Expand Up @@ -122,8 +125,6 @@ def __post_init__(self):
).flatten()


# pylint: disable=consider-iterating-dictionary
# pylint: disable=too-many-instance-attributes
@dataclass
class PredictiveDifferences:
"""Computes predictive differences and their uncertainty intervals for
Expand Down Expand Up @@ -439,6 +440,7 @@ def predictions(
prob=None,
transforms=None,
sample_new_groups=False,
return_idata: bool = False,
) -> pd.DataFrame:
"""Compute Conditional Adjusted Predictions

Expand Down Expand Up @@ -473,11 +475,16 @@ def predictions(
sample_new_groups : bool, optional
If the model contains group-level effects, and data is passed for unseen groups, whether
to sample from the new groups. Defaults to ``False``.
return_idata : bool, optional
Copy link
Collaborator

@tomicapretto tomicapretto Nov 15, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we use a different name for this? From the name, I assumed the returned object would be an instance of InferenceData, but I see it's a data frame. Maybe return_data? Or, are you thinking we would be able to return an InferenceData instance in the future?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe "return_posterior_draws_dataframe".

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@zwelitunyiswa thanks for the suggestion! Tomas and I are looking into returning the InferenceData object after all. So stay tuned to this PR 😄

Whether to return the inference data from the InferenceData object with the predictions
and data used to generate those predictions. Defaults to ``False``.

Returns
-------
cap_data : pandas.DataFrame
A DataFrame with the ``create_cap_data`` and model predictions.
A DataFrame with the ``create_cap_data`` and model predictions. If ``return_data`` is
``True``, then this DataFrame also includes inference data, observed data, and parameter
estimates.

Raises
------
Expand Down Expand Up @@ -535,6 +542,10 @@ def predictions(
y_hat = response_transform(idata["posterior"][response.name_target])
y_hat_mean = y_hat.mean(("chain", "draw"))

# early return to avoid the computation below
if return_idata:
return get_posterior(response.name_obs, idata, cap_data)

if use_hdi and pps:
y_hat_bounds = az.hdi(y_hat, prob)[response.name].T
elif use_hdi:
Expand Down Expand Up @@ -583,6 +594,7 @@ def comparisons(
prob: Union[float, None] = None,
transforms: Union[dict, None] = None,
sample_new_groups: bool = False,
return_idata: bool = False,
) -> pd.DataFrame:
"""Compute Conditional Adjusted Comparisons

Expand Down Expand Up @@ -615,12 +627,16 @@ def comparisons(
sample_new_groups : bool, optional
If the model contains group-level effects, and data is passed for unseen groups, whether
to sample from the new groups. Defaults to ``False``.
return_idata : bool, optional
Whether to return the inference data from the InferenceData object with the predictions
and data used to generate those predictions. Defaults to ``False``.

Returns
-------
pandas.DataFrame
A dataframe with the comparison values, highest density interval, contrast name,
contrast value, and conditional values.
A dataframe with the comparison values, highest density interval, contrast name, contrast
value, and conditional values. If ``return_data`` is ``True``, then this DataFrame also
includes inference data, observed data, and parameter estimates.

Raises
------
Expand Down Expand Up @@ -695,6 +711,11 @@ def comparisons(
idata, data=comparisons_data, sample_new_groups=sample_new_groups, inplace=False
)

# early return since 'PredictiveDifferences' does not need to be called
if return_idata:
# return get_posterior(response.name_obs, idata, comparisons_data)
return comparisons_data, idata

# returns empty array if model predictions do not have multiple dimensions
response_dim_key = response.name + "_dim"
if response_dim_key in idata.posterior.coords:
Expand Down Expand Up @@ -733,6 +754,7 @@ def slopes(
prob: Union[float, None] = None,
transforms: Union[dict, None] = None,
sample_new_groups: bool = False,
return_idata: bool = False,
) -> pd.DataFrame:
"""Compute Conditional Adjusted Slopes

Expand Down Expand Up @@ -776,12 +798,16 @@ def slopes(
sample_new_groups : bool, optional
If the model contains group-level effects, and data is passed for unseen groups, whether
to sample from the new groups. Defaults to ``False``.
return_idata : bool, optional
Whether to return the inference data from the InferenceData object with the predictions
and data used to generate those predictions. Defaults to ``False``.

Returns
-------
pandas.DataFrame
A dataframe with the comparison values, highest density interval, ``wrt`` name,
contrast value, and conditional values.
A dataframe with the slope values, highest density interval, with respect to name, with
respect to value, and conditional values. If ``return_data`` is ``True``, then this
DataFrame also includes inference data, observed data, and parameter estimates.

Raises
------
Expand Down Expand Up @@ -859,6 +885,10 @@ def slopes(
idata, data=slopes_data, sample_new_groups=sample_new_groups, inplace=False
)

# early return since 'PredictiveDifferences' does not need to be called
if return_idata:
return get_posterior(response.name_obs, idata, slopes_data)

# returns empty array if model predictions do not have multiple dimensions
response_dim_key = response.name + "_dim"
if response_dim_key in idata.posterior.coords:
Expand Down
36 changes: 26 additions & 10 deletions bambi/interpret/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from statistics import mode
from typing import Union

import arviz as az
import numpy as np
from formulae.terms.call import Call
import pandas as pd
Expand Down Expand Up @@ -393,16 +394,10 @@ def make_group_values(x: np.ndarray, groups_n: int = 5) -> np.ndarray:


def get_group_offset(n, lower: float = 0.05, upper: float = 0.4) -> np.ndarray:
# Complementary log log function, scaled.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for removing these comments :D

# See following code to have an idea of how this function looks like
# lower, upper = 0.05, 0.4
# x = np.linspace(2, 9)
# y = get_group_offset(x, lower, upper)
# fig, ax = plt.subplots(figsize=(8, 5))
# ax.plot(x, y)
# ax.axvline(2, color="k", ls="--")
# ax.axhline(lower, color="k", ls="--")
# ax.axhline(upper, color="k", ls="--")
"""
When plotting categoric variables, this function computes the offset of the
stripplot points based on the number of groups ``n``.
"""
intercept, slope = 3.25, 1
return lower + np.exp(-np.exp(intercept - slope * n)) * (upper - lower)

Expand Down Expand Up @@ -434,3 +429,24 @@ def merge(y_hat_mean: xr.DataArray, y_hat_bounds: xr.DataArray, data: pd.DataFra
summary_df = pd.merge(left=data, right=preds_df, left_index=True, right_index=True)

return summary_df.drop(columns=["hdi_x", "hdi_y"])


def get_posterior(
response_obs: str, idata: az.InferenceData, pred_data: pd.DataFrame
) -> pd.DataFrame:
"""
Merges the posterior or posterior predictive draws with the corresponding
observation that produced that draw.
"""
# if `pps=True` in 'predictions', then use posterior predictive draws
if "posterior_predictive" in idata.groups():
posterior_df = idata.posterior_predictive.to_dataframe().reset_index()
else:
posterior_df = idata.posterior.to_dataframe().reset_index()

posterior_df = posterior_df.set_index(response_obs)
posterior_df = posterior_df.merge(pred_data, left_index=True, right_index=True)
posterior_df = posterior_df.rename(columns=lambda x: x.replace("_x", ""))
posterior_df = posterior_df.rename(columns=lambda x: x.replace("_y", "_obs"))

return posterior_df.reset_index(drop=True)
Loading
Loading