Skip to content

Commit

Permalink
Port TF bijector to ensure posdef LKJCorr samples
Browse files Browse the repository at this point in the history
  • Loading branch information
johncant committed Jun 21, 2024
1 parent f44071b commit 79fa542
Show file tree
Hide file tree
Showing 2 changed files with 114 additions and 1 deletion.
4 changes: 3 additions & 1 deletion pymc/distributions/multivariate.py
Original file line number Diff line number Diff line change
Expand Up @@ -1579,7 +1579,9 @@ def logp(value, n, eta):

@_default_transform.register(_LKJCorr)
def lkjcorr_default_transform(op, rv):
return MultivariateIntervalTransform(-1.0, 1.0)
_, _, _, n, *_ = rv.owner.inputs
n = n.eval()
return transforms.CholeskyCorr(n)


class LKJCorr:
Expand Down
111 changes: 111 additions & 0 deletions pymc/distributions/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

import numpy as np
import pytensor.tensor as pt
import pytensor


# ignore mypy error because it somehow considers that
Expand Down Expand Up @@ -45,6 +46,7 @@
"log",
"sum_to_1",
"circular",
"CholeskyCorr",
"CholeskyCovPacked",
"Chain",
"ZeroSumTransform",
Expand Down Expand Up @@ -138,6 +140,115 @@ def log_jac_det(self, value, *inputs):
return pt.sum(y, axis=-1)


class CholeskyCorr(Transform):
"""
Transforms the off-diagonal elements of a correlation matrix to
unconstrained real numbers.
Note: This is not particular to the LKJ distribution - it is only a
transform to help generate cholesky decompositions for random valid
correlation matrices.
Ported from here: https://github.com/tensorflow/probability/blob/94f592af363e13391858b48f785eb4c250912904/tensorflow_probability/python/bijectors/correlation_cholesky.py#L31
The backward side of this transformation is the off-diagonal upper
triangular elements of a correlation matrix, specified in row major order.
"""

name = "cholesky-corr"

def __init__(self, n):
"""
Parameters
----------
n: int
Size of correlation matrix
"""
self.n = n
self.m = int(n*(n-1)/2) # number of off-diagonal elements
self.tril_r_idxs, self.tril_c_idxs = self._generate_tril_indices()
self.triu_r_idxs, self.triu_c_idxs = self._generate_triu_indices()

def _generate_tril_indices(self):
row_indices, col_indices = np.tril_indices(self.n, -1)
return (
pytensor.shared(row_indices),
pytensor.shared(col_indices)
)

def _generate_triu_indices(self):
row_indices, col_indices = np.triu_indices(self.n, 1)
return (
pytensor.shared(row_indices),
pytensor.shared(col_indices)
)

def _jacobian(self, value, *inputs):
return pt.jacobian(
self.backward(value),
wrt=value
)

def log_jac_det(self, value, *inputs):
"""
Compute log of the determinant of the jacobian.
There are no clever tricks here - we literally compute the jacobian
then compute its determinant then take log.
"""
jac = self._jacobian(value)
return pt.log(pt.linalg.det(jac))

def forward(self, value, *inputs):
"""
Convert the off-diagonal elements of a cholesky decomposition of a
correlation matrix to unconstrained real numbers.
"""
# The correlation matrix is specified via its upper triangular elements
corr = pt.set_subtensor(
pt.zeros((self.n, self.n))[self.triu_r_idxs, self.triu_c_idxs],
value
)
corr = corr + corr.T + pt.eye(self.n)

chol = pt.linalg.cholesky(corr)

# Are the diagonals always guaranteed to be positive?
# I don't know, so we'll use abs
row_norms = 1/pt.abs(pt.diag(chol))

# Multiply by the row norms to undo the normalization
unconstrained = chol*row_norms[:, pt.newaxis]

return unconstrained[self.tril_r_idxs, self.tril_c_idxs]

def backward(self, value, *inputs, foo=False):
"""
Convert unconstrained real numbers to the off-diagonal elements of the
cholesky decomposition of a correlation matrix.
"""
# The diagonals of this matrix are 1, but these ones are just used for
# computing a denominator. The diagonals of the cholesky factor are not
# returned, but they are not ones.
chol_pre_norm = pt.set_subtensor(
pt.eye(self.n).astype("floatX")[self.tril_r_idxs, self.tril_c_idxs],
value
)

# derivative of pt.linalg.norm ended up complex, which caused errors
# row_norm = pt.abs(pt.linalg.norm(chol_pre_norm, axis=1))[:, pt.newaxis].astype("floatX")

row_norm = pt.pow(pt.abs(pt.pow(chol_pre_norm, 2).sum(1)), 0.5)
chol = chol_pre_norm / row_norm[:, pt.newaxis]

# Undo the cholesky decomposition
corr = pt.matmul(chol, chol.T)

# We want the upper triangular indices here.
return corr[self.triu_r_idxs, self.triu_c_idxs]


class CholeskyCovPacked(Transform):
"""
Transforms the diagonal elements of the LKJCholeskyCov distribution to be on the
Expand Down

0 comments on commit 79fa542

Please sign in to comment.