diff --git a/pyro/infer/reparam/split.py b/pyro/infer/reparam/split.py index 83f2224263..e3abf8d74a 100644 --- a/pyro/infer/reparam/split.py +++ b/pyro/infer/reparam/split.py @@ -7,6 +7,7 @@ import pyro.distributions as dist import pyro.poutine as poutine from pyro.distributions.torch_distribution import TorchDistributionMixin +from pyro.ops.tensor_utils import broadcast_tensors_without_dim from .reparam import Reparam @@ -128,6 +129,7 @@ def apply(self, msg): # Combine parts into value. if value is None: + value_split = broadcast_tensors_without_dim(value_split, -self.event_dim) value = torch.cat(value_split, dim=-self.event_dim) if poutine.get_mask() is False: diff --git a/pyro/ops/tensor_utils.py b/pyro/ops/tensor_utils.py index 17cb367a17..9e644a9886 100644 --- a/pyro/ops/tensor_utils.py +++ b/pyro/ops/tensor_utils.py @@ -470,3 +470,35 @@ def safe_normalize(x, *, p=2): x = x / norm.clamp(min=torch.finfo(x.dtype).tiny) x.data[..., 0][x.data.eq(0).all(dim=-1)] = 1 # Avoid the singularity. return x + + +def broadcast_tensors_without_dim(tensors, dim): + """ + Broadcast tensors to the same shape without changing the size of + dimension ``dim`` of each tensor. + + The broadcasting is performed in the same way as done in + :func:`torch.broadcast_tensors`, while leaving the size of + dimension ``dim`` of each tensor unchanged. + + The returned tensors can be concatenated along the dimension ``dim``. + + :param list tensors: List of `torch.Tensor` objects. + :param int dim: Dimension to leave out of broadcasting. + :returns: List of `torch.Tensor` objects. + """ + if dim >= 0: + shape_len = len(tensors[0].shape) + for tensor in tensors[1:]: + if len(tensor.shape) != shape_len: + raise ValueError( + "Dimension dim must be negative for different dimension tensors" + ) + shapes = [list(tensor.shape) for tensor in tensors] + for shape in shapes: + shape[dim] = 1 + shape = torch.broadcast_shapes(*shapes) + shapes = [list(shape) for count in range(len(tensors))] + for shape, tensor in zip(shapes, tensors): + shape[dim] = tensor.shape[dim] + return [tensor.expand(shape) for shape, tensor in zip(shapes, tensors)] diff --git a/tests/infer/reparam/test_split.py b/tests/infer/reparam/test_split.py index b3f43bc5f6..86d7c91546 100644 --- a/tests/infer/reparam/test_split.py +++ b/tests/infer/reparam/test_split.py @@ -8,6 +8,7 @@ import pyro import pyro.distributions as dist from pyro import poutine +from pyro.infer.autoguide.initialization import InitMessenger, init_to_median from pyro.infer.reparam import SplitReparam from tests.common import assert_close @@ -91,6 +92,34 @@ def model(): check_init_reparam(model, SplitReparam(splits, dim)) +def test_observe(): + def model(): + x_dist = dist.TransformedDistribution( + dist.Normal(0, 1).expand((8,)).to_event(1), dist.transforms.HaarTransform() + ) + return pyro.sample("x", x_dist) + + # Build reparameterized model + rep = SplitReparam([6, 2], -1) + reparam_model = poutine.reparam(model, {"x": rep}) + + # Sample from the reparameterized model to create an observation + initialized_reparam_model = InitMessenger(init_to_median)(reparam_model) + trace = poutine.trace(initialized_reparam_model).get_trace() + observation = {"x_split_1": trace.nodes["x_split_1"]["value"]} + + # Create a model conditioned on the observation + conditioned_reparam_model = poutine.condition(reparam_model, observation) + + # Fit a guide for the conditioned model + guide = pyro.infer.autoguide.AutoMultivariateNormal(conditioned_reparam_model) + optim = pyro.optim.Adam(dict(lr=0.1)) + loss = pyro.infer.Trace_ELBO(num_particles=20, vectorize_particles=True) + svi = pyro.infer.SVI(conditioned_reparam_model, guide, optim, loss) + for iter_count in range(10): + svi.step() + + @batch_shape def test_transformed_distribution(batch_shape): num_samples = 10