Skip to content

Commit

Permalink
update to VariableInfo class to allow slopes with user provided multi…
Browse files Browse the repository at this point in the history
…ple values
  • Loading branch information
GStechschulte committed Jul 22, 2023
1 parent dbb7cc1 commit e326a6f
Showing 1 changed file with 8 additions and 6 deletions.
14 changes: 8 additions & 6 deletions bambi/plots/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,10 @@ class VariableInfo:
kind: str
grid: Union[bool, None] = None
eps: Union[float, None] = None
user_passed: bool = False
name: str = field(init=False)
values: Union[int, float] = field(init=False)
passed_values: int = field(init=False)

def __post_init__(self):
"""
Expand All @@ -29,10 +31,13 @@ def __post_init__(self):
own values, and dtype of the variable
"""
if isinstance(self.variable, dict):
self.values = np.array(list(self.variable.values())[0])
self.user_passed = True
self.passed_values = np.array(list(self.variable.values())[0])
self.values = self.passed_values
if self.kind == "slopes":
# TODO: does not work if users passes list or array
self.values = np.array([self.values, self.values + self.eps])
self.values = self.epsilon_difference(self.passed_values, self.eps)
if self.values.ndim > 1:
self.values = self.values.flatten()
self.name = list(self.variable.keys())[0]
elif isinstance(self.variable, (list, str)):
self.name = self.variable
Expand Down Expand Up @@ -61,13 +66,11 @@ def set_default_variable_values(self):
names = [component.name]
for name in names:
if name == self.name:
# for numeric predictors, select the mean.
predictor_data = self.model.data[name]
dtype = predictor_data.dtype
if component.kind == "numeric":
if self.grid:
predictor_data = np.mean(predictor_data)

if self.kind == "slopes":
values = self.epsilon_difference(predictor_data, self.eps)
elif self.kind == "comparisons":
Expand Down Expand Up @@ -322,7 +325,6 @@ def get_unique_levels(x: np.ndarray) -> np.ndarray:
Get unique levels of a categoric variable.
"""
if hasattr(x, "dtype") and hasattr(x.dtype, "categories"):
# levels = list(x.dtype.categories)
levels = np.array((x.dtype.categories))
else:
levels = np.unique(x)
Expand Down

0 comments on commit e326a6f

Please sign in to comment.