diff --git a/pyro/nn/module.py b/pyro/nn/module.py index 86bcff37d0..91a5e01460 100644 --- a/pyro/nn/module.py +++ b/pyro/nn/module.py @@ -233,29 +233,6 @@ def __get__( return value -class PyroSamplePlateScope(pyro.poutine.messenger.Messenger): - """ - Handler for executing PyroSample statements in a more intuitive plate context. - """ - def __init__(self, allowed_plates: Iterable[str] = ()): - self._inner_allowed_plates = frozenset(allowed_plates) - - def __enter__(self): - self._plates: frozenset[str] = frozenset(p.name for p in pyro.poutine.runtime.get_plates()) | self._inner_allowed_plates - return super().__enter__() - - def _is_local_plate(self, m: pyro.poutine.messenger.Messenger) -> bool: - return isinstance(m, pyro.poutine.plate_messenger.PlateMessenger) and m.name not in self._plates - - def _pyro_sample(self, msg): - if not msg["infer"].get("_is_global_sample", False): - return - msg["stop"] = True - msg["done"] = True - with pyro.poutine.messenger.block_messenger(lambda m: m is self or self._is_local_plate(m)): - msg["value"] = pyro.sample(msg["name"], msg["fn"], obs=msg["value"] if msg["is_observed"] else None, infer=msg["infer"]) - - def _make_name(prefix: str, name: str) -> str: return "{}.{}".format(prefix, name) if prefix else name @@ -639,7 +616,7 @@ def __getattr__(self, name: str) -> Any: value = ( pyro.deterministic(fullname, prior) if isinstance(prior, torch.Tensor) - else pyro.sample(fullname, prior, infer={"_is_global_sample": True}) + else pyro.sample(fullname, prior, infer={"_original_pyrosample_dist": prior}) ) context.set(fullname, value) return value diff --git a/pyro/poutine/plate_messenger.py b/pyro/poutine/plate_messenger.py index e1484324d6..8a76ec777c 100644 --- a/pyro/poutine/plate_messenger.py +++ b/pyro/poutine/plate_messenger.py @@ -2,8 +2,9 @@ # SPDX-License-Identifier: Apache-2.0 from contextlib import contextmanager -from typing import TYPE_CHECKING, Iterator, Optional +from typing import TYPE_CHECKING, Iterable, Iterator, Optional +import pyro from pyro.poutine.broadcast_messenger import BroadcastMessenger from pyro.poutine.messenger import Messenger, block_messengers from pyro.poutine.subsample_messenger import SubsampleMessenger @@ -88,3 +89,27 @@ def predicate(messenger: Messenger) -> bool: "setting strict=False." ) yield + + +class PyroSamplePlateScope(Messenger): + """ + Handler for executing PyroSample statements in a more intuitive plate context. + """ + def __init__(self, allowed_plates: Iterable[str] = ()): + self._inner_allowed_plates = frozenset(allowed_plates) + + def __enter__(self): + self._plates: frozenset[str] = frozenset(p.name for p in pyro.poutine.runtime.get_plates()) | self._inner_allowed_plates + return super().__enter__() + + def _is_local_plate(self, m: Messenger) -> bool: + return isinstance(m, PlateMessenger) and m.name not in self._plates + + def _pyro_sample(self, msg): + if not msg["infer"].get("_original_pyrosample_dist", None): + return + msg["stop"] = True + msg["done"] = True + with block_messengers(lambda m: m is self or self._is_local_plate(m)): + d = msg["infer"].pop("_original_pyrosample_dist") + msg["value"] = pyro.sample(msg["name"], d, obs=msg["value"] if msg["is_observed"] else None, infer=msg["infer"]) diff --git a/tests/nn/test_module.py b/tests/nn/test_module.py index 07c4daedd1..1403bf102d 100644 --- a/tests/nn/test_module.py +++ b/tests/nn/test_module.py @@ -1084,3 +1084,33 @@ def forward(self): with pyro.settings.context(module_local_params=use_module_local_params): model = Model() pyro.render_model(model) + + +def test_pyrosample_platescope(): + + class Model(pyro.nn.PyroModule): + def __init__(self, num_inputs, num_outputs): + super().__init__() + self.num_inputs = num_inputs + self.num_outputs = num_outputs + self.linear = pyro.nn.PyroModule[torch.nn.Linear](num_inputs, num_outputs) + self.linear.weight = pyro.nn.PyroSample(dist.Normal(0, 1).expand([num_outputs, num_inputs]).to_event(2)) + self.linear.bias = pyro.nn.PyroSample(dist.Normal(0, 1).expand([num_outputs]).to_event(1)) + + @pyro.nn.PyroSample + def scale(self): + return pyro.distributions.LogNormal(0, 1).expand([self.num_outputs]).to_event(1) + + @pyro.poutine.plate_messenger.PyroSamplePlateScope() + def forward(self, x): + with pyro.plate("data", x.shape[-2], dim=-1): + assert len(self.linear.weight.shape) == 2 or self.linear.weight.shape[-3] != 1 # sampled outside data plate + loc = self.linear(x) + assert len(self.scale.shape) == 1 or self.scale.shape[-2] == 1 # sampled outside data plate + y = pyro.sample("y", dist.Normal(loc, self.scale).to_event(1)) + assert y.shape[-2] == x.shape[-2] # ordinary pyro.sample statement + return y + + model = Model(3, 2) + x = torch.randn(4, 3) + assert model(x).shape == (4, 2) \ No newline at end of file