Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

plot_slopes and slopes #699

Merged
merged 28 commits into from
Aug 9, 2023
Merged
Show file tree
Hide file tree
Changes from 27 commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
b125a47
add slopes and plot_slopes
GStechschulte Jul 20, 2023
bb4ec6a
common create data function for comparisons and slopes
GStechschulte Jul 20, 2023
da1b6d4
add slopes and PredictiveDifference class for computing effects and b…
GStechschulte Jul 20, 2023
c7f6410
Assign colors for single covariates by @tjburch
GStechschulte Jul 20, 2023
689641a
add plot_slopes and common plotting function for slopes and comparisons
GStechschulte Jul 20, 2023
6935ceb
common VariableInfo class and default value computation for slopes an…
GStechschulte Jul 20, 2023
8bd459c
reorder imports alphabetically
GStechschulte Jul 22, 2023
ff60f08
improved docstring for plot_slopes
GStechschulte Jul 22, 2023
4a4482b
move private inner functions outside of 'create_differences_data' func
GStechschulte Jul 22, 2023
dd8bd43
slopes supports multiple values, added args. to 'get_estimate', and i…
GStechschulte Jul 22, 2023
dbb7cc1
add color='C0' to fix color bug
GStechschulte Jul 22, 2023
e326a6f
update to VariableInfo class to allow slopes with user provided multi…
GStechschulte Jul 22, 2023
392625f
add support for semi-elasticities, move slopes and setting of variabl…
GStechschulte Jul 24, 2023
a8f72b3
update ValueError to include semi-elasticities
GStechschulte Jul 24, 2023
d76e3e1
raise ValueError is slopes not in semi-elasticities
GStechschulte Jul 24, 2023
b883df6
run black code formatting
GStechschulte Jul 24, 2023
c04858e
if slopes effect, convert columns that are not 'wrt' to original dtype
GStechschulte Jul 26, 2023
eebfb34
pass transforms as arg. to PredictiveDifferences
GStechschulte Jul 26, 2023
9a541ac
raise ValueError if user attempts to use 'plot_slopes' without averag…
GStechschulte Jul 26, 2023
43e9ead
fix set_default_values bug for 'comparisons' by applying the mean to …
GStechschulte Jul 26, 2023
1d8888d
add 'plot_slopes' tests
GStechschulte Jul 26, 2023
dc77947
docstring fixes / enhancements
GStechschulte Jul 26, 2023
27a6360
run black formatting
GStechschulte Jul 27, 2023
7d6f82d
bug fixed when user provided values > 3 and better error raise ValueE…
GStechschulte Aug 1, 2023
9dbd181
improved error handling for and
GStechschulte Aug 2, 2023
e426978
improved error handling for and
GStechschulte Aug 2, 2023
785e88d
run black formatting
GStechschulte Aug 3, 2023
541621b
add and or improve docstrings to classes and functions
GStechschulte Aug 6, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions bambi/plots/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from bambi.plots.effects import comparisons, predictions
from bambi.plots.plotting import plot_cap, plot_comparison
from bambi.plots.effects import comparisons, predictions, slopes
from bambi.plots.plotting import plot_cap, plot_comparison, plot_slopes


__all__ = ["comparisons", "predictions", "plot_cap", "plot_comparison"]
__all__ = ["comparisons", "slopes", "predictions", "plot_cap", "plot_comparison", "plot_slopes"]
182 changes: 82 additions & 100 deletions bambi/plots/create_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,124 +6,106 @@
from bambi.models import Model
from bambi.plots.utils import (
ConditionalInfo,
ContrastInfo,
enforce_dtypes,
get_covariates,
get_model_covariates,
make_group_panel_values,
make_main_values,
set_default_values,
VariableInfo,
)


def _grid_level(
condition_info: ConditionalInfo, variable_info: VariableInfo, user_passed: bool, kind: str
):
"""
Creates a "grid" of data by using the covariates passed into the
`conditional` argument. Values for the grid are either: (1) computed
using a equally spaced grid, mean, and or mode (depending on the
covariate dtype), and (2) a user specified value or range of values.
"""
covariates = get_covariates(condition_info.covariates)

if user_passed:
data_dict = {**condition_info.conditional}
else:
main_values = make_main_values(condition_info.model.data[covariates.main])
data_dict = {covariates.main: main_values}
data_dict = make_group_panel_values(
condition_info.model.data,
data_dict,
covariates.main,
covariates.group,
covariates.panel,
kind=kind,
)

data_dict[variable_info.name] = variable_info.values
comparison_data = set_default_values(condition_info.model, data_dict, kind=kind)
# use cartesian product (cross join) to create pairwise grid
keys, values = zip(*comparison_data.items())
pairwise_grid = pd.DataFrame([dict(zip(keys, v)) for v in itertools.product(*values)])
# can't enforce dtype on numeric 'wrt' as it may remove floating point epsilons
if kind == "comparisons":
pairwise_grid = enforce_dtypes(condition_info.model.data, pairwise_grid)
elif kind == "slopes":
pairwise_grid = enforce_dtypes(condition_info.model.data, pairwise_grid, variable_info.name)

return pairwise_grid


def _unit_level(variable_info: VariableInfo, kind: str) -> pd.DataFrame:
"""
Creates the data for unit-level contrasts by using the observed (empirical)
data. All covariates in the model are included in the data, except for the
contrast predictor. The contrast predictor is replaced with either: (1) the
default contrast value, or (2) the user specified contrast value.
"""
covariates = get_model_covariates(variable_info.model)
df = variable_info.model.data[covariates].drop(labels=variable_info.name, axis=1)

variable_vals = variable_info.values

if kind == "comparisons":
variable_vals = np.array(variable_info.values)[..., None]
variable_vals = np.repeat(variable_vals, variable_info.model.data.shape[0], axis=1)

contrast_df_dict = {}
for idx, value in enumerate(variable_vals):
contrast_df_dict[f"contrast_{idx}"] = df.copy()
contrast_df_dict[f"contrast_{idx}"][variable_info.name] = value

return pd.concat(contrast_df_dict.values())


def create_differences_data(
GStechschulte marked this conversation as resolved.
Show resolved Hide resolved
condition_info: ConditionalInfo, variable_info: VariableInfo, user_passed: bool, kind: str
) -> pd.DataFrame:
"""
Creates either unit level or grid level data for 'comparisons' and 'slopes'
depending if the user passed covariate values.
"""

if not condition_info.covariates:
return _unit_level(variable_info, kind)
else:
return _grid_level(condition_info, variable_info, user_passed, kind)


def create_cap_data(model: Model, covariates: dict) -> pd.DataFrame:
GStechschulte marked this conversation as resolved.
Show resolved Hide resolved
"""Create data for a Conditional Adjusted Predictions

Parameters
----------
model : bambi.Model
An instance of a Bambi model
covariates : dict
A dictionary of length between one and three.
Keys must be taken from ("horizontal", "color", "panel").
The values indicate the names of variables.

Returns
-------
pandas.DataFrame
The data for the Conditional Adjusted Predictions dataframe and or
plotting.
"""
Creates a data grid for conditional adjusted predictions using the covariates
passed by the user.
"""
data = model.data
covariates = get_covariates(covariates)
main, group, panel = covariates.main, covariates.group, covariates.panel

# Obtain data for main variable
main_values = make_main_values(data[main])
data_dict = {main: main_values}

# Obtain data for group and panel variables if not None
data_dict = make_group_panel_values(data, data_dict, main, group, panel, kind="predictions")
data_dict = set_default_values(model, data_dict, kind="predictions")
return enforce_dtypes(data, pd.DataFrame(data_dict))


def create_comparisons_data(
condition: ConditionalInfo, contrast: ContrastInfo, user_passed: bool = False
) -> pd.DataFrame:
"""Create data for a Conditional Adjusted Comparisons

Parameters
----------
condition: ConditionalInfo
A dataclass instance containing the model, contrast, and conditional
covariates to be used in the comparisons.
contrast: ContrastInfo
A dataclass instance containing the model, and contrast name and values.
user_passed: bool, optional
Whether the user passed their own 'conditional' data. Defaults to False.

Returns
-------
pd.DataFrame
The data for the Conditional Adjusted Comparisons dataframe and or
plotting.
"""

def _grid_level(condition: ConditionalInfo, contrast: ContrastInfo):
"""
Creates the data for grid-level contrasts by using the covariates passed
into the `conditional` arg. Values for the grid are either: (1) computed
using a equally spaced grid, mean, and or mode (depending on the covariate
dtype), and (2) a user specified value or range of values.
"""
covariates = get_covariates(condition.covariates)

if user_passed:
data_dict = {**condition.conditional}
else:
main_values = make_main_values(condition.model.data[covariates.main])
data_dict = {covariates.main: main_values}
data_dict = make_group_panel_values(
condition.model.data,
data_dict,
covariates.main,
covariates.group,
covariates.panel,
kind="comparison",
)

data_dict[contrast.name] = contrast.values
comparison_data = set_default_values(condition.model, data_dict, kind="comparison")
# use cartesian product (cross join) to create contrasts
keys, values = zip(*comparison_data.items())
contrast_dict = [dict(zip(keys, v)) for v in itertools.product(*values)]

return enforce_dtypes(condition.model.data, pd.DataFrame(contrast_dict))

def _unit_level(contrast: ContrastInfo):
"""
Creates the data for unit-level contrasts by using the observed (empirical)
data. All covariates in the model are included in the data, except for the
contrast predictor. The contrast predictor is replaced with either: (1) the
default contrast value, or (2) the user specified contrast value.
"""
covariates = get_model_covariates(contrast.model)
df = contrast.model.data[covariates].drop(labels=contrast.name, axis=1)

contrast_vals = np.array(contrast.values)[..., None]
contrast_vals = np.repeat(contrast_vals, contrast.model.data.shape[0], axis=1)

contrast_df_dict = {}
for idx, value in enumerate(contrast_vals):
contrast_df_dict[f"contrast_{idx}"] = df.copy()
contrast_df_dict[f"contrast_{idx}"][contrast.name] = value

return pd.concat(contrast_df_dict.values())

if not condition.conditional:
df = _unit_level(contrast)
else:
df = _grid_level(condition, contrast)

return df
return enforce_dtypes(data, pd.DataFrame(data_dict))
Loading
Loading