Skip to content

Commit

Permalink
passing test
Browse files Browse the repository at this point in the history
  • Loading branch information
eb8680 committed Jul 17, 2024
1 parent b42d20a commit 92791df
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 25 deletions.
25 changes: 1 addition & 24 deletions pyro/nn/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,29 +233,6 @@ def __get__(
return value


class PyroSamplePlateScope(pyro.poutine.messenger.Messenger):
"""
Handler for executing PyroSample statements in a more intuitive plate context.
"""
def __init__(self, allowed_plates: Iterable[str] = ()):
self._inner_allowed_plates = frozenset(allowed_plates)

def __enter__(self):
self._plates: frozenset[str] = frozenset(p.name for p in pyro.poutine.runtime.get_plates()) | self._inner_allowed_plates
return super().__enter__()

def _is_local_plate(self, m: pyro.poutine.messenger.Messenger) -> bool:
return isinstance(m, pyro.poutine.plate_messenger.PlateMessenger) and m.name not in self._plates

def _pyro_sample(self, msg):
if not msg["infer"].get("_is_global_sample", False):
return
msg["stop"] = True
msg["done"] = True
with pyro.poutine.messenger.block_messenger(lambda m: m is self or self._is_local_plate(m)):
msg["value"] = pyro.sample(msg["name"], msg["fn"], obs=msg["value"] if msg["is_observed"] else None, infer=msg["infer"])


def _make_name(prefix: str, name: str) -> str:
return "{}.{}".format(prefix, name) if prefix else name

Expand Down Expand Up @@ -639,7 +616,7 @@ def __getattr__(self, name: str) -> Any:
value = (
pyro.deterministic(fullname, prior)
if isinstance(prior, torch.Tensor)
else pyro.sample(fullname, prior, infer={"_is_global_sample": True})
else pyro.sample(fullname, prior, infer={"_original_pyrosample_dist": prior})
)
context.set(fullname, value)
return value
Expand Down
27 changes: 26 additions & 1 deletion pyro/poutine/plate_messenger.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@
# SPDX-License-Identifier: Apache-2.0

from contextlib import contextmanager
from typing import TYPE_CHECKING, Iterator, Optional
from typing import TYPE_CHECKING, Iterable, Iterator, Optional

import pyro
from pyro.poutine.broadcast_messenger import BroadcastMessenger
from pyro.poutine.messenger import Messenger, block_messengers
from pyro.poutine.subsample_messenger import SubsampleMessenger
Expand Down Expand Up @@ -88,3 +89,27 @@ def predicate(messenger: Messenger) -> bool:
"setting strict=False."
)
yield


class PyroSamplePlateScope(Messenger):
"""
Handler for executing PyroSample statements in a more intuitive plate context.
"""
def __init__(self, allowed_plates: Iterable[str] = ()):
self._inner_allowed_plates = frozenset(allowed_plates)

def __enter__(self):
self._plates: frozenset[str] = frozenset(p.name for p in pyro.poutine.runtime.get_plates()) | self._inner_allowed_plates
return super().__enter__()

def _is_local_plate(self, m: Messenger) -> bool:
return isinstance(m, PlateMessenger) and m.name not in self._plates

def _pyro_sample(self, msg):
if not msg["infer"].get("_original_pyrosample_dist", None):
return
msg["stop"] = True
msg["done"] = True
with block_messengers(lambda m: m is self or self._is_local_plate(m)):
d = msg["infer"].pop("_original_pyrosample_dist")
msg["value"] = pyro.sample(msg["name"], d, obs=msg["value"] if msg["is_observed"] else None, infer=msg["infer"])
30 changes: 30 additions & 0 deletions tests/nn/test_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -1084,3 +1084,33 @@ def forward(self):
with pyro.settings.context(module_local_params=use_module_local_params):
model = Model()
pyro.render_model(model)


def test_pyrosample_platescope():

class Model(pyro.nn.PyroModule):
def __init__(self, num_inputs, num_outputs):
super().__init__()
self.num_inputs = num_inputs
self.num_outputs = num_outputs
self.linear = pyro.nn.PyroModule[torch.nn.Linear](num_inputs, num_outputs)
self.linear.weight = pyro.nn.PyroSample(dist.Normal(0, 1).expand([num_outputs, num_inputs]).to_event(2))
self.linear.bias = pyro.nn.PyroSample(dist.Normal(0, 1).expand([num_outputs]).to_event(1))

@pyro.nn.PyroSample
def scale(self):
return pyro.distributions.LogNormal(0, 1).expand([self.num_outputs]).to_event(1)

@pyro.poutine.plate_messenger.PyroSamplePlateScope()
def forward(self, x):
with pyro.plate("data", x.shape[-2], dim=-1):
assert len(self.linear.weight.shape) == 2 or self.linear.weight.shape[-3] != 1 # sampled outside data plate
loc = self.linear(x)
assert len(self.scale.shape) == 1 or self.scale.shape[-2] == 1 # sampled outside data plate
y = pyro.sample("y", dist.Normal(loc, self.scale).to_event(1))
assert y.shape[-2] == x.shape[-2] # ordinary pyro.sample statement
return y

model = Model(3, 2)
x = torch.randn(4, 3)
assert model(x).shape == (4, 2)

0 comments on commit 92791df

Please sign in to comment.