Skip to content

Commit

Permalink
return type hint include Tuple for variable number of return objects
Browse files Browse the repository at this point in the history
  • Loading branch information
GStechschulte committed Aug 21, 2023
1 parent 614e652 commit 0bf80e0
Showing 1 changed file with 7 additions and 7 deletions.
14 changes: 7 additions & 7 deletions bambi/interpret/effects.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# pylint: disable=ungrouped-imports
from dataclasses import dataclass, field
import itertools
from typing import Dict, Union
from typing import Dict, Tuple, Union

import arviz as az
import numpy as np
Expand Down Expand Up @@ -413,7 +413,7 @@ def predictions(
prob=None,
transforms=None,
return_posterior: bool = False,
) -> pd.DataFrame:
) -> Union[pd.DataFrame, Tuple[pd.DataFrame, pd.DataFrame]]:
"""Compute Conditional Adjusted Predictions
Parameters
Expand Down Expand Up @@ -531,7 +531,7 @@ def predictions(
predictions_summary[response.upper_bound_name] = y_hat_bounds[1]

if return_posterior:
return predictions_summary, get_posterior(response.name_obs, idata, cap_data)
return (predictions_summary, get_posterior(response.name_obs, idata, cap_data))

return predictions_summary

Expand All @@ -547,7 +547,7 @@ def comparisons(
prob: Union[float, None] = None,
transforms: Union[dict, None] = None,
return_posterior: bool = False,
) -> pd.DataFrame:
) -> Union[pd.DataFrame, Tuple[pd.DataFrame, pd.DataFrame]]:
"""Compute Conditional Adjusted Comparisons
Parameters
Expand Down Expand Up @@ -669,7 +669,7 @@ def comparisons(
comparisons_summary = predictive_difference.average_by(variable=average_by)

if return_posterior:
return comparisons_summary, get_posterior(response.name_obs, idata, comparisons_data)
return (comparisons_summary, get_posterior(response.name_obs, idata, comparisons_data))

return comparisons_summary

Expand All @@ -686,7 +686,7 @@ def slopes(
prob: Union[float, None] = None,
transforms: Union[dict, None] = None,
return_posterior: bool = False,
) -> pd.DataFrame:
) -> Union[pd.DataFrame, Tuple[pd.DataFrame, pd.DataFrame]]:
"""Compute Conditional Adjusted Slopes
Parameters
Expand Down Expand Up @@ -815,6 +815,6 @@ def slopes(
slopes_summary = predictive_difference.average_by(variable=average_by)

if return_posterior:
return slopes_summary, get_posterior(response.name_obs, idata, slopes_data)
return (slopes_summary, get_posterior(response.name_obs, idata, slopes_data))

return slopes_summary

0 comments on commit 0bf80e0

Please sign in to comment.