Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for transformed distributions, based on stacking or concatenation transforms, in SplitReparam #3390

Merged
merged 5 commits into from
Aug 4, 2024

Conversation

BenZickel
Copy link
Contributor

@BenZickel BenZickel commented Jul 30, 2024

The Problem

The SplitReparam reparameterization does not support transformed distributions that are based on stacking or concatenation transforms.

Consider the code

import pyro
import pyro.distributions as dist
from pyro.infer.reparam import SplitReparam

import torch

batch_shape = (6, 5)
num_samples = 10

transform = dist.transforms.StackTransform([
        dist.transforms.OrderedTransform(),
        dist.transforms.DiscreteCosineTransform(),
        dist.transforms.HaarTransform()], dim=-1)

num_transforms = len(transform.transforms)

def model():
    scale_tril = pyro.sample("scale_tril", dist.LKJCholesky(num_transforms, 1))
    with pyro.plate_stack("plates", batch_shape):
        x_dist = dist.TransformedDistribution(
            dist.MultivariateNormal(
                torch.zeros(num_samples, num_transforms), scale_tril=scale_tril
            ).to_event(1),
            [transform])
        return pyro.sample("x", x_dist)

split_model = pyro.poutine.reparam(model, config={"x": SplitReparam([2, 1], -1)})

pyro.clear_param_store()
guide = pyro.infer.autoguide.AutoMultivariateNormal(split_model)
guide_sites = guide()

which raises the error

Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "C:\SW\pyro-ppl\pyro\nn\module.py", line 527, in __call__
    result = super().__call__(*args, **kwargs)
  File "C:\ProgramData\Anaconda3\lib\site-packages\torch\nn\modules\module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "C:\ProgramData\Anaconda3\lib\site-packages\torch\nn\modules\module.py", line 1562, in _call_impl  
    return forward_call(*args, **kwargs)
  File "C:\SW\pyro-ppl\pyro\infer\autoguide\guides.py", line 759, in forward
    self._setup_prototype(*args, **kwargs)
  File "C:\SW\pyro-ppl\pyro\infer\autoguide\guides.py", line 875, in _setup_prototype
    super()._setup_prototype(*args, **kwargs)
  File "C:\SW\pyro-ppl\pyro\infer\autoguide\guides.py", line 644, in _setup_prototype
    biject_to(site["fn"].support).inv(site["value"]).shape
  File "C:\ProgramData\Anaconda3\lib\site-packages\torch\distributions\transforms.py", line 263, in __call__
    return self._inv._inv_call(x)
  File "C:\ProgramData\Anaconda3\lib\site-packages\torch\distributions\transforms.py", line 170, in _inv_call
    return self._inverse(y)
  File "C:\ProgramData\Anaconda3\lib\site-packages\torch\distributions\transforms.py", line 455, in _inverse
    return self.base_transform.inv(y)
  File "C:\ProgramData\Anaconda3\lib\site-packages\torch\distributions\transforms.py", line 263, in __call__
    return self._inv._inv_call(x)
  File "C:\ProgramData\Anaconda3\lib\site-packages\torch\distributions\transforms.py", line 170, in _inv_call
    return self._inverse(y)
  File "C:\ProgramData\Anaconda3\lib\site-packages\torch\distributions\transforms.py", line 1172, in _inverse
    assert y.size(self.dim) == len(self.transforms)
AssertionError

The error is due to SplitReparam not creating the right support for the sites of the split reparameterization.

The Solution

Change the way SplitReparam figures out the support of slices of transformed distributions that are based on stacking or concatenation transforms.

@BenZickel BenZickel changed the title Support of transformed distributions, based on stacking or concatenation transforms, in SplitReparam Support for transformed distributions, based on stacking or concatenation transforms, in SplitReparam Jul 30, 2024
fritzo
fritzo previously approved these changes Aug 4, 2024
Copy link
Member

@fritzo fritzo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Looks like there were merge conflicts that need to be resolved.

@BenZickel BenZickel requested a review from fritzo August 4, 2024 16:39
@fritzo fritzo merged commit 5cebc44 into pyro-ppl:dev Aug 4, 2024
9 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants