Skip to content

Commit

Permalink
Simplify and expand test ranges in test_normal_horseshoe_sampler
Browse files Browse the repository at this point in the history
  • Loading branch information
brandonwillard committed Jul 28, 2022
1 parent fe5190e commit b461e81
Showing 1 changed file with 5 additions and 6 deletions.
11 changes: 5 additions & 6 deletions tests/test_gibbs.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,13 +86,13 @@ def test_horseshoe_match(srng):


@pytest.mark.parametrize(
"N, p, nonzero_atol",
"N, p, rtol",
[
(50, 10, np.array([1.0, 0.5, 0.5, 3e-1, 3e-1])),
(50, 55, np.array([1.5, 0.5, 0.5, 0.75, 3e-1])),
(50, 10, 0.5),
(50, 75, 0.5),
],
)
def test_normal_horseshoe_sampler(srng, N, p, nonzero_atol):
def test_normal_horseshoe_sampler(srng, N, p, rtol):
"""Check the results of a normal regression model with a Horseshoe prior.
This test example is modified from section 3.2 of Makalic & Schmidt (2016)
Expand Down Expand Up @@ -131,8 +131,7 @@ def test_normal_horseshoe_sampler(srng, N, p, nonzero_atol):
assert np.all(lambda_post_val >= 0)

beta_post_median = np.median(beta_post_vals[100::2], axis=0)
assert np.allclose(beta_post_median[:5], true_beta[:5], atol=nonzero_atol)
assert np.all(np.abs(beta_post_median[5:]) < 1)
assert np.allclose(beta_post_median[:5], true_beta[:5], atol=1e-1, rtol=rtol)


@pytest.mark.parametrize(
Expand Down

0 comments on commit b461e81

Please sign in to comment.