-
-
Notifications
You must be signed in to change notification settings - Fork 2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implement unconstraining transform for LKJCorr #7380
base: main
Are you sure you want to change the base?
Changes from 3 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -45,6 +45,7 @@ | |
"log", | ||
"sum_to_1", | ||
"circular", | ||
"CholeskyCorr", | ||
"CholeskyCovPacked", | ||
"Chain", | ||
"ZeroSumTransform", | ||
|
@@ -138,6 +139,125 @@ def log_jac_det(self, value, *inputs): | |
return pt.sum(y, axis=-1) | ||
|
||
|
||
class CholeskyCorr(Transform): | ||
""" | ||
Transforms unconstrained real numbers to the off-diagonal elements of | ||
a Cholesky decomposition of a correlation matrix. | ||
|
||
This ensures that the resulting correlation matrix is positive definite. | ||
|
||
#### Mathematical Details | ||
|
||
[Include detailed mathematical explanations similar to the original TFP bijector.] | ||
|
||
#### Examples | ||
|
||
```python | ||
transform = CholeskyCorr(n=3) | ||
x = pt.as_tensor_variable([0.0, 0.0, 0.0]) | ||
y = transform.forward(x).eval() | ||
# y will be the off-diagonal elements of the Cholesky factor | ||
|
||
x_reconstructed = transform.backward(y).eval() | ||
# x_reconstructed should closely match the original x | ||
``` | ||
|
||
#### References | ||
- [Stan Manual. Section 24.2. Cholesky LKJ Correlation Distribution.](https://mc-stan.org/docs/2_18/functions-reference/cholesky-lkj-correlation-distribution.html) | ||
- Lewandowski, D., Kurowicka, D., & Joe, H. (2009). "Generating random correlation matrices based on vines and extended onion method." *Journal of Multivariate Analysis, 100*(5), 1989-2001. | ||
""" | ||
|
||
name = "cholesky-corr" | ||
|
||
def __init__(self, n, validate_args=False): | ||
""" | ||
Initialize the CholeskyCorr transform. | ||
|
||
Parameters | ||
---------- | ||
n : int | ||
Size of the correlation matrix. | ||
validate_args : bool, default False | ||
Whether to validate input arguments. | ||
""" | ||
self.n = n | ||
self.m = int(n * (n - 1) / 2) # Number of off-diagonal elements | ||
self.tril_r_idxs, self.tril_c_idxs = self._generate_tril_indices() | ||
self.triu_r_idxs, self.triu_c_idxs = self._generate_triu_indices() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. See below, not sure we need to cache these. |
||
super().__init__(validate_args=validate_args) | ||
|
||
def _generate_tril_indices(self): | ||
row_indices, col_indices = np.tril_indices(self.n, -1) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not sure if it matters but there is a There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think it's good practice to use the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I originally tried to use the |
||
return (row_indices, col_indices) | ||
|
||
def _generate_triu_indices(self): | ||
row_indices, col_indices = np.triu_indices(self.n, 1) | ||
return (row_indices, col_indices) | ||
|
||
def forward(self, x, *inputs): | ||
""" | ||
Forward transform: Unconstrained real numbers to Cholesky factors. | ||
|
||
Parameters | ||
---------- | ||
x : tensor | ||
Unconstrained real numbers. | ||
|
||
Returns | ||
------- | ||
tensor | ||
Transformed Cholesky factors. | ||
""" | ||
# Initialize a zero matrix | ||
chol = pt.zeros((self.n, self.n), dtype=x.dtype) | ||
|
||
# Assign the unconstrained values to the lower triangular part | ||
chol = pt.set_subtensor(chol[self.tril_r_idxs, self.tril_c_idxs], x) | ||
|
||
# Normalize each row to have unit L2 norm | ||
row_norms = pt.sqrt(pt.sum(chol**2, axis=1, keepdims=True)) | ||
chol = chol / row_norms | ||
|
||
return chol[self.tril_r_idxs, self.tril_c_idxs] | ||
|
||
def backward(self, y, *inputs): | ||
""" | ||
Backward transform: Cholesky factors to unconstrained real numbers. | ||
|
||
Parameters | ||
---------- | ||
y : tensor | ||
Cholesky factors. | ||
|
||
Returns | ||
------- | ||
tensor | ||
Unconstrained real numbers. | ||
""" | ||
# Reconstruct the full Cholesky matrix | ||
chol = pt.zeros((self.n, self.n), dtype=y.dtype) | ||
chol = pt.set_subtensor(chol[self.triu_r_idxs, self.triu_c_idxs], y) | ||
chol = chol + pt.transpose(chol) + pt.eye(self.n, dtype=y.dtype) | ||
|
||
# Perform Cholesky decomposition | ||
chol = pt.linalg.cholesky(chol) | ||
|
||
# Extract the unconstrained parameters by normalizing | ||
row_norms = pt.sqrt(pt.sum(chol**2, axis=1)) | ||
unconstrained = chol / row_norms[:, None] | ||
|
||
return unconstrained[self.tril_r_idxs, self.tril_c_idxs] | ||
|
||
def log_jac_det(self, y, *inputs): | ||
""" | ||
Compute the log determinant of the Jacobian. | ||
|
||
The Jacobian determinant for normalization is the product of row norms. | ||
""" | ||
row_norms = pt.sqrt(pt.sum(y**2, axis=1)) | ||
return -pt.sum(pt.log(row_norms), axis=-1) | ||
|
||
|
||
class CholeskyCovPacked(Transform): | ||
""" | ||
Transforms the diagonal elements of the LKJCholeskyCov distribution to be on the | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -24,6 +24,7 @@ | |
import pymc as pm | ||
import pymc.distributions.transforms as tr | ||
|
||
from pymc.distributions.transforms import CholeskyCorr | ||
from pymc.logprob.basic import transformed_conditional_logp | ||
from pymc.logprob.transforms import Transform | ||
from pymc.pytensorf import floatX, jacobian | ||
|
@@ -673,3 +674,142 @@ def test_deprecated_ndim_supp_transforms(): | |
|
||
with pytest.warns(FutureWarning, match="deprecated"): | ||
assert tr.multivariate_sum_to_1 == tr.sum_to_1 | ||
|
||
|
||
def test_lkjcorr_transform_round_trip(): | ||
""" | ||
Test that applying the forward transform followed by the backward transform | ||
retrieves the original unconstrained parameters, and that sampled matrices are positive definite. | ||
""" | ||
with pm.Model() as model: | ||
rho = pm.LKJCorr("rho", n=3, eta=2) | ||
|
||
trace = pm.sample( | ||
100, tune=100, chains=1, cores=1, progressbar=False, return_inferencedata=False | ||
) | ||
|
||
# Extract the sampled correlation matrices | ||
rho_samples = trace["rho"] | ||
num_samples = rho_samples.shape[0] | ||
|
||
for i in range(num_samples): | ||
sample_matrix = rho_samples[i] | ||
|
||
# Check if the sampled matrix is positive definite | ||
try: | ||
np.linalg.cholesky(sample_matrix) | ||
except np.linalg.LinAlgError: | ||
pytest.fail(f"Sampled correlation matrix at index {i} is not positive definite.") | ||
|
||
# Perform round-trip transform: forward and then backward | ||
transform = CholeskyCorr(n=3) | ||
unconstrained = transform.forward(pt.as_tensor_variable(sample_matrix)).eval() | ||
reconstructed = transform.backward(unconstrained).eval() | ||
|
||
# Assert that the original and reconstructed unconstrained parameters are close | ||
assert_allclose(sample_matrix, reconstructed, atol=1e-6) | ||
|
||
|
||
def test_lkjcorr_log_jac_det(): | ||
""" | ||
Verify that the computed log determinant of the Jacobian matches the expected closed-form solution. | ||
""" | ||
n = 3 | ||
transform = CholeskyCorr(n=n) | ||
|
||
# Create a sample unconstrained vector (all zeros for simplicity) | ||
x = np.zeros(int(n * (n - 1) / 2), dtype=pytensor.config.floatX) | ||
x_tensor = pt.as_tensor_variable(x) | ||
|
||
# Perform forward transform to obtain Cholesky factors | ||
y = transform.forward(x_tensor).eval() | ||
|
||
# Compute the log determinant using the transform's method | ||
computed_log_jac_det = transform.log_jac_det(y).eval() | ||
|
||
# Expected log determinant: 0 (since row norms are 1) | ||
expected_log_jac_det = 0.0 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Weak test. Tell it to compare with pytensor jacobian machinery with a non-trivial input. And to reuse test code that already exists to do that |
||
|
||
assert_allclose(computed_log_jac_det, expected_log_jac_det, atol=1e-6) | ||
|
||
|
||
@pytest.mark.parametrize("n", [2, 4, 5]) | ||
def test_lkjcorr_transform_various_sizes(n): | ||
""" | ||
Test the CholeskyCorr transform with various sizes of correlation matrices. | ||
""" | ||
transform = CholeskyCorr(n=n) | ||
unconstrained_size = int(n * (n - 1) / 2) | ||
|
||
# Generate random unconstrained real numbers | ||
x = np.random.randn(unconstrained_size).astype(pytensor.config.floatX) | ||
x_tensor = pt.as_tensor_variable(x) | ||
|
||
# Perform forward transform | ||
y = transform.forward(x_tensor).eval() | ||
|
||
# Perform backward transform | ||
reconstructed = transform.backward(y).eval() | ||
|
||
# Assert that the original and reconstructed unconstrained parameters are close | ||
assert_allclose(x, reconstructed, atol=1e-6) | ||
|
||
|
||
def test_lkjcorr_invalid_n(): | ||
""" | ||
Test that initializing CholeskyCorr with invalid 'n' values raises appropriate errors. | ||
""" | ||
with pytest.raises(ValueError): | ||
# 'n' must be an integer greater than 1 | ||
CholeskyCorr(n=1) | ||
|
||
with pytest.raises(TypeError): | ||
# 'n' must be an integer | ||
CholeskyCorr(n="three") | ||
|
||
|
||
def test_lkjcorr_positive_definite(): | ||
""" | ||
Ensure that all sampled correlation matrices are positive definite. | ||
""" | ||
with pm.Model() as model: | ||
rho = pm.LKJCorr("rho", n=4, eta=2) | ||
|
||
trace = pm.sample( | ||
100, tune=100, chains=1, cores=1, progressbar=False, return_inferencedata=False | ||
) | ||
|
||
# Extract the sampled correlation matrices | ||
rho_samples = trace["rho"] | ||
num_samples = rho_samples.shape[0] | ||
|
||
for i in range(num_samples): | ||
sample_matrix = rho_samples[i] | ||
|
||
# Check if the sampled matrix is positive definite | ||
try: | ||
np.linalg.cholesky(sample_matrix) | ||
except np.linalg.LinAlgError: | ||
pytest.fail(f"Sampled correlation matrix at index {i} is not positive definite.") | ||
|
||
|
||
def test_lkjcorr_round_trip_various_sizes(): | ||
""" | ||
Perform round-trip transformation tests for various sizes of correlation matrices. | ||
""" | ||
for n in [2, 3, 4]: | ||
transform = CholeskyCorr(n=n) | ||
unconstrained_size = int(n * (n - 1) / 2) | ||
|
||
# Generate random unconstrained real numbers | ||
x = np.random.randn(unconstrained_size).astype(pytensor.config.floatX) | ||
x_tensor = pt.as_tensor_variable(x) | ||
|
||
# Perform forward transform | ||
y = transform.forward(x_tensor).eval() | ||
|
||
# Perform backward transform | ||
reconstructed = transform.backward(y).eval() | ||
|
||
# Assert that the original and reconstructed unconstrained parameters are close | ||
assert_allclose(x, reconstructed, atol=1e-6) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you delete this transform class as well? It was a (wrong) patch to the problem you're solving
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can do. Just to confirm, you don't consider MultivariateIntervalTransform to be part of pymc's public API?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nope, can be removed without worries
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok - great