From 79fa542c54a7648f824aa24bbf2d44dc43bf80fa Mon Sep 17 00:00:00 2001 From: John Cant Date: Fri, 21 Jun 2024 14:35:12 +0100 Subject: [PATCH] Port TF bijector to ensure posdef LKJCorr samples --- pymc/distributions/multivariate.py | 4 +- pymc/distributions/transforms.py | 111 +++++++++++++++++++++++++++++ 2 files changed, 114 insertions(+), 1 deletion(-) diff --git a/pymc/distributions/multivariate.py b/pymc/distributions/multivariate.py index 359b0743dd8..5b1934d3e19 100644 --- a/pymc/distributions/multivariate.py +++ b/pymc/distributions/multivariate.py @@ -1579,7 +1579,9 @@ def logp(value, n, eta): @_default_transform.register(_LKJCorr) def lkjcorr_default_transform(op, rv): - return MultivariateIntervalTransform(-1.0, 1.0) + _, _, _, n, *_ = rv.owner.inputs + n = n.eval() + return transforms.CholeskyCorr(n) class LKJCorr: diff --git a/pymc/distributions/transforms.py b/pymc/distributions/transforms.py index d8998889cfd..74a4bfd9616 100644 --- a/pymc/distributions/transforms.py +++ b/pymc/distributions/transforms.py @@ -17,6 +17,7 @@ import numpy as np import pytensor.tensor as pt +import pytensor # ignore mypy error because it somehow considers that @@ -45,6 +46,7 @@ "log", "sum_to_1", "circular", + "CholeskyCorr", "CholeskyCovPacked", "Chain", "ZeroSumTransform", @@ -138,6 +140,115 @@ def log_jac_det(self, value, *inputs): return pt.sum(y, axis=-1) +class CholeskyCorr(Transform): + """ + Transforms the off-diagonal elements of a correlation matrix to + unconstrained real numbers. + + Note: This is not particular to the LKJ distribution - it is only a + transform to help generate cholesky decompositions for random valid + correlation matrices. + + Ported from here: https://github.com/tensorflow/probability/blob/94f592af363e13391858b48f785eb4c250912904/tensorflow_probability/python/bijectors/correlation_cholesky.py#L31 + + The backward side of this transformation is the off-diagonal upper + triangular elements of a correlation matrix, specified in row major order. + """ + + name = "cholesky-corr" + + def __init__(self, n): + """ + + Parameters + ---------- + n: int + Size of correlation matrix + """ + self.n = n + self.m = int(n*(n-1)/2) # number of off-diagonal elements + self.tril_r_idxs, self.tril_c_idxs = self._generate_tril_indices() + self.triu_r_idxs, self.triu_c_idxs = self._generate_triu_indices() + + def _generate_tril_indices(self): + row_indices, col_indices = np.tril_indices(self.n, -1) + return ( + pytensor.shared(row_indices), + pytensor.shared(col_indices) + ) + + def _generate_triu_indices(self): + row_indices, col_indices = np.triu_indices(self.n, 1) + return ( + pytensor.shared(row_indices), + pytensor.shared(col_indices) + ) + + def _jacobian(self, value, *inputs): + return pt.jacobian( + self.backward(value), + wrt=value + ) + + def log_jac_det(self, value, *inputs): + """ + Compute log of the determinant of the jacobian. + + There are no clever tricks here - we literally compute the jacobian + then compute its determinant then take log. + """ + jac = self._jacobian(value) + return pt.log(pt.linalg.det(jac)) + + def forward(self, value, *inputs): + """ + Convert the off-diagonal elements of a cholesky decomposition of a + correlation matrix to unconstrained real numbers. + """ + # The correlation matrix is specified via its upper triangular elements + corr = pt.set_subtensor( + pt.zeros((self.n, self.n))[self.triu_r_idxs, self.triu_c_idxs], + value + ) + corr = corr + corr.T + pt.eye(self.n) + + chol = pt.linalg.cholesky(corr) + + # Are the diagonals always guaranteed to be positive? + # I don't know, so we'll use abs + row_norms = 1/pt.abs(pt.diag(chol)) + + # Multiply by the row norms to undo the normalization + unconstrained = chol*row_norms[:, pt.newaxis] + + return unconstrained[self.tril_r_idxs, self.tril_c_idxs] + + def backward(self, value, *inputs, foo=False): + """ + Convert unconstrained real numbers to the off-diagonal elements of the + cholesky decomposition of a correlation matrix. + """ + # The diagonals of this matrix are 1, but these ones are just used for + # computing a denominator. The diagonals of the cholesky factor are not + # returned, but they are not ones. + chol_pre_norm = pt.set_subtensor( + pt.eye(self.n).astype("floatX")[self.tril_r_idxs, self.tril_c_idxs], + value + ) + + # derivative of pt.linalg.norm ended up complex, which caused errors +# row_norm = pt.abs(pt.linalg.norm(chol_pre_norm, axis=1))[:, pt.newaxis].astype("floatX") + + row_norm = pt.pow(pt.abs(pt.pow(chol_pre_norm, 2).sum(1)), 0.5) + chol = chol_pre_norm / row_norm[:, pt.newaxis] + + # Undo the cholesky decomposition + corr = pt.matmul(chol, chol.T) + + # We want the upper triangular indices here. + return corr[self.triu_r_idxs, self.triu_c_idxs] + + class CholeskyCovPacked(Transform): """ Transforms the diagonal elements of the LKJCholeskyCov distribution to be on the