Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rename samples argument to draws in sample_prior_predictive #7366

Merged
merged 4 commits into from
Jun 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions docs/source/learn/core_notebooks/posterior_predictive.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@
" sigma = pm.Exponential(\"sigma\", 1.0)\n",
"\n",
" pm.Normal(\"obs\", mu=mu, sigma=sigma, observed=outcome_scaled)\n",
" idata = pm.sample_prior_predictive(samples=50, random_seed=rng)"
" idata = pm.sample_prior_predictive(draws=50, random_seed=rng)"
]
},
{
Expand Down Expand Up @@ -225,7 +225,7 @@
" sigma = pm.Exponential(\"sigma\", 1.0)\n",
"\n",
" pm.Normal(\"obs\", mu=mu, sigma=sigma, observed=outcome_scaled)\n",
" idata = pm.sample_prior_predictive(samples=50, random_seed=rng)"
" idata = pm.sample_prior_predictive(draws=50, random_seed=rng)"
]
},
{
Expand Down
18 changes: 15 additions & 3 deletions pymc/sampling/forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,19 +338,20 @@ def observed_dependent_deterministics(model: Model):


def sample_prior_predictive(
samples: int = 500,
draws: int = 500,
model: Model | None = None,
var_names: Iterable[str] | None = None,
random_seed: RandomState = None,
return_inferencedata: bool = True,
idata_kwargs: dict | None = None,
compile_kwargs: dict | None = None,
samples: int | None = None,
) -> InferenceData | dict[str, np.ndarray]:
"""Generate samples from the prior predictive distribution.

Parameters
----------
samples : int
draws : int
Number of samples from the prior predictive to generate. Defaults to 500.
model : Model (optional if in ``with`` context)
var_names : Iterable[str]
Expand All @@ -366,13 +367,24 @@ def sample_prior_predictive(
Keyword arguments for :func:`pymc.to_inference_data`
compile_kwargs: dict, optional
Keyword arguments for :func:`pymc.pytensorf.compile_pymc`.
samples : int
Number of samples from the prior predictive to generate. Deprecated in favor of `draws`.

Returns
-------
arviz.InferenceData or Dict
An ArviZ ``InferenceData`` object containing the prior and prior predictive samples (default),
or a dictionary with variable names as keys and samples as numpy arrays.
"""
if samples is not None:
warnings.warn(
f"The samples argument has been deprecated in favor of draws. Use draws={samples} going forward.",
DeprecationWarning,
stacklevel=1,
)

draws = samples

model = modelcontext(model)

if model.potentials:
Expand Down Expand Up @@ -415,7 +427,7 @@ def sample_prior_predictive(

# All model variables have a name, but mypy does not know this
_log.info(f"Sampling: {list(sorted(volatile_basic_rvs, key=lambda var: var.name))}") # type: ignore
values = zip(*(sampler_fn() for i in range(samples)))
values = zip(*(sampler_fn() for i in range(draws)))

data = {k: np.stack(v) for k, v in zip(names, values)}
if data is None:
Expand Down
10 changes: 5 additions & 5 deletions tests/distributions/test_mixture.py
Original file line number Diff line number Diff line change
Expand Up @@ -561,7 +561,7 @@ def test_single_poisson_predictive_sampling_shape(self):

n_samples = 30
with model:
prior = sample_prior_predictive(samples=n_samples, return_inferencedata=False)
prior = sample_prior_predictive(draws=n_samples, return_inferencedata=False)
ppc = sample_posterior_predictive(
n_samples * [self.get_initial_point(model)], return_inferencedata=False
)
Expand Down Expand Up @@ -607,7 +607,7 @@ def test_list_mvnormals_predictive_sampling_shape(self):

n_samples = 20
with model:
prior = sample_prior_predictive(samples=n_samples, return_inferencedata=False)
prior = sample_prior_predictive(draws=n_samples, return_inferencedata=False)
ppc = sample_posterior_predictive(
n_samples * [self.get_initial_point(model)], return_inferencedata=False
)
Expand Down Expand Up @@ -1028,7 +1028,7 @@ def test_with_multinomial(self, seeded_test, batch_shape):
comp_dists=comp_dists,
shape=(*batch_shape, 3),
)
prior = sample_prior_predictive(samples=self.n_samples, return_inferencedata=False)
prior = sample_prior_predictive(draws=self.n_samples, return_inferencedata=False)

assert prior["mixture"].shape == (self.n_samples, *batch_shape, 3)
assert draw(mixture, draws=self.size).shape == (self.size, *batch_shape, 3)
Expand Down Expand Up @@ -1060,7 +1060,7 @@ def test_with_mvnormal(self, seeded_test):
with Model() as model:
comp_dists = MvNormal.dist(mu=mu, chol=chol, shape=(self.mixture_comps, 3))
mixture = Mixture("mixture", w=w, comp_dists=comp_dists, shape=(3,))
prior = sample_prior_predictive(samples=self.n_samples, return_inferencedata=False)
prior = sample_prior_predictive(draws=self.n_samples, return_inferencedata=False)

assert prior["mixture"].shape == (self.n_samples, 3)
assert draw(mixture, draws=self.size).shape == (self.size, 3)
Expand All @@ -1084,7 +1084,7 @@ def test_broadcasting_in_shape(self):
mu = Gamma("mu", 1.0, 1.0, shape=2)
comp_dists = Poisson.dist(mu, shape=2)
mix = Mixture("mix", w=np.ones(2) / 2, comp_dists=comp_dists, shape=(1000,))
prior = sample_prior_predictive(samples=self.n_samples, return_inferencedata=False)
prior = sample_prior_predictive(draws=self.n_samples, return_inferencedata=False)

assert prior["mix"].shape == (self.n_samples, 1000)

Expand Down
6 changes: 3 additions & 3 deletions tests/distributions/test_multivariate.py
Original file line number Diff line number Diff line change
Expand Up @@ -1448,7 +1448,7 @@ def test_with_chol_rv(self):
"chol_cov", n=3, eta=2, sd_dist=sd_dist, compute_corr=True
)
mv = pm.MvNormal("mv", mu, chol=chol, size=4)
prior = pm.sample_prior_predictive(samples=10, return_inferencedata=False)
prior = pm.sample_prior_predictive(draws=10, return_inferencedata=False)

assert prior["mv"].shape == (10, 4, 3)

Expand All @@ -1462,7 +1462,7 @@ def test_with_cov_rv(
"chol_cov", n=3, eta=2, sd_dist=sd_dist, compute_corr=True
)
mv = pm.MvNormal("mv", mu, cov=pm.math.dot(chol, chol.T), size=4)
prior = pm.sample_prior_predictive(samples=10, return_inferencedata=False)
prior = pm.sample_prior_predictive(draws=10, return_inferencedata=False)

assert prior["mv"].shape == (10, 4, 3)

Expand All @@ -1473,7 +1473,7 @@ def test_with_lkjcorr_matrix(
corr = pm.LKJCorr("corr", n=3, eta=2, return_matrix=True)
pm.Deterministic("corr_mat", corr)
mv = pm.MvNormal("mv", 0.0, cov=corr, size=4)
prior = pm.sample_prior_predictive(samples=10, return_inferencedata=False)
prior = pm.sample_prior_predictive(draws=10, return_inferencedata=False)

assert prior["corr_mat"].shape == (10, 3, 3) # square
assert (prior["corr_mat"][:, [0, 1, 2], [0, 1, 2]] == 1.0).all() # 1.0 on diagonal
Expand Down
8 changes: 4 additions & 4 deletions tests/gp/test_hsgp_approx.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ def test_prior(self, model, cov_func, X1, parametrization, rng):
gp = pm.gp.Latent(cov_func=cov_func)
f2 = gp.prior("f2", X=X1)

idata = pm.sample_prior_predictive(samples=1000, random_seed=rng)
idata = pm.sample_prior_predictive(draws=1000, random_seed=rng)

samples1 = az.extract(idata.prior["f1"])["f1"].values.T
samples2 = az.extract(idata.prior["f2"])["f2"].values.T
Expand All @@ -240,7 +240,7 @@ def test_conditional(self, model, cov_func, X1, parametrization):
f = hsgp.prior("f", X=X1)
fc = hsgp.conditional("fc", Xnew=X1)

idata = pm.sample_prior_predictive(samples=1000)
idata = pm.sample_prior_predictive(draws=1000)

samples1 = az.extract(idata.prior["f"])["f"].values.T
samples2 = az.extract(idata.prior["fc"])["fc"].values.T
Expand Down Expand Up @@ -300,7 +300,7 @@ def test_prior(self, model, cov_func, eta, X1, rng):
gp = pm.gp.Latent(cov_func=eta**2 * cov_func)
f2 = gp.prior("f2", X=X1)

idata = pm.sample_prior_predictive(samples=1000, random_seed=rng)
idata = pm.sample_prior_predictive(draws=1000, random_seed=rng)

samples1 = az.extract(idata.prior["f1"])["f1"].values.T
samples2 = az.extract(idata.prior["f2"])["f2"].values.T
Expand All @@ -321,7 +321,7 @@ def test_conditional_periodic(self, model, cov_func, X1):
f = hsgp.prior("f", X=X1)
fc = hsgp.conditional("fc", Xnew=X1)

idata = pm.sample_prior_predictive(samples=1000)
idata = pm.sample_prior_predictive(draws=1000)

samples1 = az.extract(idata.prior["f"])["f"].values.T
samples2 = az.extract(idata.prior["fc"])["fc"].values.T
Expand Down
2 changes: 1 addition & 1 deletion tests/model/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -873,7 +873,7 @@ def test_none_coords_autonumbering(self):
m.add_coord(name="a", values=None, length=3)
m.add_coord(name="b", values=range(5))
x = pm.Normal("x", dims=("a", "b"))
prior = pm.sample_prior_predictive(samples=2).prior
prior = pm.sample_prior_predictive(draws=2).prior
assert prior["x"].shape == (1, 2, 3, 5)
assert list(prior.coords["a"].values) == list(range(3))
assert list(prior.coords["b"].values) == list(range(5))
Expand Down
2 changes: 1 addition & 1 deletion tests/sampling/test_deterministic.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def test_compute_deterministics():
sigma = Deterministic("sigma", sigma_raw.exp())

dataset = sample_prior_predictive(
samples=5, model=m, var_names=["mu_raw", "sigma_raw"], random_seed=22
draws=5, model=m, var_names=["mu_raw", "sigma_raw"], random_seed=22
).prior

# Test default
Expand Down
35 changes: 22 additions & 13 deletions tests/sampling/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -794,12 +794,12 @@ def test_logging_sampled_basic_rvs_prior(self, caplog):
z = pm.Normal("z", y, observed=0)

with m:
pm.sample_prior_predictive(samples=1)
pm.sample_prior_predictive(draws=1)
assert caplog.record_tuples == [("pymc.sampling.forward", logging.INFO, "Sampling: [x, z]")]
caplog.clear()

with m:
pm.sample_prior_predictive(samples=1, var_names=["x"])
pm.sample_prior_predictive(draws=1, var_names=["x"])
assert caplog.record_tuples == [("pymc.sampling.forward", logging.INFO, "Sampling: [x]")]
caplog.clear()

Expand Down Expand Up @@ -1028,7 +1028,7 @@ def test_observed_data_needed_in_pp(self):
mu = x_data.sum(-1)
pm.Normal("y", mu=mu, sigma=sigma, observed=y_data, shape=mu.shape, dims=("trial",))

prior = pm.sample_prior_predictive(samples=25).prior
prior = pm.sample_prior_predictive(draws=25).prior

fake_idata = InferenceData(posterior=prior)

Expand All @@ -1052,7 +1052,7 @@ def test_observed_data_needed_in_pp(self):
mu = (y_data.sum() * x_data).sum(-1)
pm.Normal("y", mu=mu, sigma=sigma, observed=y_data, shape=mu.shape, dims=("trial",))

prior = pm.sample_prior_predictive(samples=25).prior
prior = pm.sample_prior_predictive(draws=25).prior

fake_idata = InferenceData(posterior=prior)

Expand Down Expand Up @@ -1135,7 +1135,7 @@ def test_multivariate2(self, seeded_test):
compute_convergence_checks=False,
)
sim_priors = pm.sample_prior_predictive(
return_inferencedata=False, samples=20, model=dm_model
return_inferencedata=False, draws=20, model=dm_model
)
sim_ppc = pm.sample_posterior_predictive(
burned_trace, return_inferencedata=False, model=dm_model
Expand Down Expand Up @@ -1227,7 +1227,7 @@ def test_zeroinflatedpoisson(self):
mu = pm.Beta("mu", alpha=1, beta=1)
psi = pm.HalfNormal("psi", sigma=1)
pm.ZeroInflatedPoisson("suppliers", psi=psi, mu=mu, size=20)
gen_data = pm.sample_prior_predictive(samples=5000)
gen_data = pm.sample_prior_predictive(draws=5000)
assert gen_data.prior["mu"].shape == (1, 5000)
assert gen_data.prior["psi"].shape == (1, 5000)
assert gen_data.prior["suppliers"].shape == (1, 5000, 20)
Expand All @@ -1240,7 +1240,7 @@ def test_potentials_warning(self):

with m:
with pytest.warns(UserWarning, match=warning_msg):
pm.sample_prior_predictive(samples=5)
pm.sample_prior_predictive(draws=5)

def test_transformed_vars_not_supported(self):
with pm.Model() as model:
Expand All @@ -1260,7 +1260,7 @@ def test_issue_4490(self):
c = pm.Normal("c")
d = pm.Normal("d")
prior1 = pm.sample_prior_predictive(
samples=1, var_names=["a", "b", "c", "d"], random_seed=seed
draws=1, var_names=["a", "b", "c", "d"], random_seed=seed
)

with pm.Model() as m2:
Expand All @@ -1269,7 +1269,7 @@ def test_issue_4490(self):
c = pm.Normal("c")
d = pm.Normal("d")
prior2 = pm.sample_prior_predictive(
samples=1, var_names=["b", "a", "d", "c"], random_seed=seed
draws=1, var_names=["b", "a", "d", "c"], random_seed=seed
)

assert prior1.prior["a"] == prior2.prior["a"]
Expand All @@ -1284,7 +1284,7 @@ def test_pytensor_function_kwargs(self):
y = pm.Deterministic("y", x + sharedvar)

prior = pm.sample_prior_predictive(
samples=5,
draws=5,
return_inferencedata=False,
compile_kwargs=dict(
mode=Mode("py"),
Expand All @@ -1308,7 +1308,7 @@ def test_sample_from_xarray_prior(self, point_list_arg_bug_fixture):

with pmodel:
prior = pm.sample_prior_predictive(
samples=20,
draws=20,
return_inferencedata=False,
)
idat = pm.to_inference_data(trace, prior=prior)
Expand Down Expand Up @@ -1367,7 +1367,7 @@ def test_distinct_rvs():
Y_rv = pm.Normal("y")

pp_samples = pm.sample_prior_predictive(
samples=2, return_inferencedata=False, random_seed=npr.RandomState(2023532)
draws=2, return_inferencedata=False, random_seed=npr.RandomState(2023532)
)

assert X_rv.owner.inputs[0] != Y_rv.owner.inputs[0]
Expand All @@ -1377,7 +1377,7 @@ def test_distinct_rvs():
Y_rv = pm.Normal("y")

pp_samples_2 = pm.sample_prior_predictive(
samples=2, return_inferencedata=False, random_seed=npr.RandomState(2023532)
draws=2, return_inferencedata=False, random_seed=npr.RandomState(2023532)
)

assert np.array_equal(pp_samples["y"], pp_samples_2["y"])
Expand Down Expand Up @@ -1706,3 +1706,12 @@ def test_observed_dependent_deterministics():
det_mixed = pm.Deterministic("det_mixed", free + obs)

assert set(observed_dependent_deterministics(m)) == {det_obs, det_obs2, det_mixed}


def test_sample_prior_predictive_samples_deprecated_warns() -> None:
with pm.Model() as m:
pm.Normal("a")

match = "The samples argument has been deprecated"
with pytest.warns(DeprecationWarning, match=match):
pm.sample_prior_predictive(model=m, samples=10)
Loading