From 29fd73225757e2efc5710596162a5c632516f511 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tom=C3=A1s=20Capretto?= Date: Sun, 28 Apr 2024 12:12:34 -0300 Subject: [PATCH] Add test to compare draws when `var_names` is used in `pm.sample()` (#7287) --- tests/sampling/test_mcmc.py | 39 +++++++++++++++++++++++++++++++------ 1 file changed, 33 insertions(+), 6 deletions(-) diff --git a/tests/sampling/test_mcmc.py b/tests/sampling/test_mcmc.py index 31b48250fc..647e8ada70 100644 --- a/tests/sampling/test_mcmc.py +++ b/tests/sampling/test_mcmc.py @@ -695,12 +695,39 @@ def test_no_init_nuts_compound(caplog): def test_sample_var_names(): - with pm.Model() as model: - a = pm.Normal("a") - b = pm.Deterministic("b", a**2) - idata = pm.sample(10, tune=10, var_names=["a"]) - assert "a" in idata.posterior - assert "b" not in idata.posterior + # Generate data + seed = 1234 + rng = np.random.default_rng(seed) + + group = rng.choice(list("ABCD"), size=100) + x = rng.normal(size=100) + y = rng.normal(size=100) + + group_values, group_idx = np.unique(group, return_inverse=True) + + coords = {"group": group_values} + + # Create model + with pm.Model(coords=coords) as model: + b_group = pm.Normal("b_group", dims="group") + b_x = pm.Normal("b_x") + mu = pm.Deterministic("mu", b_group[group_idx] + b_x * x) + sigma = pm.HalfNormal("sigma") + pm.Normal("y", mu=mu, sigma=sigma, observed=y) + + # Sample with and without var_names, but always with the same seed + with model: + idata_1 = pm.sample(tune=100, draws=100, random_seed=seed) + idata_2 = pm.sample( + tune=100, draws=100, var_names=["b_group", "b_x", "sigma"], random_seed=seed + ) + + assert "mu" in idata_1.posterior + assert "mu" not in idata_2.posterior + + assert np.all(idata_1.posterior["b_group"] == idata_2.posterior["b_group"]).item() + assert np.all(idata_1.posterior["b_x"] == idata_2.posterior["b_x"]).item() + assert np.all(idata_1.posterior["sigma"] == idata_2.posterior["sigma"]).item() class TestAssignStepMethods: