Skip to content

Commit

Permalink
Support observations that require broadcasting in SplitReparam.
Browse files Browse the repository at this point in the history
  • Loading branch information
Ben Zickel committed Aug 3, 2024
1 parent 871abb8 commit 8dda7b0
Show file tree
Hide file tree
Showing 3 changed files with 63 additions and 0 deletions.
2 changes: 2 additions & 0 deletions pyro/infer/reparam/split.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import pyro
import pyro.distributions as dist
import pyro.poutine as poutine
from pyro.ops.tensor_utils import broadcast_tensors_without_dim

from .reparam import Reparam

Expand Down Expand Up @@ -64,6 +65,7 @@ def apply(self, msg):

# Combine parts into value.
if value is None:
value_split = broadcast_tensors_without_dim(value_split, -self.event_dim)
value = torch.cat(value_split, dim=-self.event_dim)

if poutine.get_mask() is False:
Expand Down
32 changes: 32 additions & 0 deletions pyro/ops/tensor_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -470,3 +470,35 @@ def safe_normalize(x, *, p=2):
x = x / norm.clamp(min=torch.finfo(x.dtype).tiny)
x.data[..., 0][x.data.eq(0).all(dim=-1)] = 1 # Avoid the singularity.
return x


def broadcast_tensors_without_dim(tensors, dim):
"""
Broadcast tensors to the same shape without changing the size of
dimension ``dim`` of each tensor.
The broadcasting is performed in the same way as done in
:func:`torch.broadcast_tensors`, while leaving the size of
dimension ``dim`` of each tensor unchanged.
The returned tensors can be concatenated along the dimension ``dim``.
:param list tensors: List of `torch.Tensor` objects.
:param int dim: Dimension to leave out of broadcasting.
:returns: List of `torch.Tensor` objects.
"""
if dim >= 0:
shape_len = len(tensors[0].shape)
for tensor in tensors[1:]:
if len(tensor.shape) != shape_len:
raise ValueError(
"Dimension dim must be negative for different dimension tensors"
)
shapes = [list(tensor.shape) for tensor in tensors]
for shape in shapes:
shape[dim] = 1
shape = torch.broadcast_shapes(*shapes)
shapes = [list(shape) for count in range(len(tensors))]
for shape, tensor in zip(shapes, tensors):
shape[dim] = tensor.shape[dim]
return [tensor.expand(shape) for shape, tensor in zip(shapes, tensors)]
29 changes: 29 additions & 0 deletions tests/infer/reparam/test_split.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import pyro
import pyro.distributions as dist
from pyro import poutine
from pyro.infer.autoguide.initialization import InitMessenger, init_to_median
from pyro.infer.reparam import SplitReparam
from tests.common import assert_close

Expand Down Expand Up @@ -100,3 +101,31 @@ def model():
return pyro.sample("x", dist.Normal(loc, scale).to_event(len(event_shape)))

check_init_reparam(model, SplitReparam(splits, dim))


def test_observe():
def model():
x_dist = dist.TransformedDistribution(
dist.Normal(0, 1).expand((8,)).to_event(1), dist.transforms.HaarTransform()
)
return pyro.sample("x", x_dist)

# Build reparameterized model
rep = SplitReparam([6, 2], -1)
reparam_model = poutine.reparam(model, {"x": rep})

# Sample from the reparameterized model to create an observation
initialized_reparam_model = InitMessenger(init_to_median)(reparam_model)
trace = poutine.trace(initialized_reparam_model).get_trace()
observation = {"x_split_1": trace.nodes["x_split_1"]["value"]}

# Create a model conditioned on the observation
conditioned_reparam_model = poutine.condition(reparam_model, observation)

# Fit a guide for the conditioned model
guide = pyro.infer.autoguide.AutoMultivariateNormal(conditioned_reparam_model)
optim = pyro.optim.Adam(dict(lr=0.1))
loss = pyro.infer.Trace_ELBO(num_particles=20, vectorize_particles=True)
svi = pyro.infer.SVI(conditioned_reparam_model, guide, optim, loss)
for iter_count in range(10):
svi.step()

0 comments on commit 8dda7b0

Please sign in to comment.