Skip to content

Commit

Permalink
Make Predictive work with the SplitReparam reparameterizer [bugfix] (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
BenZickel committed Aug 4, 2024
1 parent a756d2f commit 6130da0
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 23 deletions.
8 changes: 6 additions & 2 deletions pyro/infer/predictive.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

import pyro
import pyro.poutine as poutine
from pyro.infer.autoguide.initialization import InitMessenger, init_to_sample
from pyro.infer.importance import LogWeightsMixin
from pyro.infer.util import CloneMixin, plate_log_prob_sum
from pyro.poutine.trace_struct import Trace
Expand Down Expand Up @@ -86,12 +87,15 @@ def _predictive(
mask=True,
):
model = torch.no_grad()(poutine.mask(model, mask=False) if mask else model)
max_plate_nesting = _guess_max_plate_nesting(model, model_args, model_kwargs)
initailized_model = InitMessenger(init_to_sample)(model)
max_plate_nesting = _guess_max_plate_nesting(
initailized_model, model_args, model_kwargs
)
vectorize = pyro.plate(
_predictive_vectorize_plate_name, num_samples, dim=-max_plate_nesting - 1
)
model_trace = prune_subsample_sites(
poutine.trace(model).get_trace(*model_args, **model_kwargs)
poutine.trace(initailized_model).get_trace(*model_args, **model_kwargs)
)
reshaped_samples = {}

Expand Down
77 changes: 56 additions & 21 deletions tests/infer/reparam/test_split.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,7 @@

from .util import check_init_reparam


@pytest.mark.parametrize(
event_shape_splits_dim = pytest.mark.parametrize(
"event_shape,splits,dim",
[
((6,), [2, 1, 3], -1),
Expand All @@ -31,7 +30,13 @@
],
ids=str,
)
@pytest.mark.parametrize("batch_shape", [(), (4,), (3, 2)], ids=str)


batch_shape = pytest.mark.parametrize("batch_shape", [(), (4,), (3, 2)], ids=str)


@event_shape_splits_dim
@batch_shape
def test_normal(batch_shape, event_shape, splits, dim):
shape = batch_shape + event_shape
loc = torch.empty(shape).uniform_(-1.0, 1.0).requires_grad_()
Expand Down Expand Up @@ -72,24 +77,8 @@ def model():
assert_close(actual_grads, expected_grads)


@pytest.mark.parametrize(
"event_shape,splits,dim",
[
((6,), [2, 1, 3], -1),
(
(
2,
5,
),
[2, 3],
-1,
),
((4, 2), [1, 3], -2),
((2, 3, 1), [1, 2], -2),
],
ids=str,
)
@pytest.mark.parametrize("batch_shape", [(), (4,), (3, 2)], ids=str)
@event_shape_splits_dim
@batch_shape
def test_init(batch_shape, event_shape, splits, dim):
shape = batch_shape + event_shape
loc = torch.empty(shape).uniform_(-1.0, 1.0)
Expand All @@ -100,3 +89,49 @@ def model():
return pyro.sample("x", dist.Normal(loc, scale).to_event(len(event_shape)))

check_init_reparam(model, SplitReparam(splits, dim))


@event_shape_splits_dim
@batch_shape
def test_predictive(batch_shape, event_shape, splits, dim):
shape = batch_shape + event_shape
loc = torch.empty(shape).uniform_(-1.0, 1.0)
scale = torch.empty(shape).uniform_(0.5, 1.5)

def model():
with pyro.plate_stack("plates", batch_shape):
pyro.sample("x", dist.Normal(loc, scale).to_event(len(event_shape)))

# Reparametrize model
rep = SplitReparam(splits, dim)
reparam_model = poutine.reparam(model, {"x": rep})

# Fit guide to reparametrized model
guide = pyro.infer.autoguide.guides.AutoMultivariateNormal(reparam_model)
optimizer = pyro.optim.Adam(dict(lr=0.01))
loss = pyro.infer.JitTrace_ELBO(
num_particles=20, vectorize_particles=True, ignore_jit_warnings=True
)
svi = pyro.infer.SVI(reparam_model, guide, optimizer, loss)
for count in range(1001):
loss = svi.step()
if count % 100 == 0:
print(f"iteration {count} loss = {loss}")

# Sample from model using the guide
num_samples = 100000
parallel = True
sites = ["x_split_{}".format(i) for i in range(len(splits))]
values = pyro.infer.Predictive(
reparam_model,
guide=guide,
num_samples=num_samples,
parallel=parallel,
return_sites=sites,
)()

# Verify sampling
mean = torch.cat([values[site].mean(0) for site in sites], dim=dim)
std = torch.cat([values[site].std(0) for site in sites], dim=dim)
assert_close(mean, loc, atol=0.1)
assert_close(std, scale, rtol=0.1)

0 comments on commit 6130da0

Please sign in to comment.