Skip to content

Commit

Permalink
Add ZeroSumNormal distribution
Browse files Browse the repository at this point in the history
  • Loading branch information
Siddharth Baleja committed Jul 28, 2024
1 parent 8eaa9be commit 657e208
Showing 1 changed file with 83 additions and 25 deletions.
108 changes: 83 additions & 25 deletions pymc/distributions/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,15 +53,20 @@

def __getattr__(name):
if name in ("univariate_ordered", "multivariate_ordered"):
warnings.warn(f"{name} has been deprecated, use ordered instead.", FutureWarning)
warnings.warn(
f"{name} has been deprecated, use ordered instead.",
FutureWarning)
return ordered

if name in ("univariate_sum_to_1", "multivariate_sum_to_1"):
warnings.warn(f"{name} has been deprecated, use sum_to_1 instead.", FutureWarning)
warnings.warn(
f"{name} has been deprecated, use sum_to_1 instead.",
FutureWarning)
return sum_to_1

if name == "RVTransform":
warnings.warn("RVTransform has been renamed to Transform", FutureWarning)
warnings.warn(
"RVTransform has been renamed to Transform", FutureWarning)
return Transform

raise AttributeError(f"module {__name__} has no attribute {name}")
Expand Down Expand Up @@ -96,7 +101,9 @@ class Ordered(Transform):

def __init__(self, ndim_supp=None):
if ndim_supp is not None:
warnings.warn("ndim_supp argument is deprecated and has no effect", FutureWarning)
warnings.warn(
"ndim_supp argument is deprecated and has no effect",
FutureWarning)

def backward(self, value, *inputs):
x = pt.zeros(value.shape)
Expand All @@ -107,7 +114,8 @@ def backward(self, value, *inputs):
def forward(self, value, *inputs):
y = pt.zeros(value.shape)
y = pt.set_subtensor(y[..., 0], value[..., 0])
y = pt.set_subtensor(y[..., 1:], pt.log(value[..., 1:] - value[..., :-1]))
log_value = pt.log(value[..., 1:] - value[..., :-1])
y = pt.set_subtensor(y[..., 1:], log_value)
return y

def log_jac_det(self, value, *inputs):
Expand All @@ -116,15 +124,18 @@ def log_jac_det(self, value, *inputs):

class SumTo1(Transform):
"""
Transforms K - 1 dimensional simplex space (k values in [0,1] and that sum to 1) to a K - 1 vector of values in [0,1]
This Transformation operates on the last dimension of the input tensor.
Transforms K - 1 dimensional simplex space (k values in [0,1] and that
sum to 1) to a K - 1 vector of values in [0,1]. This Transformation
operates on the last dimension of the input tensor.
"""

name = "sumto1"

def __init__(self, ndim_supp=None):
if ndim_supp is not None:
warnings.warn("ndim_supp argument is deprecated and has no effect", FutureWarning)
warnings.warn(
"ndim_supp argument is deprecated and has no effect",
FutureWarning)

def backward(self, value, *inputs):
remaining = 1 - pt.sum(value[..., :], axis=-1, keepdims=True)
Expand All @@ -140,7 +151,8 @@ def log_jac_det(self, value, *inputs):

class CholeskyCovPacked(Transform):
"""
Transforms the diagonal elements of the LKJCholeskyCov distribution to be on the
Transforms the diagonal elements of
the LKJCholeskyCov distribution to be on the
log scale
"""

Expand All @@ -157,10 +169,14 @@ def __init__(self, n):
self.diag_idxs = pt.arange(1, n + 1).cumsum() - 1

def backward(self, value, *inputs):
return pt.set_subtensor(value[..., self.diag_idxs], pt.exp(value[..., self.diag_idxs]))
diag_values = value[..., self.diag_idxs]
exp_values = pt.exp(diag_values)
return pt.set_subtensor(value[..., self.diag_idxs], exp_values)

def forward(self, value, *inputs):
return pt.set_subtensor(value[..., self.diag_idxs], pt.log(value[..., self.diag_idxs]))
diag_values = value[..., self.diag_idxs]
log_values = pt.log(diag_values)
return pt.set_subtensor(value[..., self.diag_idxs], log_values)

def log_jac_det(self, value, *inputs):
return pt.sum(value[..., self.diag_idxs], axis=-1)
Expand All @@ -180,8 +196,9 @@ def log_jac_det(self, value, *inputs):


class Interval(IntervalTransform):
"""Wrapper around :class:`pymc.logprob.transforms.IntervalTransform` for use in the
``transform`` argument of a random variable.
"""
Wrapper around :class:`pymc.logprob.transforms.IntervalTransform` for use
in the ``transform`` argument of a random variable.
Parameters
----------
Expand All @@ -192,15 +209,15 @@ class Interval(IntervalTransform):
Upper bound of the interval transform. Must be a constant finite value.
By default (``upper=None``), the interval is not bounded above.
bounds_fn : callable, optional
Alternative to lower and upper. Must return a tuple of lower and upper bounds
as a symbolic function of the respective distribution inputs. If one of lower or
upper is ``None``, the interval is unbounded on that edge.
.. warning:: Expressions returned by `bounds_fn` should depend only on the
distribution inputs or other constants. Expressions that depend on nonlocal
variables, such as other distributions defined in the model context will
likely break sampling.
Alternative to lower and upper. Must return a tuple of lower and upper
bounds as a symbolic function of the respective distribution inputs. If
one of lower or upper is ``None``,the interval is unbounded on
that edge.
.. warning:: Expressions returned by `bounds_fn` should depend only on
the distribution inputs or other constants. Expressions that depend
on nonlocal variables, such as other distributions defined in the
model context will likely break sampling.
Examples
--------
Expand All @@ -220,10 +237,14 @@ def get_bounds(rng, size, mu, sigma):
return 0, None
with pm.Model():
interval = pm.distributions.transforms.Interval(bounds_fn=get_bounds)
interval = pm.distributions.transforms.Interval(
bounds_fn=get_bounds
)
x = pm.Normal("x", transform=interval)
Create a lower-bounded interval transform that depends on a distribution parameter
Create a lower-bounded interval transform that depends on a
distribution parameter
.. code-block:: python
Expand Down Expand Up @@ -267,10 +288,47 @@ class ZeroSumTransform(Transform):
"""
Constrains any random samples to sum to zero along the user-provided ``zerosum_axes``.
This transform is useful when modeling distributions where the sum of certain dimensions
must be zero, such as in some types of constrained latent variable models or in certain
types of signal processing applications.
Parameters
----------
zerosum_axes : list of ints
Must be a list of integers (positive or negative).
zerosum_axes : list of int
List of integers specifying the axes along which the random samples should sum to zero.
Positive integers indicate dimensions in the standard order, while negative integers
can be used to reference dimensions from the end of the shape.
Examples
--------
Suppose you want to ensure that the last dimension of a tensor sums to zero. You can use
`ZeroSumTransform` as follows:
.. code-block:: python
import pymc as pm
with pm.Model() as model:
# Create a 2D variable with the last axis constrained to sum to zero
x = pm.Normal("x", shape=(10, 5), transform=pm.distributions.transforms.ZeroSumTransform(zerosum_axes=[-1]))
Methods
-------
forward(value, *rv_inputs)
Transforms the input tensor to ensure that the specified axes sum to zero.
backward(value, *rv_inputs)
Computes the inverse transform to convert back to the original space where the sum was zero.
log_jac_det(value, *rv_inputs)
Returns the log Jacobian determinant of the transform. For this transform, it is zero.
Notes
-----
The `extend_axis` and `extend_axis_rev` methods are used internally to handle the transformation:
- `extend_axis`: Extends the axis by adding an additional element to ensure zero-sum constraint.
- `extend_axis_rev`: Reverses the extension operation applied by `extend_axis`.
"""

name = "zerosum"
Expand Down

0 comments on commit 657e208

Please sign in to comment.