Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for transformed distributions, based on stacking or concatenation transforms, in SplitReparam #3390

Merged
merged 5 commits into from
Aug 4, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 66 additions & 2 deletions pyro/infer/reparam/split.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,61 @@
import pyro
import pyro.distributions as dist
import pyro.poutine as poutine
from pyro.distributions.torch_distribution import TorchDistributionMixin

from .reparam import Reparam


def same_support(fn: TorchDistributionMixin, *args):
"""
Returns support of the `fn` distribution. Used in :class:`SplitReparam` in
order to determine the support of the split value.

:param fn: distribution class
:returns: distribution support
"""
return fn.support


def real_support(fn: TorchDistributionMixin, *args):
"""
Returns real support with same event dimension as that of the `fn` distribution.
Used in :class:`SplitReparam` in order to determine the support of the split value.

:param fn: distribution class
:returns: distribution support
"""
return dist.constraints.independent(dist.constraints.real, fn.event_dim)


def default_support(fn: TorchDistributionMixin, slice, dim):
"""
Returns support of the `fn` distribution, corrected for split stacking and
concatenation transforms. Used in :class:`SplitReparam` in
order to determine the support of the split value.

:param fn: distribution class
:param slice: slice for which to return support
:param dim: dimension for which to return support
:returns: distribution support
"""
support = fn.support
# Unwrap support
reinterpreted_batch_ndims_vec = []
while isinstance(support, dist.constraints.independent):
reinterpreted_batch_ndims_vec.append(support.reinterpreted_batch_ndims)
support = support.base_constraint
# Slice concatenation and stacking transforms
if isinstance(support, dist.constraints.stack) and support.dim == dim:
support = dist.constraints.stack(support.cseq[slice], dim)
elif isinstance(support, dist.constraints.cat) and support.dim == dim:
support = dist.constraints.cat(support.cseq[slice], dim, support.lengths[slice])
# Wrap support
for reinterpreted_batch_ndims in reinterpreted_batch_ndims_vec[::-1]:
support = dist.constraints.independent(support, reinterpreted_batch_ndims)
return support


class SplitReparam(Reparam):
"""
Reparameterizer to split a random variable along a dimension, similar to
Expand All @@ -28,14 +79,21 @@ class SplitReparam(Reparam):
each chunk.
:type: list(int)
:param int dim: Dimension along which to split. Defaults to -1.
:param callable support_fn: Function which derives the split support
from the site's sampling function, split size, and split dimension.
Default is :func:`default_support` which correctly handles stacking
and concatenation transforms. Other options are :func:`same_support`
which returns the same support as that of the sampling function, and
:func:`real_support` which returns a real support.
"""

def __init__(self, sections, dim):
def __init__(self, sections, dim, support_fn=default_support):
assert isinstance(dim, int) and dim < 0
assert isinstance(sections, list)
assert all(isinstance(size, int) for size in sections)
self.event_dim = -dim
self.sections = sections
self.support_fn = support_fn

def apply(self, msg):
name = msg["name"]
Expand All @@ -53,14 +111,20 @@ def apply(self, msg):
dim = fn.event_dim - self.event_dim
left_shape = fn.event_shape[:dim]
right_shape = fn.event_shape[1 + dim :]
start = 0
for i, size in enumerate(self.sections):
event_shape = left_shape + (size,) + right_shape
value_split[i] = pyro.sample(
f"{name}_split_{i}",
dist.ImproperUniform(fn.support, fn.batch_shape, event_shape),
dist.ImproperUniform(
self.support_fn(fn, slice(start, start + size), -self.event_dim),
fn.batch_shape,
event_shape,
),
obs=value_split[i],
infer={"is_observed": is_observed},
)
start += size

# Combine parts into value.
if value is None:
Expand Down
81 changes: 60 additions & 21 deletions tests/infer/reparam/test_split.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,7 @@

from .util import check_init_reparam


@pytest.mark.parametrize(
event_shape_splits_dim = pytest.mark.parametrize(
"event_shape,splits,dim",
[
((6,), [2, 1, 3], -1),
Expand All @@ -31,7 +30,13 @@
],
ids=str,
)
@pytest.mark.parametrize("batch_shape", [(), (4,), (3, 2)], ids=str)


batch_shape = pytest.mark.parametrize("batch_shape", [(), (4,), (3, 2)], ids=str)


@event_shape_splits_dim
@batch_shape
def test_normal(batch_shape, event_shape, splits, dim):
shape = batch_shape + event_shape
loc = torch.empty(shape).uniform_(-1.0, 1.0).requires_grad_()
Expand Down Expand Up @@ -72,24 +77,8 @@ def model():
assert_close(actual_grads, expected_grads)


@pytest.mark.parametrize(
"event_shape,splits,dim",
[
((6,), [2, 1, 3], -1),
(
(
2,
5,
),
[2, 3],
-1,
),
((4, 2), [1, 3], -2),
((2, 3, 1), [1, 2], -2),
],
ids=str,
)
@pytest.mark.parametrize("batch_shape", [(), (4,), (3, 2)], ids=str)
@event_shape_splits_dim
@batch_shape
def test_init(batch_shape, event_shape, splits, dim):
shape = batch_shape + event_shape
loc = torch.empty(shape).uniform_(-1.0, 1.0)
Expand All @@ -100,3 +89,53 @@ def model():
return pyro.sample("x", dist.Normal(loc, scale).to_event(len(event_shape)))

check_init_reparam(model, SplitReparam(splits, dim))


@batch_shape
def test_transformed_distribution(batch_shape):
num_samples = 10

transform = dist.transforms.StackTransform(
[
dist.transforms.OrderedTransform(),
dist.transforms.DiscreteCosineTransform(),
dist.transforms.HaarTransform(),
],
dim=-1,
)

num_transforms = len(transform.transforms)

def model():
scale_tril = pyro.sample("scale_tril", dist.LKJCholesky(num_transforms, 1))
with pyro.plate_stack("plates", batch_shape):
x_dist = dist.TransformedDistribution(
dist.MultivariateNormal(
torch.zeros(num_samples, num_transforms), scale_tril=scale_tril
).to_event(1),
[transform],
)
return pyro.sample("x", x_dist)

assert model().shape == batch_shape + (num_samples, num_transforms)

pyro.clear_param_store()
guide = pyro.infer.autoguide.AutoMultivariateNormal(model)
guide_sites = guide()

assert guide_sites["x"].shape == batch_shape + (num_samples, num_transforms)

for sections in [[1, 1, 1], [1, 2], [2, 1]]:
split_model = pyro.poutine.reparam(
model, config={"x": SplitReparam(sections, -1)}
)

pyro.clear_param_store()
guide = pyro.infer.autoguide.AutoMultivariateNormal(split_model)
guide_sites = guide()

for n, section in enumerate(sections):
assert guide_sites[f"x_split_{n}"].shape == batch_shape + (
num_samples,
section,
)
Loading