Skip to content

Commit

Permalink
add tests for non-plotting functions, module docstring, and update fu…
Browse files Browse the repository at this point in the history
…nction calls
  • Loading branch information
GStechschulte committed Aug 21, 2023
1 parent 0bf80e0 commit 5f4c7cf
Show file tree
Hide file tree
Showing 2 changed files with 87 additions and 49 deletions.
34 changes: 34 additions & 0 deletions tests/test_interpret.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
"""
This module contains tests for the non-plotting functions of the 'interpret'
sub-package. In some cases, 'comparisons()', 'predictions()' and 'slopes()'
contain arguments not in their respective plotting functions. Such arguments
are tested here.
"""
import pandas as pd
import pytest

import bambi as bmb


@pytest.fixture(scope="module")
def mtcars():
data = bmb.load_data('mtcars')
data["am"] = pd.Categorical(data["am"], categories=[0, 1], ordered=True)
model = bmb.Model("mpg ~ hp * drat * am", data)
idata = model.fit(tune=500, draws=500, random_seed=1234)
return model, idata


@pytest.mark.parametrize("return_posterior", [True, False])
def test_return_posterior(mtcars, return_posterior):
model, idata = mtcars

bmb.interpret.predictions(
model, idata, ["hp", "wt"], return_posterior=return_posterior
)
bmb.interpret.comparisons(
model, idata, "hp", "wt", return_posterior=return_posterior
)
bmb.interpret.slopes(
model, idata, "hp", "wt", return_posterior=return_posterior
)
102 changes: 53 additions & 49 deletions tests/test_plots.py → tests/test_interpret_plots.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
from os.path import dirname, join

"""
This module contains tests for the plotting functions of the 'interpret'
sub-package. In most cases, testing the plotting function implicitly tests
the underlying functions that are called by the plotting functions (e.g.,
'bmb.interpret.plot_slopes()' calls 'slopes()') as they share mostly the same
arguments.
"""
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import pytest

import bambi as bmb
from bambi.interpret import plot_comparisons, plot_predictions, plot_slopes


@pytest.fixture(scope="module")
Expand All @@ -31,91 +35,91 @@ class TestCommon:
@pytest.mark.parametrize("pps", [False, True])
def test_use_hdi(self, mtcars, pps):
model, idata = mtcars
plot_comparisons(model, idata, "hp", "am", use_hdi=False)
plot_predictions(
bmb.interpret.plot_comparisons(model, idata, "hp", "am", use_hdi=False)
bmb.interpret.plot_predictions(
model,
idata,
["hp", "cyl", "gear"],
pps=pps,
use_hdi=False
)
plot_slopes(model, idata, "hp", "am", use_hdi=False)
bmb.interpret.plot_slopes(model, idata, "hp", "am", use_hdi=False)


@pytest.mark.parametrize("pps", [False, True])
def test_hdi_prob(self, mtcars, pps):
model, idata = mtcars
plot_comparisons(model, idata, "am", "hp", prob=0.8)
plot_predictions(
bmb.interpret.plot_comparisons(model, idata, "am", "hp", prob=0.8)
bmb.interpret.plot_predictions(
model,
idata,
["hp", "cyl", "gear"],
pps=pps,
prob=0.8
)
plot_slopes(model, idata, "hp", "am", prob=0.8)
bmb.interpret.plot_slopes(model, idata, "hp", "am", prob=0.8)

with pytest.raises(
ValueError, match="'prob' must be greater than 0 and smaller than 1. It is 1.1."
):
plot_comparisons(model, idata, "am", "hp", prob=1.1)
plot_predictions(
bmb.interpret.plot_comparisons(model, idata, "am", "hp", prob=1.1)
bmb.interpret.plot_predictions(
model,
idata,
["hp", "cyl", "gear"],
pps=pps,
prob=1.1)
plot_slopes(model, idata, "hp", "am", prob=1.1)
bmb.interpret.plot_slopes(model, idata, "hp", "am", prob=1.1)

with pytest.raises(
ValueError, match="'prob' must be greater than 0 and smaller than 1. It is -0.1."
):
plot_comparisons(model, idata, "am", "hp", prob=-0.1)
plot_predictions(
bmb.interpret.plot_comparisons(model, idata, "am", "hp", prob=-0.1)
bmb.interpret.plot_predictions(
model,
idata,
["hp", "cyl", "gear"],
pps=pps,
prob=-0.1)
plot_slopes(model, idata, "hp", "am", prob=0.1)
bmb.interpret.plot_slopes(model, idata, "hp", "am", prob=0.1)


@pytest.mark.parametrize("pps", [False, True])
def test_legend(self, mtcars, pps):
model, idata = mtcars
plot_comparisons(model, idata, "am", "hp", legend=False)
plot_predictions(model, idata, ["hp"], pps=pps,legend=False)
plot_slopes(model, idata, "hp", "am", legend=False)
bmb.interpret.plot_comparisons(model, idata, "am", "hp", legend=False)
bmb.interpret.plot_predictions(model, idata, ["hp"], pps=pps,legend=False)
bmb.interpret.plot_slopes(model, idata, "hp", "am", legend=False)


@pytest.mark.parametrize("pps", [False, True])
def test_ax(self, mtcars, pps):
model, idata = mtcars
fig, ax = plt.subplots()
fig_r, ax_r = plot_comparisons(model, idata, "am", "hp", ax=ax)
fig_r, ax_r = bmb.interpret.plot_comparisons(model, idata, "am", "hp", ax=ax)

assert isinstance(ax_r, np.ndarray)
assert fig is fig_r
assert ax is ax_r[0]

fig, ax = plt.subplots()
fig_r, ax_r = plot_predictions(model, idata, ["hp"], pps=pps, ax=ax)
fig_r, ax_r = bmb.interpret.plot_predictions(model, idata, ["hp"], pps=pps, ax=ax)

assert isinstance(ax_r, np.ndarray)
assert fig is fig_r
assert ax is ax_r[0]

fig, ax = plt.subplots()
fig_r, ax_r = plot_slopes(model, idata, "hp", "am", ax=ax)
fig_r, ax_r = bmb.interpret.plot_slopes(model, idata, "hp", "am", ax=ax)

assert isinstance(ax_r, np.ndarray)
assert fig is fig_r
assert ax is ax_r[0]


class TestCap:
class TestPredictions:
"""
Tests the 'plot_predictions' function for different combinations of main, group,
Tests the 'bmb.interpret.plot_predictions' function for different combinations of main, group,
and panel variables.
"""
@pytest.mark.parametrize("pps", [False, True])
Expand All @@ -129,7 +133,7 @@ class TestCap:
)
def test_basic(self, mtcars, covariates, pps):
model, idata = mtcars
plot_predictions(model, idata, covariates, pps=pps)
bmb.interpret.plot_predictions(model, idata, covariates, pps=pps)


@pytest.mark.parametrize("pps", [False, True])
Expand All @@ -143,7 +147,7 @@ def test_basic(self, mtcars, covariates, pps):
)
def test_with_groups(self, mtcars, covariates, pps):
model, idata = mtcars
plot_predictions(model, idata, covariates, pps=pps)
bmb.interpret.plot_predictions(model, idata, covariates, pps=pps)


@pytest.mark.parametrize("pps", [False, True])
Expand All @@ -156,13 +160,13 @@ def test_with_groups(self, mtcars, covariates, pps):
)
def test_with_group_and_panel(self, mtcars, covariates, pps):
model, idata = mtcars
plot_predictions(model, idata, covariates, pps=pps)
bmb.interpret.plot_predictions(model, idata, covariates, pps=pps)


@pytest.mark.parametrize("pps", [False, True])
def test_fig_kwargs(self, mtcars, pps):
model, idata = mtcars
plot_predictions(
bmb.interpret.plot_predictions(
model,
idata,
[ "hp", "cyl", "gear"],
Expand All @@ -174,7 +178,7 @@ def test_fig_kwargs(self, mtcars, pps):
@pytest.mark.parametrize("pps", [False, True])
def test_subplot_kwargs(self, mtcars, pps):
model, idata = mtcars
plot_predictions(
bmb.interpret.plot_predictions(
model,
idata,
["hp", "drat"],
Expand All @@ -193,7 +197,7 @@ def test_subplot_kwargs(self, mtcars, pps):
)
def test_transforms(self, mtcars, transforms, pps):
model, idata = mtcars
plot_predictions(model, idata, ["hp"], pps=pps, transforms=transforms)
bmb.interpret.plot_predictions(model, idata, ["hp"], pps=pps, transforms=transforms)


@pytest.mark.parametrize("pps", [False, True])
Expand All @@ -213,22 +217,22 @@ def test_multiple_outputs_with_alias(self, pps):
# Without alias
idata = model.fit(tune=100, draws=100, random_seed=1234)
# Test default target
plot_predictions(model, idata, "x", pps=pps)
bmb.interpret.plot_predictions(model, idata, "x", pps=pps)
# Test user supplied target argument
plot_predictions(model, idata, "x", "alpha", pps=False)
bmb.interpret.plot_predictions(model, idata, "x", "alpha", pps=False)

# With alias
alias = {"alpha": {"Intercept": "sd_intercept", "x": "sd_x", "alpha": "sd_alpha"}}
model.set_alias(alias)
idata = model.fit(tune=100, draws=100, random_seed=1234)

# Test user supplied target argument
plot_predictions(model, idata, "x", "alpha", pps=False)
bmb.interpret.plot_predictions(model, idata, "x", "alpha", pps=False)


class TestComparison:
"""
Tests the plot_comparisons function for different combinations of
Tests the bmb.interpret.plot_comparisons function for different combinations of
contrast and conditional variables, and user inputs.
"""
@pytest.mark.parametrize(
Expand All @@ -239,7 +243,7 @@ class TestComparison:
)
def test_basic(self, mtcars, contrast, conditional):
model, idata = mtcars
plot_comparisons(model, idata, contrast, conditional)
bmb.interpret.plot_comparisons(model, idata, contrast, conditional)


@pytest.mark.parametrize(
Expand All @@ -250,7 +254,7 @@ def test_basic(self, mtcars, contrast, conditional):
)
def test_with_groups(self, mtcars, contrast, conditional):
model, idata = mtcars
plot_comparisons(model, idata, contrast, conditional)
bmb.interpret.plot_comparisons(model, idata, contrast, conditional)


@pytest.mark.parametrize(
Expand All @@ -261,7 +265,7 @@ def test_with_groups(self, mtcars, contrast, conditional):
)
def test_with_user_values(self, mtcars, contrast, conditional):
model, idata = mtcars
plot_comparisons(model, idata, contrast, conditional)
bmb.interpret.plot_comparisons(model, idata, contrast, conditional)


@pytest.mark.parametrize(
Expand All @@ -271,7 +275,7 @@ def test_with_user_values(self, mtcars, contrast, conditional):
)
def test_subplot_kwargs(self, mtcars, contrast, conditional, subplot_kwargs):
model, idata = mtcars
plot_comparisons(model, idata, contrast, conditional, subplot_kwargs=subplot_kwargs)
bmb.interpret.plot_comparisons(model, idata, contrast, conditional, subplot_kwargs=subplot_kwargs)


@pytest.mark.parametrize(
Expand All @@ -282,23 +286,23 @@ def test_subplot_kwargs(self, mtcars, contrast, conditional, subplot_kwargs):
)
def test_transforms(self, mtcars, contrast, conditional, transforms):
model, idata = mtcars
plot_comparisons(model, idata, contrast, conditional, transforms=transforms)
bmb.interpret.plot_comparisons(model, idata, contrast, conditional, transforms=transforms)


@pytest.mark.parametrize("average_by", ["am", "drat", ["am", "drat"]])
def test_average_by(self, mtcars, average_by):
model, idata = mtcars

# grid of values with average_by
plot_comparisons(model, idata, "hp", ["am", "drat"], average_by)
bmb.interpret.plot_comparisons(model, idata, "hp", ["am", "drat"], average_by)

# unit level with average by
plot_comparisons(model, idata, "hp", None, average_by)
bmb.interpret.plot_comparisons(model, idata, "hp", None, average_by)


class TestSlopes:
"""
Tests the 'plot_slopes' function for different combinations, elasticity,
Tests the 'bmb.interpret.plot_slopes' function for different combinations, elasticity,
and effect types (unit and average slopes) of 'wrt' and 'conditional'
variables.
"""
Expand All @@ -311,7 +315,7 @@ class TestSlopes:
)
def test_basic(self, mtcars, wrt, conditional):
model, idata = mtcars
plot_slopes(model, idata, wrt, conditional)
bmb.interpret.plot_slopes(model, idata, wrt, conditional)


@pytest.mark.parametrize(
Expand All @@ -322,7 +326,7 @@ def test_basic(self, mtcars, wrt, conditional):
)
def test_with_groups(self, mtcars, wrt, conditional):
model, idata = mtcars
plot_slopes(model, idata, wrt, conditional)
bmb.interpret.plot_slopes(model, idata, wrt, conditional)


@pytest.mark.parametrize(
Expand All @@ -336,13 +340,13 @@ def test_with_groups(self, mtcars, wrt, conditional):
def test_with_user_values(self, mtcars, wrt, conditional, average_by):
model, idata = mtcars
# need to average by if greater than 1 value is passed with 'wrt'
plot_slopes(model, idata, wrt, conditional, average_by=average_by)
bmb.interpret.plot_slopes(model, idata, wrt, conditional, average_by=average_by)


@pytest.mark.parametrize("slope", ["dydx", "dyex", "eyex", "eydx"])
def test_elasticity(self, mtcars, slope):
model, idata = mtcars
plot_slopes(model, idata, "hp", "drat", slope=slope)
bmb.interpret.plot_slopes(model, idata, "hp", "drat", slope=slope)


@pytest.mark.parametrize(
Expand All @@ -352,7 +356,7 @@ def test_elasticity(self, mtcars, slope):
)
def test_subplot_kwargs(self, mtcars, wrt, conditional, subplot_kwargs):
model, idata = mtcars
plot_slopes(model, idata, wrt, conditional, subplot_kwargs=subplot_kwargs)
bmb.interpret.plot_slopes(model, idata, wrt, conditional, subplot_kwargs=subplot_kwargs)


@pytest.mark.parametrize(
Expand All @@ -363,15 +367,15 @@ def test_subplot_kwargs(self, mtcars, wrt, conditional, subplot_kwargs):
)
def test_transforms(self, mtcars, wrt, conditional, transforms):
model, idata = mtcars
plot_slopes(model, idata, wrt, conditional, transforms=transforms)
bmb.interpret.plot_slopes(model, idata, wrt, conditional, transforms=transforms)


@pytest.mark.parametrize("average_by", ["am", "drat", ["am", "drat"]])
def test_average_by(self, mtcars, average_by):
model, idata = mtcars

# grid of values with average_by
plot_slopes(model, idata, "hp", ["am", "drat"], average_by)
bmb.interpret.plot_slopes(model, idata, "hp", ["am", "drat"], average_by)

# unit level with average by
plot_slopes(model, idata, "hp", None, average_by)
bmb.interpret.plot_slopes(model, idata, "hp", None, average_by)

0 comments on commit 5f4c7cf

Please sign in to comment.