Skip to content

Commit

Permalink
Effect handler that conditions a model on sample sites having the sam…
Browse files Browse the repository at this point in the history
…e value (#3395)

* Add the keep distributions option to the EqualizeMessenger effect handler.

* Added tests for the keep distribution option of the EqualizeMessenger effect handler.

* exclude ipynb

* Make sample site deterministic only if the EqualizeMessenger effect handler does not keep its original distribution.

---------

Co-authored-by: Ben Zickel <[email protected]>
Co-authored-by: Yerdos Ordabayev <[email protected]>
  • Loading branch information
3 people committed Sep 20, 2024
1 parent 88ae262 commit b3c7851
Show file tree
Hide file tree
Showing 3 changed files with 75 additions and 4 deletions.
2 changes: 1 addition & 1 deletion pyro/ops/stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -511,7 +511,7 @@ def crps_empirical(pred, truth):


def energy_score_empirical(pred: torch.Tensor, truth: torch.Tensor) -> torch.Tensor:
"""
r"""
Computes negative Energy Score ES* (see equation 22 in [1]) between a
set of multivariate samples ``pred`` and a true data vector ``truth``. Running time
is quadratic in the number of samples ``n``. In case of univariate samples
Expand Down
33 changes: 30 additions & 3 deletions pyro/poutine/equalize_messenger.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,18 +38,42 @@ class EqualizeMessenger(Messenger):
>>> equal_std_param_model = pyro.poutine.equalize(equal_std_model, '.+_shift', 'param')
Alternatively, the ``equalize`` messenger can be used to condition a model on primitive statements
having the same value by setting `keep_dist` to True. Consider the below model:
>>> def model():
... x = pyro.sample('x', pyro.distributions.Normal(0, 1))
... y = pyro.sample('y', pyro.distributions.Normal(5, 3))
... return x, y
The model can be conditioned on 'x' and 'y' having the same value by
>>> conditioned_model = pyro.poutine.equalize(model, ['x', 'y'], keep_dist=True)
Note that the conditioned model defined above calculates the correct unnormalized log-probablity
density, but in order to correctly sample from it one must use SVI or MCMC techniques.
:param fn: a stochastic function (callable containing Pyro primitive calls)
:param sites: a string or list of strings to match site names (the strings can be regular expressions)
:param type: a string specifying the site type (default is 'sample')
:param bool keep_dist: Whether to keep the distributions of the second and subsequent
matching primitive statements. If kept this is equivalent to conditioning the model
on all matching primitive statements having the same value, as opposed to having the
second and subsequent matching primitive statements replaced by delta sampling functions.
Defaults to False.
:returns: stochastic function decorated with a :class:`~pyro.poutine.equalize_messenger.EqualizeMessenger`
"""

def __init__(
self, sites: Union[str, List[str]], type: Optional[str] = "sample"
self,
sites: Union[str, List[str]],
type: Optional[str] = "sample",
keep_dist: bool = False,
) -> None:
super().__init__()
self.sites = [sites] if isinstance(sites, str) else sites
self.type = type
self.keep_dist = keep_dist

def __enter__(self) -> Self:
self.value = None
Expand All @@ -72,6 +96,9 @@ def _process_message(self, msg: Message) -> None:
if self.value is not None and self._is_matching(msg): # type: ignore[unreachable]
msg["value"] = self.value # type: ignore[unreachable]
if msg["type"] == "sample":
msg["fn"] = Delta(self.value, event_dim=msg["fn"].event_dim).mask(False)
msg["infer"] = {"_deterministic": True}
msg["is_observed"] = True
if not self.keep_dist:
msg["infer"] = {"_deterministic": True}
msg["fn"] = Delta(self.value, event_dim=msg["fn"].event_dim).mask(
False
)
44 changes: 44 additions & 0 deletions tests/poutine/test_poutines.py
Original file line number Diff line number Diff line change
Expand Up @@ -805,6 +805,50 @@ def test_render_model(self):
pyro.render_model(model)


@pytest.mark.parametrize("keep_dist", [False, True])
@pytest.mark.parametrize(
"loc_x, scale_x, loc_y, scale_y", [(0.0, 1.0, 5.0, 2.0), (5.0, 2.0, 0.0, 1.0)]
)
def test_condition_by_equalize(loc_x, scale_x, loc_y, scale_y, keep_dist):
# Create model and equalize it.
def model():
x = pyro.sample("x", dist.Normal(loc_x, scale_x))
y = pyro.sample("y", dist.Normal(loc_y, scale_y))
return x, y

equalized_model = pyro.poutine.equalize(model, ["x", "y"], keep_dist=keep_dist)

# Fit guide to model
guide = pyro.infer.autoguide.AutoNormal(equalized_model)
optim = pyro.optim.Adam(dict(lr=0.1))
svi = pyro.infer.SVI(
equalized_model,
guide,
optim,
loss=pyro.infer.TraceGraph_ELBO(num_particles=1000, vectorize_particles=True),
)
for step_num in range(100):
svi.step()

# Get guide distribution parameters
loc, scale = guide._get_loc_and_scale("x")
loc = float(loc.detach().numpy())
scale = float(scale.detach().numpy())

# Verify against expected distribution parameters
if keep_dist:
# Both 'x' and 'y' are sampled and the model is conditioned on 'x' and 'y' having the same value.
expected_var = 1 / (1 / scale_x**2 + 1 / scale_y**2)
expected_loc = (loc_x / scale_x**2 + loc_y / scale_y**2) * expected_var
expected_scale = expected_var**0.5
else:
# The random variable 'x' is sampled and its value is assigned to 'y'.
expected_loc = loc_x
expected_scale = scale_x
assert_close(loc, expected_loc, atol=0.05)
assert_close(scale, expected_scale, atol=0.05)


@pytest.mark.parametrize("first_available_dim", [-1, -2, -3])
@pytest.mark.parametrize("depth", [0, 1, 2])
def test_enumerate_poutine(depth, first_available_dim):
Expand Down

0 comments on commit b3c7851

Please sign in to comment.