Skip to content

Commit

Permalink
Additional scale functions for AffineOp (#109)
Browse files Browse the repository at this point in the history
Summary:
### Motivation
As pointed out in #85, it may be preferable to use `softplus` rather than `exp` to calculate the scale parameter of the affine map in `bij.ops.Affine`.

### Changes proposed
Another PR #92 by vmoens implements `softplus`, `sigmoid`, and `exp` options for the scale parameter - I have factored that out and simplified some of the design in order to make #92 easier for review. `softplus` is now the default option for `Affine`

Pull Request resolved: #109

Test Plan: `pytest tests/`

Reviewed By: vmoens

Differential Revision: D36169529

Pulled By: stefanwebb

fbshipit-source-id: 625387e10399291a5a404c28f4ada743d0945649
  • Loading branch information
stefanwebb authored and facebook-github-bot committed May 9, 2022
1 parent 852a960 commit d4ae3c0
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 18 deletions.
6 changes: 4 additions & 2 deletions flowtorch/bijectors/affine_autoregressive.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,17 @@ def __init__(
*,
shape: torch.Size,
context_shape: Optional[torch.Size] = None,
clamp_values: bool = False,
log_scale_min_clip: float = -5.0,
log_scale_max_clip: float = 3.0,
sigmoid_bias: float = 2.0,
scale_fn: str = "softplus",
) -> None:
super().__init__(
params_fn,
shape=shape,
context_shape=context_shape,
)
self.clamp_values = clamp_values
self.log_scale_min_clip = log_scale_min_clip
self.log_scale_max_clip = log_scale_max_clip
self.sigmoid_bias = sigmoid_bias
self.scale_fn = scale_fn
79 changes: 63 additions & 16 deletions flowtorch/bijectors/ops/affine.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import flowtorch
import torch
import torch.nn.functional as F
from flowtorch.bijectors.base import Bijector
from flowtorch.ops import clamp_preserve_gradients
from torch.distributions.utils import _sum_rightmost
Expand All @@ -22,25 +23,66 @@ def __init__(
*,
shape: torch.Size,
context_shape: Optional[torch.Size] = None,
clamp_values: bool = False,
log_scale_min_clip: float = -5.0,
log_scale_max_clip: float = 3.0,
sigmoid_bias: float = 2.0,
scale_fn: str = "softplus",
) -> None:
super().__init__(params_fn, shape=shape, context_shape=context_shape)
self.clamp_values = clamp_values
self.log_scale_min_clip = log_scale_min_clip
self.log_scale_max_clip = log_scale_max_clip
self.sigmoid_bias = sigmoid_bias
self.scale_fn = scale_fn

def _scale_fn(
self, unbounded_scale: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
# NOTE: Need to hardcode log(f(x)) for numerical stability
if self.scale_fn == "softplus":
scale = F.softplus(unbounded_scale)
log_scale = torch.log(scale)
elif self.scale_fn == "exp":
scale = torch.exp(unbounded_scale)
log_scale = unbounded_scale
elif self.scale_fn == "sigmoid":
scale = torch.sigmoid(unbounded_scale)
log_scale = F.logsigmoid(unbounded_scale)
else:
raise ValueError(f"Unknown scale function: {self.scale_fn}")

return scale, log_scale

def _inv_scale_fn(
self, unbounded_scale: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
# NOTE: Need to hardcode 1./log(f(x)) for numerical stability
if self.scale_fn == "softplus":
scale = F.softplus(unbounded_scale)
inverse_scale = F.softplus(unbounded_scale).reciprocal()
log_scale = torch.log(scale)
elif self.scale_fn == "exp":
inverse_scale = torch.exp(-unbounded_scale)
log_scale = unbounded_scale
elif self.scale_fn == "sigmoid":
inverse_scale = torch.exp(-unbounded_scale) + 1.0
log_scale = F.logsigmoid(unbounded_scale)
else:
raise ValueError(f"Unknown scale function: {self.scale_fn}")

return inverse_scale, log_scale

def _forward(
self, x: torch.Tensor, params: Optional[Sequence[torch.Tensor]]
) -> Tuple[torch.Tensor, torch.Tensor]:
assert params is not None

mean, log_scale = params
log_scale = clamp_preserve_gradients(
log_scale, self.log_scale_min_clip, self.log_scale_max_clip
)
scale = torch.exp(log_scale)
mean, unbounded_scale = params
if self.clamp_values:
unbounded_scale = clamp_preserve_gradients(
unbounded_scale, self.log_scale_min_clip, self.log_scale_max_clip
)

scale, log_scale = self._scale_fn(unbounded_scale)
y = scale * x + mean
return y, _sum_rightmost(log_scale, self.domain.event_dim)

Expand All @@ -49,11 +91,13 @@ def _inverse(
) -> Tuple[torch.Tensor, torch.Tensor]:
assert params is not None

mean, log_scale = params
log_scale = clamp_preserve_gradients(
log_scale, self.log_scale_min_clip, self.log_scale_max_clip
)
inverse_scale = torch.exp(-log_scale)
mean, unbounded_scale = params
if self.clamp_values:
unbounded_scale = clamp_preserve_gradients(
unbounded_scale, self.log_scale_min_clip, self.log_scale_max_clip
)

inverse_scale, log_scale = self._inv_scale_fn(unbounded_scale)
x_new = (y - mean) * inverse_scale
return x_new, _sum_rightmost(log_scale, self.domain.event_dim)

Expand All @@ -65,10 +109,13 @@ def _log_abs_det_jacobian(
) -> torch.Tensor:
assert params is not None

_, log_scale = params
log_scale = clamp_preserve_gradients(
log_scale, self.log_scale_min_clip, self.log_scale_max_clip
)
_, unbounded_scale = params
if self.clamp_values:
unbounded_scale = clamp_preserve_gradients(
unbounded_scale, self.log_scale_min_clip, self.log_scale_max_clip
)
_, log_scale = self._scale_fn(unbounded_scale)

return _sum_rightmost(log_scale, self.domain.event_dim)

def param_shapes(self, shape: torch.Size) -> Tuple[torch.Size, torch.Size]:
Expand Down

0 comments on commit d4ae3c0

Please sign in to comment.