diff --git a/bambi/interpret/effects.py b/bambi/interpret/effects.py index 26fee0794..85bd3e212 100644 --- a/bambi/interpret/effects.py +++ b/bambi/interpret/effects.py @@ -1,7 +1,7 @@ # pylint: disable=ungrouped-imports from dataclasses import dataclass, field import itertools -from typing import Dict, Union +from typing import Dict, Tuple, Union import arviz as az import numpy as np @@ -413,7 +413,7 @@ def predictions( prob=None, transforms=None, return_posterior: bool = False, -) -> pd.DataFrame: +) -> Union[pd.DataFrame, Tuple[pd.DataFrame, pd.DataFrame]]: """Compute Conditional Adjusted Predictions Parameters @@ -531,7 +531,7 @@ def predictions( predictions_summary[response.upper_bound_name] = y_hat_bounds[1] if return_posterior: - return predictions_summary, get_posterior(response.name_obs, idata, cap_data) + return (predictions_summary, get_posterior(response.name_obs, idata, cap_data)) return predictions_summary @@ -547,7 +547,7 @@ def comparisons( prob: Union[float, None] = None, transforms: Union[dict, None] = None, return_posterior: bool = False, -) -> pd.DataFrame: +) -> Union[pd.DataFrame, Tuple[pd.DataFrame, pd.DataFrame]]: """Compute Conditional Adjusted Comparisons Parameters @@ -669,7 +669,7 @@ def comparisons( comparisons_summary = predictive_difference.average_by(variable=average_by) if return_posterior: - return comparisons_summary, get_posterior(response.name_obs, idata, comparisons_data) + return (comparisons_summary, get_posterior(response.name_obs, idata, comparisons_data)) return comparisons_summary @@ -686,7 +686,7 @@ def slopes( prob: Union[float, None] = None, transforms: Union[dict, None] = None, return_posterior: bool = False, -) -> pd.DataFrame: +) -> Union[pd.DataFrame, Tuple[pd.DataFrame, pd.DataFrame]]: """Compute Conditional Adjusted Slopes Parameters @@ -815,6 +815,6 @@ def slopes( slopes_summary = predictive_difference.average_by(variable=average_by) if return_posterior: - return slopes_summary, get_posterior(response.name_obs, idata, slopes_data) + return (slopes_summary, get_posterior(response.name_obs, idata, slopes_data)) return slopes_summary