Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rename samples argument to draws in sample_prior_predictive #7366

Merged
merged 4 commits into from
Jun 17, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 15 additions & 3 deletions pymc/sampling/forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,19 +338,20 @@ def observed_dependent_deterministics(model: Model):


def sample_prior_predictive(
samples: int = 500,
draws: int = 500,
model: Model | None = None,
var_names: Iterable[str] | None = None,
random_seed: RandomState = None,
return_inferencedata: bool = True,
idata_kwargs: dict | None = None,
compile_kwargs: dict | None = None,
samples: int | None = None,
) -> InferenceData | dict[str, np.ndarray]:
"""Generate samples from the prior predictive distribution.

Parameters
----------
samples : int
draws : int
Number of samples from the prior predictive to generate. Defaults to 500.
model : Model (optional if in ``with`` context)
var_names : Iterable[str]
Expand All @@ -366,13 +367,24 @@ def sample_prior_predictive(
Keyword arguments for :func:`pymc.to_inference_data`
compile_kwargs: dict, optional
Keyword arguments for :func:`pymc.pytensorf.compile_pymc`.
samples : int
Number of samples from the prior predictive to generate. Deprecated in favor of `draws`.

Returns
-------
arviz.InferenceData or Dict
An ArviZ ``InferenceData`` object containing the prior and prior predictive samples (default),
or a dictionary with variable names as keys and samples as numpy arrays.
"""
if samples is not None:
warnings.warn(
f"The samples argument has been deprecated in favor of draws. Use draws={samples} going forward.",
DeprecationWarning,
stacklevel=1,
)

draws = samples

model = modelcontext(model)

if model.potentials:
Expand Down Expand Up @@ -415,7 +427,7 @@ def sample_prior_predictive(

# All model variables have a name, but mypy does not know this
_log.info(f"Sampling: {list(sorted(volatile_basic_rvs, key=lambda var: var.name))}") # type: ignore
values = zip(*(sampler_fn() for i in range(samples)))
values = zip(*(sampler_fn() for i in range(draws)))

data = {k: np.stack(v) for k, v in zip(names, values)}
if data is None:
Expand Down
35 changes: 22 additions & 13 deletions tests/sampling/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -794,12 +794,12 @@ def test_logging_sampled_basic_rvs_prior(self, caplog):
z = pm.Normal("z", y, observed=0)

with m:
pm.sample_prior_predictive(samples=1)
pm.sample_prior_predictive(draws=1)
assert caplog.record_tuples == [("pymc.sampling.forward", logging.INFO, "Sampling: [x, z]")]
caplog.clear()

with m:
pm.sample_prior_predictive(samples=1, var_names=["x"])
pm.sample_prior_predictive(draws=1, var_names=["x"])
assert caplog.record_tuples == [("pymc.sampling.forward", logging.INFO, "Sampling: [x]")]
caplog.clear()

Expand Down Expand Up @@ -1028,7 +1028,7 @@ def test_observed_data_needed_in_pp(self):
mu = x_data.sum(-1)
pm.Normal("y", mu=mu, sigma=sigma, observed=y_data, shape=mu.shape, dims=("trial",))

prior = pm.sample_prior_predictive(samples=25).prior
prior = pm.sample_prior_predictive(draws=25).prior

fake_idata = InferenceData(posterior=prior)

Expand All @@ -1052,7 +1052,7 @@ def test_observed_data_needed_in_pp(self):
mu = (y_data.sum() * x_data).sum(-1)
pm.Normal("y", mu=mu, sigma=sigma, observed=y_data, shape=mu.shape, dims=("trial",))

prior = pm.sample_prior_predictive(samples=25).prior
prior = pm.sample_prior_predictive(draws=25).prior

fake_idata = InferenceData(posterior=prior)

Expand Down Expand Up @@ -1135,7 +1135,7 @@ def test_multivariate2(self, seeded_test):
compute_convergence_checks=False,
)
sim_priors = pm.sample_prior_predictive(
return_inferencedata=False, samples=20, model=dm_model
return_inferencedata=False, draws=20, model=dm_model
)
sim_ppc = pm.sample_posterior_predictive(
burned_trace, return_inferencedata=False, model=dm_model
Expand Down Expand Up @@ -1227,7 +1227,7 @@ def test_zeroinflatedpoisson(self):
mu = pm.Beta("mu", alpha=1, beta=1)
psi = pm.HalfNormal("psi", sigma=1)
pm.ZeroInflatedPoisson("suppliers", psi=psi, mu=mu, size=20)
gen_data = pm.sample_prior_predictive(samples=5000)
gen_data = pm.sample_prior_predictive(draws=5000)
assert gen_data.prior["mu"].shape == (1, 5000)
assert gen_data.prior["psi"].shape == (1, 5000)
assert gen_data.prior["suppliers"].shape == (1, 5000, 20)
Expand All @@ -1240,7 +1240,7 @@ def test_potentials_warning(self):

with m:
with pytest.warns(UserWarning, match=warning_msg):
pm.sample_prior_predictive(samples=5)
pm.sample_prior_predictive(draws=5)

def test_transformed_vars_not_supported(self):
with pm.Model() as model:
Expand All @@ -1260,7 +1260,7 @@ def test_issue_4490(self):
c = pm.Normal("c")
d = pm.Normal("d")
prior1 = pm.sample_prior_predictive(
samples=1, var_names=["a", "b", "c", "d"], random_seed=seed
draws=1, var_names=["a", "b", "c", "d"], random_seed=seed
)

with pm.Model() as m2:
Expand All @@ -1269,7 +1269,7 @@ def test_issue_4490(self):
c = pm.Normal("c")
d = pm.Normal("d")
prior2 = pm.sample_prior_predictive(
samples=1, var_names=["b", "a", "d", "c"], random_seed=seed
draws=1, var_names=["b", "a", "d", "c"], random_seed=seed
)

assert prior1.prior["a"] == prior2.prior["a"]
Expand All @@ -1284,7 +1284,7 @@ def test_pytensor_function_kwargs(self):
y = pm.Deterministic("y", x + sharedvar)

prior = pm.sample_prior_predictive(
samples=5,
draws=5,
return_inferencedata=False,
compile_kwargs=dict(
mode=Mode("py"),
Expand All @@ -1308,7 +1308,7 @@ def test_sample_from_xarray_prior(self, point_list_arg_bug_fixture):

with pmodel:
prior = pm.sample_prior_predictive(
samples=20,
draws=20,
return_inferencedata=False,
)
idat = pm.to_inference_data(trace, prior=prior)
Expand Down Expand Up @@ -1367,7 +1367,7 @@ def test_distinct_rvs():
Y_rv = pm.Normal("y")

pp_samples = pm.sample_prior_predictive(
samples=2, return_inferencedata=False, random_seed=npr.RandomState(2023532)
draws=2, return_inferencedata=False, random_seed=npr.RandomState(2023532)
)

assert X_rv.owner.inputs[0] != Y_rv.owner.inputs[0]
Expand All @@ -1377,7 +1377,7 @@ def test_distinct_rvs():
Y_rv = pm.Normal("y")

pp_samples_2 = pm.sample_prior_predictive(
samples=2, return_inferencedata=False, random_seed=npr.RandomState(2023532)
draws=2, return_inferencedata=False, random_seed=npr.RandomState(2023532)
)

assert np.array_equal(pp_samples["y"], pp_samples_2["y"])
Expand Down Expand Up @@ -1706,3 +1706,12 @@ def test_observed_dependent_deterministics():
det_mixed = pm.Deterministic("det_mixed", free + obs)

assert set(observed_dependent_deterministics(m)) == {det_obs, det_obs2, det_mixed}


def test_sample_prior_predictive_samples_deprecated_warns() -> None:
with pm.Model() as m:
pm.Normal("a")

match = "The samples argument has been deprecated"
with pytest.warns(DeprecationWarning, match=match):
pm.sample_prior_predictive(model=m, samples=10)
Loading