From b3c78513ed27a60d6e5822ca2ae9159dadca046e Mon Sep 17 00:00:00 2001 From: Ben Zickel <35469979+BenZickel@users.noreply.github.com> Date: Fri, 20 Sep 2024 21:18:20 +0300 Subject: [PATCH] Effect handler that conditions a model on sample sites having the same value (#3395) * Add the keep distributions option to the EqualizeMessenger effect handler. * Added tests for the keep distribution option of the EqualizeMessenger effect handler. * exclude ipynb * Make sample site deterministic only if the EqualizeMessenger effect handler does not keep its original distribution. --------- Co-authored-by: Ben Zickel Co-authored-by: Yerdos Ordabayev --- pyro/ops/stats.py | 2 +- pyro/poutine/equalize_messenger.py | 33 ++++++++++++++++++++-- tests/poutine/test_poutines.py | 44 ++++++++++++++++++++++++++++++ 3 files changed, 75 insertions(+), 4 deletions(-) diff --git a/pyro/ops/stats.py b/pyro/ops/stats.py index efa60134e5..a0a546059a 100644 --- a/pyro/ops/stats.py +++ b/pyro/ops/stats.py @@ -511,7 +511,7 @@ def crps_empirical(pred, truth): def energy_score_empirical(pred: torch.Tensor, truth: torch.Tensor) -> torch.Tensor: - """ + r""" Computes negative Energy Score ES* (see equation 22 in [1]) between a set of multivariate samples ``pred`` and a true data vector ``truth``. Running time is quadratic in the number of samples ``n``. In case of univariate samples diff --git a/pyro/poutine/equalize_messenger.py b/pyro/poutine/equalize_messenger.py index e17693267b..1bc79a5521 100644 --- a/pyro/poutine/equalize_messenger.py +++ b/pyro/poutine/equalize_messenger.py @@ -38,18 +38,42 @@ class EqualizeMessenger(Messenger): >>> equal_std_param_model = pyro.poutine.equalize(equal_std_model, '.+_shift', 'param') + Alternatively, the ``equalize`` messenger can be used to condition a model on primitive statements + having the same value by setting `keep_dist` to True. Consider the below model: + + >>> def model(): + ... x = pyro.sample('x', pyro.distributions.Normal(0, 1)) + ... y = pyro.sample('y', pyro.distributions.Normal(5, 3)) + ... return x, y + + The model can be conditioned on 'x' and 'y' having the same value by + + >>> conditioned_model = pyro.poutine.equalize(model, ['x', 'y'], keep_dist=True) + + Note that the conditioned model defined above calculates the correct unnormalized log-probablity + density, but in order to correctly sample from it one must use SVI or MCMC techniques. + :param fn: a stochastic function (callable containing Pyro primitive calls) :param sites: a string or list of strings to match site names (the strings can be regular expressions) :param type: a string specifying the site type (default is 'sample') + :param bool keep_dist: Whether to keep the distributions of the second and subsequent + matching primitive statements. If kept this is equivalent to conditioning the model + on all matching primitive statements having the same value, as opposed to having the + second and subsequent matching primitive statements replaced by delta sampling functions. + Defaults to False. :returns: stochastic function decorated with a :class:`~pyro.poutine.equalize_messenger.EqualizeMessenger` """ def __init__( - self, sites: Union[str, List[str]], type: Optional[str] = "sample" + self, + sites: Union[str, List[str]], + type: Optional[str] = "sample", + keep_dist: bool = False, ) -> None: super().__init__() self.sites = [sites] if isinstance(sites, str) else sites self.type = type + self.keep_dist = keep_dist def __enter__(self) -> Self: self.value = None @@ -72,6 +96,9 @@ def _process_message(self, msg: Message) -> None: if self.value is not None and self._is_matching(msg): # type: ignore[unreachable] msg["value"] = self.value # type: ignore[unreachable] if msg["type"] == "sample": - msg["fn"] = Delta(self.value, event_dim=msg["fn"].event_dim).mask(False) - msg["infer"] = {"_deterministic": True} msg["is_observed"] = True + if not self.keep_dist: + msg["infer"] = {"_deterministic": True} + msg["fn"] = Delta(self.value, event_dim=msg["fn"].event_dim).mask( + False + ) diff --git a/tests/poutine/test_poutines.py b/tests/poutine/test_poutines.py index 751cfecf1e..311837bc87 100644 --- a/tests/poutine/test_poutines.py +++ b/tests/poutine/test_poutines.py @@ -805,6 +805,50 @@ def test_render_model(self): pyro.render_model(model) +@pytest.mark.parametrize("keep_dist", [False, True]) +@pytest.mark.parametrize( + "loc_x, scale_x, loc_y, scale_y", [(0.0, 1.0, 5.0, 2.0), (5.0, 2.0, 0.0, 1.0)] +) +def test_condition_by_equalize(loc_x, scale_x, loc_y, scale_y, keep_dist): + # Create model and equalize it. + def model(): + x = pyro.sample("x", dist.Normal(loc_x, scale_x)) + y = pyro.sample("y", dist.Normal(loc_y, scale_y)) + return x, y + + equalized_model = pyro.poutine.equalize(model, ["x", "y"], keep_dist=keep_dist) + + # Fit guide to model + guide = pyro.infer.autoguide.AutoNormal(equalized_model) + optim = pyro.optim.Adam(dict(lr=0.1)) + svi = pyro.infer.SVI( + equalized_model, + guide, + optim, + loss=pyro.infer.TraceGraph_ELBO(num_particles=1000, vectorize_particles=True), + ) + for step_num in range(100): + svi.step() + + # Get guide distribution parameters + loc, scale = guide._get_loc_and_scale("x") + loc = float(loc.detach().numpy()) + scale = float(scale.detach().numpy()) + + # Verify against expected distribution parameters + if keep_dist: + # Both 'x' and 'y' are sampled and the model is conditioned on 'x' and 'y' having the same value. + expected_var = 1 / (1 / scale_x**2 + 1 / scale_y**2) + expected_loc = (loc_x / scale_x**2 + loc_y / scale_y**2) * expected_var + expected_scale = expected_var**0.5 + else: + # The random variable 'x' is sampled and its value is assigned to 'y'. + expected_loc = loc_x + expected_scale = scale_x + assert_close(loc, expected_loc, atol=0.05) + assert_close(scale, expected_scale, atol=0.05) + + @pytest.mark.parametrize("first_available_dim", [-1, -2, -3]) @pytest.mark.parametrize("depth", [0, 1, 2]) def test_enumerate_poutine(depth, first_available_dim):