Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Return interpret data #758

Closed

Conversation

GStechschulte
Copy link
Collaborator

@GStechschulte GStechschulte commented Nov 14, 2023

This PR addresses issue #703 and #751 by adding a parameter return_idata: bool = False in comparisons(), predictions(), and slopes() that merges the posterior draws with the corresponding observation that "produced" that draw and returns it as a dataframe.

Most of the code diff is from adding a new test file that tests non-plotting functionality of the interpret sub-package not tested in test_plots.py.

fish_data = pd.read_stata("http://www.stata-press.com/data/r11/fish.dta")
cols = ["count", "livebait", "camper", "persons", "child"]
fish_data = fish_data[cols]
fish_data["livebait"] = pd.Categorical(fish_data["livebait"])
fish_data["camper"] = pd.Categorical(fish_data["camper"])

fish_model = bmb.Model(
    "count ~ livebait + camper + persons + child", 
    fish_data, 
    family='zero_inflated_poisson'
)

fish_idata = fish_model.fit(
    draws=1000, 
    target_accept=0.95, 
    random_seed=1234, 
    chains=4
)

With return_idata=True, one data frame is returned. This dataframe contains the inference data from the posterior groupInferenceData object, observed data, and parameter estimates. In the case that a user is calling predictions with pps=True, then the posterior predictive group is used. {marginaleffects} has a similar functionality for Bayesian models.

Below are a few examples:

bmb.interpret.predictions(
    model=fish_model,
    idata=fish_idata,
    conditional=["persons", "child", "livebait"],
    return_idata=True
) 
chain draw livebait_dim camper_dim Intercept livebait camper persons child count_psi count_mean persons_obs child_obs livebait_obs camper_obs
0 0 1.0 1.0 -2.560515 1.877271 0.658819 0.850313 -1.256469 0.619094 0.349454 1.0 0.0 0.0 1.0
0 1 1.0 1.0 -2.746079 1.852347 0.794963 0.848145 -1.302583 0.643842 0.331884 1.0 0.0 0.0 1.0
0 2 1.0 1.0 -2.669619 1.674642 0.693835 0.923590 -1.588417 0.683065 0.349171 1.0 0.0 0.0 1.0
0 3 1.0 1.0 -2.581749 1.474203 0.661846 0.958753 -1.624487 0.684139 0.382453 1.0 0.0 0.0 1.0
0 4 1.0 1.0 -2.866501 2.162873 0.575769 0.860328 -1.311473 0.647937 0.239212 1.0 0.0 0.0 1.0

1200000 rows × 15 columns

Returning the inference data when calling comparisons will allow the user to conduct more specific or complex comparisons leveraging group by aggregations:

bmb.interpret.comparisons(
    model=fish_model,
    idata=fish_idata,
    contrast={"persons": [1, 4]},
    conditional={"child": [0, 1, 2], "livebait": [0, 1]},
    return_idata=True
) 
chain draw livebait_dim camper_dim Intercept livebait camper persons child count_psi count_mean child_obs livebait_obs persons_obs camper_obs
0 0 1.0 1.0 -2.560515 1.877271 0.658819 0.850313 -1.256469 0.619094 0.349454 0.0 0 1.0 1.0
0 1 1.0 1.0 -2.746079 1.852347 0.794963 0.848145 -1.302583 0.643842 0.331884 0.0 0 1.0 1.0
0 2 1.0 1.0 -2.669619 1.674642 0.693835 0.923590 -1.588417 0.683065 0.349171 0.0 0 1.0 1.0
0 3 1.0 1.0 -2.581749 1.474203 0.661846 0.958753 -1.624487 0.684139 0.382453 0.0 0 1.0 1.0
0 4 1.0 1.0 -2.866501 2.162873 0.575769 0.860328 -1.311473 0.647937 0.239212 0.0 0 1.0 1.0

48000 rows × 15 columns

Initially, I wanted to return the az.InferenceData object. However, due to the following limitations I settled on a DataFrame:

  • I could add new groups to the inference data object, but then it isn't clear to me how to perform group by aggregations on the posterior dataset while taking into account this new group.

  • Additionally, the data shouldn't be merged as a data variable in the az.InferenceData.posterior dataset because, when aggregations are performed along the coordinates, these aggregations will also be applied to the data used to generate the predictions (since they were merged as a data variable).

  • Lastly, the data could be merged and made as a coordinate so you can specify along which dimension(s) you want to compute the aggregation. Although, I can't seem to groupby more than one coordinate. For example, xr.Dataset.groupby([coord1, coord2]) results in the following error:

TypeError: group must be an xarray.DataArray or the name of an xarray variable or dimension. Received ['coord1', 'coord2'] instead.

Note: depending on the model specification and number of chains and draws, it is possible there will be millions of rows returned.

To do:

  • add case for when pps=True (posterior predictive samples) in predictions. Currently, I only access the posterior group of the InferenceData to build the dataframe.

@codecov-commenter
Copy link

codecov-commenter commented Nov 14, 2023

Codecov Report

Attention: 1 lines in your changes are missing coverage. Please review.

Comparison is base (dcd879b) 89.90% compared to head (6208fe8) 89.91%.

Files Patch % Lines
bambi/interpret/utils.py 90.00% 1 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main     #758      +/-   ##
==========================================
+ Coverage   89.90%   89.91%   +0.01%     
==========================================
  Files          45       45              
  Lines        3713     3729      +16     
==========================================
+ Hits         3338     3353      +15     
- Misses        375      376       +1     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@GStechschulte GStechschulte marked this pull request as ready for review November 15, 2023 05:51
@@ -473,11 +475,16 @@ def predictions(
sample_new_groups : bool, optional
If the model contains group-level effects, and data is passed for unseen groups, whether
to sample from the new groups. Defaults to ``False``.
return_idata : bool, optional
Copy link
Collaborator

@tomicapretto tomicapretto Nov 15, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we use a different name for this? From the name, I assumed the returned object would be an instance of InferenceData, but I see it's a data frame. Maybe return_data? Or, are you thinking we would be able to return an InferenceData instance in the future?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe "return_posterior_draws_dataframe".

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@zwelitunyiswa thanks for the suggestion! Tomas and I are looking into returning the InferenceData object after all. So stay tuned to this PR 😄

@@ -393,16 +394,10 @@ def make_group_values(x: np.ndarray, groups_n: int = 5) -> np.ndarray:


def get_group_offset(n, lower: float = 0.05, upper: float = 0.4) -> np.ndarray:
# Complementary log log function, scaled.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for removing these comments :D

def test_return_idata_common_effects(mtcars, return_idata):
model, idata = mtcars

bmb.interpret.predictions(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add a check for the type of the object returned and the number of rows/columns? Notice we would need to fix the number of chains above.

Copy link
Collaborator

@tomicapretto tomicapretto left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the nice feature! Just a couple of suggestions.

@GStechschulte
Copy link
Collaborator Author

Thanks for the nice feature! Just a couple of suggestions.

Thanks for the review. I will incorporate these once we finalize the implementation per our conversation on Slack.

@GStechschulte
Copy link
Collaborator Author

Closing in favor of #762

@GStechschulte GStechschulte deleted the return-interpret-data branch January 21, 2024 20:19
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants