Skip to content

Commit

Permalink
advanced interpret usage (#762)
Browse files Browse the repository at this point in the history
* re-run notebooks and advanced usage docs

* added select_draws and data_grid functions

* move data generation functions to create_data.py for better cohesion

* move sorting of dict values to ConditionalInfo dataclass

* removal of keyword args. to functions in create_data.py

* create_grid function for internal and user-level functions

* add select_draws and data_grid as modules

* remove code-cells

* initial tests for interpret helper functions

* add kwargs, docstrings, and error handling

* improved docstrings and inline comments

* remove functions that have been deleted from utils.py

* lowercase inline comments

* re-run docs and add advanced interpret docs

* finalize tests

* update interpret logger tests to reflect new message

* update logger to parse create_data func

* remove double backticks

* eps logic and add logger decorator

* re-run slopes and advanced usage notebooks

* remove elif block and update filterwarnings to specific message

* remove elif block and update filterwarnings to specific message

* remove elif block and update filterwarnings to specific message
  • Loading branch information
GStechschulte authored Dec 6, 2023
1 parent 312afa2 commit 8fa47aa
Show file tree
Hide file tree
Showing 16 changed files with 4,658 additions and 1,173 deletions.
3 changes: 3 additions & 0 deletions bambi/interpret/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
import logging

from bambi.interpret.effects import comparisons, predictions, slopes
from bambi.interpret.helpers import data_grid, select_draws
from bambi.interpret.plotting import plot_comparisons, plot_predictions, plot_slopes

__all__ = [
"comparisons",
"data_grid",
"logger",
"select_draws",
"slopes",
"predictions",
"plot_comparisons",
Expand Down
242 changes: 127 additions & 115 deletions bambi/interpret/create_data.py
Original file line number Diff line number Diff line change
@@ -1,136 +1,131 @@
import itertools

from typing import Union
from statistics import mode

import numpy as np
import pandas as pd

from pandas.api.types import (
is_categorical_dtype,
is_float_dtype,
is_integer_dtype,
is_numeric_dtype,
is_object_dtype,
is_string_dtype,
)

from bambi import Model
from bambi.interpret.utils import (
ConditionalInfo,
enforce_dtypes,
get_covariates,
get_model_covariates,
make_group_panel_values,
make_main_values,
set_default_values,
VariableInfo,
)

from bambi.interpret.logs import log_interpret_defaults

def _pairwise_grid(data_dict: dict) -> pd.DataFrame:
"""Creates a pairwise grid (cartesian product) of data by using the
key-values of the dictionary.

@log_interpret_defaults
def create_grid(
condition: ConditionalInfo, variable: Union[VariableInfo, None] = None, **kwargs
) -> pd.DataFrame:
"""Creates a grid of data by using the covariates passed into the 'conditional'
and 'variable' argument.
Values for the grid are either:
1.) computed using an equally spaced grid (`np.linspace`), mean, and or mode
depending on the covariate dtype.
2.) a user specified value or range of values if `condition.user_passed = True`
Parameters
----------
data_dict : dict
A dictionary containing the covariates as keys and their values as the
values.
condition : ConditionalInfo
Information about data passed to the conditional parameter of 'comparisons',
'predictions', or 'slopes' related functions.
variable : VariableInfo, optional
Information about data passed to the variable of interest parameter. This
is 'contrast' for 'comparisons', 'wrt' for 'slopes', and 'None' for 'predictions'.
**kwargs : dict
Optional keywords arguments such as 'effect_type' (the effect being computed),
and 'num' (the number of values to return when computing a `np.linspace` grid).
Returns
-------
pd.DataFrame
A dataframe containing values used as input to the fitted Bambi model to
generate predictions.
A dataframe containing pairwise combinations of values.
"""
keys, values = zip(*data_dict.items())
data_grid = pd.DataFrame([dict(zip(keys, v)) for v in itertools.product(*values)])
return data_grid
model, observed_data = condition.model, condition.model.data

if condition.user_passed:
# shallow copy of user-passed data dictionary
data_dict = {**condition.conditional}
else:
data_dict = {}
# values here are the names of the covariates
for covariate in condition.covariates.values():
x = observed_data[covariate]
num = kwargs.get("num", 50)
if is_numeric_dtype(x) or is_float_dtype(x):
values = np.linspace(np.min(x), np.max(x), num)
elif is_integer_dtype(x):
values = np.quantile(x, np.linspace(0, 1, 5))
elif is_categorical_dtype(x) or is_string_dtype(x) or is_object_dtype(x):
values = np.unique(x)
else:
raise TypeError(
f"Unsupported data type of {x.dtype} for covariate '{covariate.name}'"
)

data_dict[covariate] = values

if variable:
data_dict[variable.name] = variable.values

# set typical values as defaults for unspecified covariates
data_dict = set_default_values(model, data_dict)
data_grid = _pairwise_grid(data_dict)

# can't enforce dtype on 'with respect to' variable for 'slopes' as it
# may remove floating point in the epsilon
effect = kwargs.get("effect_type", None)
if effect == "slopes":
except_col = variable.name
else:
except_col = None

data_grid = enforce_dtypes(observed_data, data_grid, except_col)

# after computing default values, fractional values may have been computed.
# Enforcing the dtype of "int" may create duplicate rows as it will round
# the fractional values.
data_grid = data_grid.drop_duplicates()

return data_grid.reset_index(drop=True)

def _grid_level(
condition_info: ConditionalInfo,
variable_info: Union[VariableInfo, None],
user_passed: bool,
kind: str,
) -> pd.DataFrame:
"""Creates a "grid" of data by using the covariates passed into the
`conditional` argument. Values for the grid are either: (1) computed
using a equally spaced grid, mean, and or mode (depending on the
covariate dtype), and (2) a user specified value or range of values.

def _pairwise_grid(data_dict: dict) -> pd.DataFrame:
"""Creates a pairwise grid (cartesian product) of data by using the
key-values of the dictionary.
Parameters
----------
condition_info : ConditionalInfo
Information about the conditional argument passed into the plot
function.
variable_info : VariableInfo, optional
Information about the variable of interest. This is `contrast` for
'comparisons', `wrt` for 'slopes', and `None` for 'predictions'.
user_passed : bool
Whether the user passed a value(s) for the `conditional` argument.
kind : str
The kind of effect being computed. Either "comparisons", "predictions",
or "slopes".
data_dict : dict
A dictionary containing the covariates as keys and their values as the
values.
Returns
-------
pd.DataFrame
A dataframe containing values used as input to the fitted Bambi model to
generate predictions.
"""
covariates = get_covariates(condition_info.covariates)

if kind == "predictions":
# Compute pairwise grid of values if the user passed a dict.
if user_passed:
data_dict = {**condition_info.conditional}
data_dict = set_default_values(condition_info.model, data_dict, kind=kind)
for key, value in data_dict.items():
if not isinstance(value, (list, np.ndarray)):
data_dict[key] = [value]
data_grid = _pairwise_grid(data_dict)
else:
# Compute a grid of values
main_values = make_main_values(
condition_info.model.data[covariates.main], covariates.main
)
data_dict = {covariates.main: main_values}
data_dict = make_group_panel_values(
condition_info.model.data,
data_dict,
covariates.main,
covariates.group,
covariates.panel,
kind=kind,
)
data_dict = set_default_values(condition_info.model, data_dict, kind=kind)
data_grid = pd.DataFrame(data_dict)
else:
# Compute pairwise grid of values if the user passed a dict.
if user_passed:
data_dict = {**condition_info.conditional}
else:
# Compute a grid of values
main_values = make_main_values(
condition_info.model.data[covariates.main], covariates.main
)
data_dict = {covariates.main: main_values}
data_dict = make_group_panel_values(
condition_info.model.data,
data_dict,
covariates.main,
covariates.group,
covariates.panel,
kind=kind,
)

data_dict[variable_info.name] = variable_info.values
data_dict = set_default_values(condition_info.model, data_dict, kind=kind)
data_grid = _pairwise_grid(data_dict)

# Can't enforce dtype on numeric 'wrt' for 'slopes 'as it may remove floating point epsilons
except_col = None if kind in ("comparisons", "predictions") else {variable_info.name}
data_grid = enforce_dtypes(condition_info.model.data, data_grid, except_col)

# After computing default values, fractional values may have been computed.
# Enforcing the dtype of "int" may create duplicate rows as it will round
# the fractional values.
data_grid = data_grid.drop_duplicates()

return data_grid.reset_index(drop=True)
keys, values = zip(*data_dict.items())
cross_joined_data = pd.DataFrame([dict(zip(keys, v)) for v in itertools.product(*values)])
return cross_joined_data


def _differences_unit_level(variable_info: VariableInfo, kind: str) -> pd.DataFrame:
def _differences_unit_level(variable_info: VariableInfo, effect_type: str) -> pd.DataFrame:
"""Creates the data for unit-level contrasts by using the observed (empirical)
data. All covariates in the model are included in the data, except for the
contrast predictor. The contrast predictor is replaced with either: (1) the
Expand All @@ -141,8 +136,8 @@ def _differences_unit_level(variable_info: VariableInfo, kind: str) -> pd.DataFr
variable_info : VariableInfo
Information about the variable of interest. This is `contrast` for
'comparisons' and `wrt` for 'slopes'.
kind : str
The kind of effect being computed. Either "comparisons" or "slopes".
effect_type : str
The type of effect being computed. Either "comparisons" or "slopes".
Returns
-------
Expand All @@ -153,10 +148,9 @@ def _differences_unit_level(variable_info: VariableInfo, kind: str) -> pd.DataFr
"""
covariates = get_model_covariates(variable_info.model)
df = variable_info.model.data[covariates].drop(labels=variable_info.name, axis=1)

variable_vals = variable_info.values

if kind == "comparisons":
if effect_type == "comparisons":
variable_vals = np.array(variable_info.values)[..., None]
variable_vals = np.repeat(variable_vals, variable_info.model.data.shape[0], axis=1)

Expand All @@ -165,11 +159,13 @@ def _differences_unit_level(variable_info: VariableInfo, kind: str) -> pd.DataFr
unit_level_df_dict[f"contrast_{idx}"] = df.copy()
unit_level_df_dict[f"contrast_{idx}"][variable_info.name] = value

return pd.concat(unit_level_df_dict.values())
unit_level_df = pd.concat(unit_level_df_dict.values())

return unit_level_df.reset_index(drop=True)


def create_differences_data(
condition_info: ConditionalInfo, variable_info: VariableInfo, user_passed: bool, kind: str
condition_info: ConditionalInfo, variable_info: VariableInfo, effect_type: str
) -> pd.DataFrame:
"""Creates either unit level or grid level data for 'comparisons' and 'slopes'
depending if the user passed covariate values.
Expand All @@ -182,10 +178,8 @@ def create_differences_data(
variable_info : VariableInfo
Information about the variable of interest. This is `contrast` for
'comparisons' and `wrt` for 'slopes'.
user_passed : bool
Whether the user passed a value(s) for the `conditional` argument.
kind : str
The kind of effect being computed. Either "comparisons" or "slopes".
effect_type : str
The type of effect being computed. Either "comparisons" or "slopes".
Returns
-------
Expand All @@ -195,14 +189,13 @@ def create_differences_data(
is returned. Otherwise, a grid of values is created using the covariates
passed into the `conditional` argument.
"""

if not condition_info.covariates:
return _differences_unit_level(variable_info, kind)
return _differences_unit_level(variable_info, effect_type)

return _grid_level(condition_info, variable_info, user_passed, kind)
return create_grid(condition_info, variable_info, effect_type=effect_type)


def create_predictions_data(condition_info: ConditionalInfo, user_passed: bool) -> pd.DataFrame:
def create_predictions_data(condition_info: ConditionalInfo) -> pd.DataFrame:
"""Creates either unit level or grid level data for 'predictions' depending
if the user passed covariates.
Expand All @@ -211,8 +204,6 @@ def create_predictions_data(condition_info: ConditionalInfo, user_passed: bool)
condition_info : ConditionalInfo
Information about the conditional argument passed into the plot
function.
user_passed : bool
Whether the user passed a value(s) for the `conditional` argument.
Returns
-------
Expand All @@ -222,9 +213,30 @@ def create_predictions_data(condition_info: ConditionalInfo, user_passed: bool)
is returned. Otherwise, a grid of values is created using the covariates
passed into the `conditional` argument.
"""
# Unit level data used the observed (empirical) data
# unit level data uses the observed (empirical) data
if not condition_info.covariates:
covariates = get_model_covariates(condition_info.model)
return condition_info.model.data[covariates]

return _grid_level(condition_info, None, user_passed, "predictions")
return create_grid(condition_info, None)


@log_interpret_defaults
def set_default_values(model: Model, data_dict: dict) -> dict:
"""
Set default values for each variable in the model if the user did not
pass them in the data_dict.
"""
# set unspecified covariates to "typical" values
unique_covariates = get_model_covariates(model)
for name in unique_covariates:
if name not in data_dict:
x = model.data[name]
if is_numeric_dtype(x) or is_integer_dtype(x) or is_float_dtype(x):
data_dict[name] = np.array([np.mean(x)])
elif is_categorical_dtype(x) or is_string_dtype(x) or is_object_dtype(x):
data_dict[name] = np.array([mode(x)])
else:
raise TypeError(f"Unsupported data type of {x.dtype} for covariate '{name}'")

return data_dict
Loading

0 comments on commit 8fa47aa

Please sign in to comment.