From 0bf80e0d1f032750b2719c2a371d53ea226b7dd1 Mon Sep 17 00:00:00 2001 From: GStechschulte Date: Mon, 21 Aug 2023 21:20:19 +0200 Subject: [PATCH] return type hint include Tuple for variable number of return objects --- bambi/interpret/effects.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/bambi/interpret/effects.py b/bambi/interpret/effects.py index 26fee0794..85bd3e212 100644 --- a/bambi/interpret/effects.py +++ b/bambi/interpret/effects.py @@ -1,7 +1,7 @@ # pylint: disable=ungrouped-imports from dataclasses import dataclass, field import itertools -from typing import Dict, Union +from typing import Dict, Tuple, Union import arviz as az import numpy as np @@ -413,7 +413,7 @@ def predictions( prob=None, transforms=None, return_posterior: bool = False, -) -> pd.DataFrame: +) -> Union[pd.DataFrame, Tuple[pd.DataFrame, pd.DataFrame]]: """Compute Conditional Adjusted Predictions Parameters @@ -531,7 +531,7 @@ def predictions( predictions_summary[response.upper_bound_name] = y_hat_bounds[1] if return_posterior: - return predictions_summary, get_posterior(response.name_obs, idata, cap_data) + return (predictions_summary, get_posterior(response.name_obs, idata, cap_data)) return predictions_summary @@ -547,7 +547,7 @@ def comparisons( prob: Union[float, None] = None, transforms: Union[dict, None] = None, return_posterior: bool = False, -) -> pd.DataFrame: +) -> Union[pd.DataFrame, Tuple[pd.DataFrame, pd.DataFrame]]: """Compute Conditional Adjusted Comparisons Parameters @@ -669,7 +669,7 @@ def comparisons( comparisons_summary = predictive_difference.average_by(variable=average_by) if return_posterior: - return comparisons_summary, get_posterior(response.name_obs, idata, comparisons_data) + return (comparisons_summary, get_posterior(response.name_obs, idata, comparisons_data)) return comparisons_summary @@ -686,7 +686,7 @@ def slopes( prob: Union[float, None] = None, transforms: Union[dict, None] = None, return_posterior: bool = False, -) -> pd.DataFrame: +) -> Union[pd.DataFrame, Tuple[pd.DataFrame, pd.DataFrame]]: """Compute Conditional Adjusted Slopes Parameters @@ -815,6 +815,6 @@ def slopes( slopes_summary = predictive_difference.average_by(variable=average_by) if return_posterior: - return slopes_summary, get_posterior(response.name_obs, idata, slopes_data) + return (slopes_summary, get_posterior(response.name_obs, idata, slopes_data)) return slopes_summary