-
-
Notifications
You must be signed in to change notification settings - Fork 2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implement unconstraining transform for LKJCorr #7380
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -45,6 +45,7 @@ | |
"log", | ||
"sum_to_1", | ||
"circular", | ||
"CholeskyCorr", | ||
"CholeskyCovPacked", | ||
"Chain", | ||
"ZeroSumTransform", | ||
|
@@ -138,6 +139,175 @@ def log_jac_det(self, value, *inputs): | |
return pt.sum(y, axis=-1) | ||
|
||
|
||
class CholeskyCorr(Transform): | ||
""" | ||
Transforms unconstrained real numbers to the off-diagonal elements of | ||
a Cholesky decomposition of a correlation matrix. | ||
|
||
This ensures that the resulting correlation matrix is positive definite. | ||
|
||
#### Mathematical Details | ||
|
||
This bijector provides a change of variables from unconstrained reals to a | ||
parameterization of the CholeskyLKJ distribution. The CholeskyLKJ distribution | ||
[1] is a distribution on the set of Cholesky factors of positive definite | ||
correlation matrices. The CholeskyLKJ probability density function is | ||
obtained from the LKJ density on n x n matrices as follows: | ||
|
||
1 = int p(A | eta) dA | ||
= int Z(eta) * det(A) ** (eta - 1) dA | ||
= int Z(eta) L_ii ** {(n - i - 1) + 2 * (eta - 1)} ^dL_ij (0 <= i < j < n) | ||
|
||
where Z(eta) is the normalizer; the matrix L is the Cholesky factor of the | ||
correlation matrix A; and ^dL_ij denotes the wedge product (or differential) | ||
of the strictly lower triangular entries of L. The entries L_ij are | ||
constrained such that each entry lies in [-1, 1] and the norm of each row is | ||
1. The norm includes the diagonal; which is not included in the wedge product. | ||
To preserve uniqueness, we further specify that the diagonal entries are | ||
positive. | ||
|
||
The image of unconstrained reals under the `CorrelationCholesky` bijector is | ||
the set of correlation matrices which are positive definite. A [correlation | ||
matrix](https://en.wikipedia.org/wiki/Correlation_and_dependence#Correlation_matrices) | ||
can be characterized as a symmetric positive semidefinite matrix with 1s on | ||
the main diagonal. | ||
|
||
For a lower triangular matrix `L` to be a valid Cholesky-factor of a positive | ||
definite correlation matrix, it is necessary and sufficient that each row of | ||
`L` have unit Euclidean norm [1]. To see this, observe that if `L_i` is the | ||
`i`th row of the Cholesky factor corresponding to the correlation matrix `R`, | ||
then the `i`th diagonal entry of `R` satisfies: | ||
|
||
1 = R_i,i = L_i . L_i = ||L_i||^2 | ||
|
||
where '.' is the dot product of vectors and `||...||` denotes the Euclidean | ||
norm. | ||
|
||
Furthermore, observe that `R_i,j` lies in the interval `[-1, 1]`. By the | ||
Cauchy-Schwarz inequality: | ||
|
||
|R_i,j| = |L_i . L_j| <= ||L_i|| ||L_j|| = 1 | ||
|
||
This is a consequence of the fact that `R` is symmetric positive definite with | ||
1s on the main diagonal. | ||
|
||
We choose the mapping from x in `R^{m}` to `R^{n^2}` where `m` is the | ||
`(n - 1)`th triangular number; i.e. `m = 1 + 2 + ... + (n - 1)`. | ||
|
||
L_ij = x_i,j / s_i (for i < j) | ||
L_ii = 1 / s_i | ||
|
||
where s_i = sqrt(1 + x_i,0^2 + x_i,1^2 + ... + x_(i,i-1)^2). We can check that | ||
the required constraints on the image are satisfied. | ||
|
||
#### Examples | ||
|
||
```python | ||
transform = CholeskyCorr(n=3) | ||
x = pt.as_tensor_variable([0.0, 0.0, 0.0]) | ||
y = transform.forward(x).eval() | ||
# y will be the off-diagonal elements of the Cholesky factor | ||
|
||
x_reconstructed = transform.backward(y).eval() | ||
# x_reconstructed should closely match the original x | ||
``` | ||
|
||
#### References | ||
- [Stan Manual. Section 24.2. Cholesky LKJ Correlation Distribution.](https://mc-stan.org/docs/2_18/functions-reference/cholesky-lkj-correlation-distribution.html) | ||
- Lewandowski, D., Kurowicka, D., & Joe, H. (2009). "Generating random correlation matrices based on vines and extended onion method." *Journal of Multivariate Analysis, 100*(5), 1989-2001. | ||
""" | ||
|
||
name = "cholesky-corr" | ||
|
||
def __init__(self, n, validate_args=False): | ||
""" | ||
Initialize the CholeskyCorr transform. | ||
|
||
Parameters | ||
---------- | ||
n : int | ||
Size of the correlation matrix. | ||
validate_args : bool, default False | ||
Whether to validate input arguments. | ||
""" | ||
self.n = n | ||
self.m = int(n * (n - 1) / 2) # Number of off-diagonal elements | ||
self.tril_r_idxs, self.tril_c_idxs = self._generate_tril_indices() | ||
self.triu_r_idxs, self.triu_c_idxs = self._generate_triu_indices() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. See below, not sure we need to cache these. |
||
super().__init__(validate_args=validate_args) | ||
|
||
def _generate_tril_indices(self): | ||
row_indices, col_indices = np.tril_indices(self.n, -1) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not sure if it matters but there is a There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think it's good practice to use the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I originally tried to use the |
||
return (row_indices, col_indices) | ||
|
||
def _generate_triu_indices(self): | ||
row_indices, col_indices = np.triu_indices(self.n, 1) | ||
return (row_indices, col_indices) | ||
|
||
def forward(self, x, *inputs): | ||
""" | ||
Forward transform: Unconstrained real numbers to Cholesky factors. | ||
|
||
Parameters | ||
---------- | ||
x : tensor | ||
Unconstrained real numbers. | ||
|
||
Returns | ||
------- | ||
tensor | ||
Transformed Cholesky factors. | ||
""" | ||
# Initialize a zero matrix | ||
chol = pt.zeros((self.n, self.n), dtype=x.dtype) | ||
|
||
# Assign the unconstrained values to the lower triangular part | ||
chol = pt.set_subtensor(chol[self.tril_r_idxs, self.tril_c_idxs], x) | ||
|
||
# Normalize each row to have unit L2 norm | ||
row_norms = pt.sqrt(pt.sum(chol**2, axis=1, keepdims=True)) | ||
chol = chol / row_norms | ||
|
||
return chol[self.tril_r_idxs, self.tril_c_idxs] | ||
|
||
def backward(self, y, *inputs): | ||
""" | ||
Backward transform: Cholesky factors to unconstrained real numbers. | ||
|
||
Parameters | ||
---------- | ||
y : tensor | ||
Cholesky factors. | ||
|
||
Returns | ||
------- | ||
tensor | ||
Unconstrained real numbers. | ||
""" | ||
# Reconstruct the full Cholesky matrix | ||
chol = pt.zeros((self.n, self.n), dtype=y.dtype) | ||
chol = pt.set_subtensor(chol[self.triu_r_idxs, self.triu_c_idxs], y) | ||
chol = chol + pt.transpose(chol) + pt.eye(self.n, dtype=y.dtype) | ||
|
||
# Perform Cholesky decomposition | ||
chol = pt.linalg.cholesky(chol) | ||
|
||
# Extract the unconstrained parameters by normalizing | ||
row_norms = pt.sqrt(pt.sum(chol**2, axis=1)) | ||
unconstrained = chol / row_norms[:, None] | ||
|
||
return unconstrained[self.tril_r_idxs, self.tril_c_idxs] | ||
|
||
def log_jac_det(self, y, *inputs): | ||
""" | ||
Compute the log determinant of the Jacobian. | ||
|
||
The Jacobian determinant for normalization is the product of row norms. | ||
""" | ||
row_norms = pt.sqrt(pt.sum(y**2, axis=1)) | ||
return -pt.sum(pt.log(row_norms), axis=-1) | ||
|
||
|
||
class CholeskyCovPacked(Transform): | ||
""" | ||
Transforms the diagonal elements of the LKJCholeskyCov distribution to be on the | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you delete this transform class as well? It was a (wrong) patch to the problem you're solving
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can do. Just to confirm, you don't consider MultivariateIntervalTransform to be part of pymc's public API?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nope, can be removed without worries
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok - great