Skip to content

Commit

Permalink
Add test to compare draws when var_names is used in pm.sample() (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
tomicapretto authored Apr 28, 2024
1 parent a74c03f commit 29fd732
Showing 1 changed file with 33 additions and 6 deletions.
39 changes: 33 additions & 6 deletions tests/sampling/test_mcmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -695,12 +695,39 @@ def test_no_init_nuts_compound(caplog):


def test_sample_var_names():
with pm.Model() as model:
a = pm.Normal("a")
b = pm.Deterministic("b", a**2)
idata = pm.sample(10, tune=10, var_names=["a"])
assert "a" in idata.posterior
assert "b" not in idata.posterior
# Generate data
seed = 1234
rng = np.random.default_rng(seed)

group = rng.choice(list("ABCD"), size=100)
x = rng.normal(size=100)
y = rng.normal(size=100)

group_values, group_idx = np.unique(group, return_inverse=True)

coords = {"group": group_values}

# Create model
with pm.Model(coords=coords) as model:
b_group = pm.Normal("b_group", dims="group")
b_x = pm.Normal("b_x")
mu = pm.Deterministic("mu", b_group[group_idx] + b_x * x)
sigma = pm.HalfNormal("sigma")
pm.Normal("y", mu=mu, sigma=sigma, observed=y)

# Sample with and without var_names, but always with the same seed
with model:
idata_1 = pm.sample(tune=100, draws=100, random_seed=seed)
idata_2 = pm.sample(
tune=100, draws=100, var_names=["b_group", "b_x", "sigma"], random_seed=seed
)

assert "mu" in idata_1.posterior
assert "mu" not in idata_2.posterior

assert np.all(idata_1.posterior["b_group"] == idata_2.posterior["b_group"]).item()
assert np.all(idata_1.posterior["b_x"] == idata_2.posterior["b_x"]).item()
assert np.all(idata_1.posterior["sigma"] == idata_2.posterior["sigma"]).item()


class TestAssignStepMethods:
Expand Down

0 comments on commit 29fd732

Please sign in to comment.