Skip to content

Commit

Permalink
Composing Bijector's with different event_dim's (#107)
Browse files Browse the repository at this point in the history
Summary:
### Motivation
Currently when you compose two `Bijector`'s with different `event_dims`'s, e.g.

```python
bijectors = B.Compose(bijectors=[B.AffineAutoregressive(), B.Sigmoid()])
```

you get an error when the `log_detJ` term is calculated.

### Changes proposed
`Compose.__init__` calculates the output `event_dim` as the maximum over the bijectors, and the calculation of `event_dim` sums out extra dimensions where required.

Pull Request resolved: #107

Test Plan: See #104 for an example of code that currently fails.

Reviewed By: vmoens

Differential Revision: D36169511

Pulled By: stefanwebb

fbshipit-source-id: 773a198bacc420b72ea6b4126e0e1c0ee54f726a
  • Loading branch information
stefanwebb authored and facebook-github-bot committed May 9, 2022
1 parent a8dbf2b commit 852a960
Showing 1 changed file with 25 additions and 4 deletions.
29 changes: 25 additions & 4 deletions flowtorch/bijectors/compose.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
# Copyright (c) Meta Platforms, Inc
import copy
import warnings
from typing import Optional, Sequence

import flowtorch.parameters
import torch
import torch.distributions
import torch.distributions.constraints as constraints
from flowtorch.bijectors.base import Bijector
from flowtorch.bijectors.bijective_tensor import BijectiveTensor, to_bijective_tensor
from flowtorch.bijectors.utils import is_record_flow_graph_enabled, requires_log_detJ
Expand All @@ -30,8 +32,17 @@ def __init__(
self.bijectors.append(bijector(shape=shape)) # type: ignore
shape = self.bijectors[-1].forward_shape(shape) # type: ignore

self.domain = self.bijectors[0].domain # type: ignore
self.codomain = self.bijectors[-1].codomain # type: ignore
# TODO: domain of next bijector must be compatible with codomain
# of previous one
# TODO: Intelligent way to calculate final codomain, like an algebra
# of contraints
self.domain = copy.copy(self.bijectors[0].domain) # type: ignore
self.codomain = copy.copy(self.bijectors[-1].codomain) # type: ignore
max_event_dim = max([b.codomain.event_dim for b in self.bijectors])
if max_event_dim > self.codomain.event_dim:
self.codomain = constraints.independent(
self.codomain, max_event_dim - self.codomain.event_dim
)

self._context_shape = context_shape

Expand All @@ -53,6 +64,9 @@ def forward(
raise RuntimeError(
"neither of x nor y contains the log-abs-det-jacobian"
)
_log_detJ = _sum_rightmost(
_log_detJ, self.codomain.event_dim - bijector.codomain.event_dim
)
log_detJ = log_detJ + _log_detJ if log_detJ is not None else _log_detJ
x_temp = y

Expand Down Expand Up @@ -85,6 +99,10 @@ def inverse(
raise RuntimeError(
"neither of x nor y contains the log-abs-det-jacobian"
)
event_dim: int = bijector.codomain.event_dim # type: ignore
_log_detJ = _sum_rightmost(
_log_detJ, self.codomain.event_dim - event_dim
)
log_detJ = log_detJ + _log_detJ if log_detJ is not None else _log_detJ
y_temp = x # type: ignore

Expand All @@ -106,7 +124,7 @@ def log_abs_det_jacobian(
"""
ldj = _sum_rightmost(
torch.zeros_like(y),
self.domain.event_dim,
self.codomain.event_dim,
)

if isinstance(x, BijectiveTensor) and x.has_ancestor(y):
Expand Down Expand Up @@ -135,7 +153,10 @@ def log_abs_det_jacobian(
y_inv = bijector.inverse(y, context) # type: ignore
else:
y_inv = parents.pop()
ldj += bijector.log_abs_det_jacobian(y_inv, y, context) # type: ignore
_log_detJ = bijector.log_abs_det_jacobian(y_inv, y, context) # type: ignore
event_dim: int = bijector.codomain.event_dim # type: ignore
_log_detJ = _sum_rightmost(_log_detJ, self.codomain.event_dim - event_dim)
ldj += _log_detJ
y = y_inv
return ldj

Expand Down

0 comments on commit 852a960

Please sign in to comment.