Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Coupling layers #92

Open
wants to merge 16 commits into
base: main
Choose a base branch
from
2 changes: 2 additions & 0 deletions flowtorch/bijectors/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from flowtorch.bijectors.autoregressive import Autoregressive
from flowtorch.bijectors.base import Bijector
from flowtorch.bijectors.compose import Compose
from flowtorch.bijectors.coupling import Coupling
from flowtorch.bijectors.elementwise import Elementwise
from flowtorch.bijectors.elu import ELU
from flowtorch.bijectors.exp import Exp
Expand All @@ -33,6 +34,7 @@
standard_bijectors = [
("Affine", Affine),
("AffineAutoregressive", AffineAutoregressive),
("Coupling", Coupling),
("AffineFixed", AffineFixed),
("ELU", ELU),
("Exp", Exp),
Expand Down
21 changes: 17 additions & 4 deletions flowtorch/bijectors/affine_autoregressive.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,28 @@ def __init__(
*,
shape: torch.Size,
context_shape: Optional[torch.Size] = None,
clamp_values: bool = False,
log_scale_min_clip: float = -5.0,
log_scale_max_clip: float = 3.0,
sigmoid_bias: float = 2.0,
positive_map: str = "softplus",
positive_bias: Optional[float] = None,
) -> None:
super().__init__(
AffineOp.__init__(
vmoens marked this conversation as resolved.
Show resolved Hide resolved
self,
params_fn,
shape=shape,
context_shape=context_shape,
clamp_values=clamp_values,
log_scale_min_clip=log_scale_min_clip,
log_scale_max_clip=log_scale_max_clip,
sigmoid_bias=sigmoid_bias,
positive_map=positive_map,
positive_bias=positive_bias,
)
Autoregressive.__init__(
vmoens marked this conversation as resolved.
Show resolved Hide resolved
self,
params_fn,
shape=shape,
context_shape=context_shape,
)
self.log_scale_min_clip = log_scale_min_clip
self.log_scale_max_clip = log_scale_max_clip
self.sigmoid_bias = sigmoid_bias
2 changes: 1 addition & 1 deletion flowtorch/bijectors/autoregressive.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def inverse(
# TODO: Make permutation, inverse work for other event shapes
log_detJ: Optional[torch.Tensor] = None
for idx in cast(torch.LongTensor, permutation):
_params = self._params_fn(x_new.clone(), context=context)
_params = self._params_fn(x_new.clone(), None, context=context)
x_temp, log_detJ = self._inverse(y, params=_params)
x_new[..., idx] = x_temp[..., idx]
# _log_detJ = out[1]
Expand Down
12 changes: 7 additions & 5 deletions flowtorch/bijectors/base.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
# Copyright (c) Meta Platforms, Inc
import warnings
from typing import Optional, Sequence, Tuple, Union, Callable, Iterator
from typing import Callable, Iterator, Optional, Sequence, Tuple, Union

import flowtorch.parameters
import torch
import torch.distributions
from flowtorch.bijectors.bijective_tensor import to_bijective_tensor, BijectiveTensor
from flowtorch.bijectors.bijective_tensor import BijectiveTensor, to_bijective_tensor
from flowtorch.bijectors.utils import is_record_flow_graph_enabled
from flowtorch.parameters import Parameters
from torch.distributions import constraints
Expand Down Expand Up @@ -75,7 +75,9 @@ def forward(
assert isinstance(x, BijectiveTensor)
return x.get_parent_from_bijector(self)

params = self._params_fn(x, context) if self._params_fn is not None else None
params = (
self._params_fn(x, None, context) if self._params_fn is not None else None
)
y, log_detJ = self._forward(x, params)
if (
is_record_flow_graph_enabled()
Expand Down Expand Up @@ -119,7 +121,7 @@ def inverse(
return y.get_parent_from_bijector(self)

# TODO: What to do in this line?
params = self._params_fn(x, context) if self._params_fn is not None else None
params = self._params_fn(x, y, context) if self._params_fn is not None else None
x, log_detJ = self._inverse(y, params)

if (
Expand Down Expand Up @@ -173,7 +175,7 @@ def log_abs_det_jacobian(
"Computing _log_abs_det_jacobian from values and not from cache."
)
params = (
self._params_fn(x, context) if self._params_fn is not None else None
self._params_fn(x, y, context) if self._params_fn is not None else None
)
return self._log_abs_det_jacobian(x, y, params)
return ladj
Expand Down
2 changes: 1 addition & 1 deletion flowtorch/bijectors/bijective_tensor.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# Copyright (c) Meta Platforms, Inc
from typing import Any, Optional, Iterator, Type, TYPE_CHECKING, Union
from typing import Any, Iterator, Optional, Type, TYPE_CHECKING, Union

if TYPE_CHECKING:
from flowtorch.bijectors.base import Bijector
Expand Down
4 changes: 2 additions & 2 deletions flowtorch/bijectors/compose.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
# Copyright (c) Meta Platforms, Inc
from typing import Optional, Sequence, Iterator
from typing import Iterator, Optional, Sequence

import flowtorch.parameters
import torch
import torch.distributions
from flowtorch.bijectors.base import Bijector
from flowtorch.bijectors.bijective_tensor import to_bijective_tensor, BijectiveTensor
from flowtorch.bijectors.bijective_tensor import BijectiveTensor, to_bijective_tensor
from flowtorch.bijectors.utils import is_record_flow_graph_enabled, requires_log_detJ
from torch.distributions.utils import _sum_rightmost

Expand Down
61 changes: 61 additions & 0 deletions flowtorch/bijectors/coupling.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
# Copyright (c) Meta Platforms, Inc

from typing import Optional, Sequence, Tuple

import flowtorch.parameters

import torch
from flowtorch.bijectors.ops.affine import Affine as AffineOp
from flowtorch.parameters import DenseCoupling


class Coupling(AffineOp):
def __init__(
self,
params_fn: Optional[flowtorch.Lazy] = None,
*,
shape: torch.Size,
context_shape: Optional[torch.Size] = None,
clamp_values: bool = False,
log_scale_min_clip: float = -5.0,
log_scale_max_clip: float = 3.0,
sigmoid_bias: float = 2.0,
positive_map: str = "softplus",
positive_bias: Optional[float] = None,
) -> None:

if params_fn is None:
params_fn = DenseCoupling() # type: ignore

AffineOp.__init__(
self,
params_fn,
shape=shape,
context_shape=context_shape,
clamp_values=clamp_values,
log_scale_min_clip=log_scale_min_clip,
log_scale_max_clip=log_scale_max_clip,
sigmoid_bias=sigmoid_bias,
positive_map=positive_map,
positive_bias=positive_bias,
)

def _forward(
self, x: torch.Tensor, params: Optional[Sequence[torch.Tensor]]
) -> Tuple[torch.Tensor, torch.Tensor]:
assert self._params_fn is not None

x = x[..., self._params_fn.permutation]
vmoens marked this conversation as resolved.
Show resolved Hide resolved
y, ldj = super()._forward(x, params)
y = y[..., self._params_fn.inv_permutation]
return y, ldj

def _inverse(
self, y: torch.Tensor, params: Optional[Sequence[torch.Tensor]]
) -> Tuple[torch.Tensor, torch.Tensor]:
assert self._params_fn is not None

y = y[..., self._params_fn.inv_permutation]
x, ldj = super()._inverse(y, params)
x = x[..., self._params_fn.permutation]
return x, ldj
73 changes: 58 additions & 15 deletions flowtorch/bijectors/ops/affine.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,24 @@
# Copyright (c) Meta Platforms, Inc

from typing import Optional, Sequence, Tuple
from typing import Callable, Dict, Optional, Sequence, Tuple

import flowtorch

import torch
from flowtorch.bijectors.base import Bijector
from flowtorch.ops import clamp_preserve_gradients
from torch.distributions.utils import _sum_rightmost

_DEFAULT_POSITIVE_BIASES = {
"softplus": torch.expm1(torch.ones(1)).log().item(),
"exp": 0.0,
}
_POSITIVE_MAPS: Dict[str, Callable[[torch.Tensor], torch.Tensor]] = {
"softplus": torch.nn.functional.softplus,
"sigmoid": torch.sigmoid,
"exp": torch.exp,
}


class Affine(Bijector):
r"""
Expand All @@ -22,38 +33,64 @@ def __init__(
*,
shape: torch.Size,
context_shape: Optional[torch.Size] = None,
clamp_values: bool = False,
log_scale_min_clip: float = -5.0,
log_scale_max_clip: float = 3.0,
sigmoid_bias: float = 2.0,
positive_map: str = "softplus",
positive_bias: Optional[float] = None,
) -> None:
super().__init__(params_fn, shape=shape, context_shape=context_shape)
self.clamp_values = clamp_values
self.log_scale_min_clip = log_scale_min_clip
self.log_scale_max_clip = log_scale_max_clip
self.sigmoid_bias = sigmoid_bias
if positive_bias is None:
positive_bias = _DEFAULT_POSITIVE_BIASES[positive_map]
self.positive_bias = positive_bias
if positive_map not in _POSITIVE_MAPS:
raise RuntimeError(f"Unknwon positive map {positive_map}")
self._positive_map = _POSITIVE_MAPS[positive_map]
self._exp_map = self._positive_map is torch.exp and self.positive_bias == 0

def positive_map(self, x: torch.Tensor) -> torch.Tensor:
return self._positive_map(x + self.positive_bias)

def _forward(
self, x: torch.Tensor, params: Optional[Sequence[torch.Tensor]]
) -> Tuple[torch.Tensor, torch.Tensor]:
assert params is not None

mean, log_scale = params
log_scale = clamp_preserve_gradients(
log_scale, self.log_scale_min_clip, self.log_scale_max_clip
)
scale = torch.exp(log_scale)
mean, unbounded_scale = params
if self.clamp_values:
unbounded_scale = clamp_preserve_gradients(
unbounded_scale, self.log_scale_min_clip, self.log_scale_max_clip
)
scale = self.positive_map(unbounded_scale)
log_scale = scale.log() if not self._exp_map else unbounded_scale
y = scale * x + mean
return y, _sum_rightmost(log_scale, self.domain.event_dim)

def _inverse(
self, y: torch.Tensor, params: Optional[Sequence[torch.Tensor]]
) -> Tuple[torch.Tensor, torch.Tensor]:
assert params is not None
assert (
params is not None
), f"{self.__class__.__name__}._inverse got no parameters"

mean, unbounded_scale = params
if self.clamp_values:
unbounded_scale = clamp_preserve_gradients(
unbounded_scale, self.log_scale_min_clip, self.log_scale_max_clip
)

if not self._exp_map:
inverse_scale = self.positive_map(unbounded_scale).reciprocal()
log_scale = inverse_scale.log()
else:
inverse_scale = torch.exp(-unbounded_scale)
log_scale = unbounded_scale

mean, log_scale = params
log_scale = clamp_preserve_gradients(
log_scale, self.log_scale_min_clip, self.log_scale_max_clip
)
inverse_scale = torch.exp(-log_scale)
x_new = (y - mean) * inverse_scale
return x_new, _sum_rightmost(log_scale, self.domain.event_dim)

Expand All @@ -65,9 +102,15 @@ def _log_abs_det_jacobian(
) -> torch.Tensor:
assert params is not None

_, log_scale = params
log_scale = clamp_preserve_gradients(
log_scale, self.log_scale_min_clip, self.log_scale_max_clip
_, unbounded_scale = params
if self.clamp_values:
unbounded_scale = clamp_preserve_gradients(
unbounded_scale, self.log_scale_min_clip, self.log_scale_max_clip
)
log_scale = (
self.positive_map(unbounded_scale).log()
if not self._exp_map
else unbounded_scale
)
return _sum_rightmost(log_scale, self.domain.event_dim)

Expand Down
2 changes: 1 addition & 1 deletion flowtorch/distributions/flow.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Copyright (c) Meta Platforms, Inc

from typing import Any, Dict, Optional, Union, Iterator
from typing import Any, Dict, Iterator, Optional, Union

import flowtorch
import torch
Expand Down
3 changes: 2 additions & 1 deletion flowtorch/parameters/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
"""

from flowtorch.parameters.base import Parameters
from flowtorch.parameters.coupling import DenseCoupling
from flowtorch.parameters.dense_autoregressive import DenseAutoregressive
from flowtorch.parameters.tensor import Tensor

__all__ = ["Parameters", "DenseAutoregressive", "Tensor"]
__all__ = ["Parameters", "DenseAutoregressive", "Tensor", "DenseCoupling"]
vmoens marked this conversation as resolved.
Show resolved Hide resolved
4 changes: 3 additions & 1 deletion flowtorch/parameters/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,16 @@ def __init__(
def forward(
self,
x: Optional[torch.Tensor] = None,
y: Optional[torch.Tensor] = None,
vmoens marked this conversation as resolved.
Show resolved Hide resolved
context: Optional[torch.Tensor] = None,
) -> Optional[Sequence[torch.Tensor]]:
# TODO: Caching etc.
return self._forward(x, context)
return self._forward(x, y, context)

def _forward(
self,
x: Optional[torch.Tensor] = None,
y: Optional[torch.Tensor] = None,
context: Optional[torch.Tensor] = None,
) -> Optional[Sequence[torch.Tensor]]:
# I raise an exception rather than using @abstractmethod and
Expand Down
Loading