From e326a6fc2fedd1a18284780efbd70edf4a7b2010 Mon Sep 17 00:00:00 2001 From: GStechschulte Date: Sat, 22 Jul 2023 10:27:32 +0200 Subject: [PATCH] update to VariableInfo class to allow slopes with user provided multiple values --- bambi/plots/utils.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/bambi/plots/utils.py b/bambi/plots/utils.py index 548c4d443..fd47c18c5 100644 --- a/bambi/plots/utils.py +++ b/bambi/plots/utils.py @@ -19,8 +19,10 @@ class VariableInfo: kind: str grid: Union[bool, None] = None eps: Union[float, None] = None + user_passed: bool = False name: str = field(init=False) values: Union[int, float] = field(init=False) + passed_values: int = field(init=False) def __post_init__(self): """ @@ -29,10 +31,13 @@ def __post_init__(self): own values, and dtype of the variable """ if isinstance(self.variable, dict): - self.values = np.array(list(self.variable.values())[0]) + self.user_passed = True + self.passed_values = np.array(list(self.variable.values())[0]) + self.values = self.passed_values if self.kind == "slopes": - # TODO: does not work if users passes list or array - self.values = np.array([self.values, self.values + self.eps]) + self.values = self.epsilon_difference(self.passed_values, self.eps) + if self.values.ndim > 1: + self.values = self.values.flatten() self.name = list(self.variable.keys())[0] elif isinstance(self.variable, (list, str)): self.name = self.variable @@ -61,13 +66,11 @@ def set_default_variable_values(self): names = [component.name] for name in names: if name == self.name: - # for numeric predictors, select the mean. predictor_data = self.model.data[name] dtype = predictor_data.dtype if component.kind == "numeric": if self.grid: predictor_data = np.mean(predictor_data) - if self.kind == "slopes": values = self.epsilon_difference(predictor_data, self.eps) elif self.kind == "comparisons": @@ -322,7 +325,6 @@ def get_unique_levels(x: np.ndarray) -> np.ndarray: Get unique levels of a categoric variable. """ if hasattr(x, "dtype") and hasattr(x.dtype, "categories"): - # levels = list(x.dtype.categories) levels = np.array((x.dtype.categories)) else: levels = np.unique(x)