Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat: Implement Full-Rank VI #720

Draft
wants to merge 12 commits into
base: main
Choose a base branch
from
Draft

Conversation

gil2rok
Copy link
Contributor

@gil2rok gil2rok commented Aug 14, 2024

Implement variational inference (VI) with full-rank Gaussian approximation

While mean-field VI learns a Gaussian approximation $N(\mu, \sigma I)$ with diagonal covariance $\sigma I$, full-rank VI learns a Gaussian approximation $N(\mu, \Sigma)$ with full-rank covariance $\Sigma$.

We use the Cholesky decomposition $\Sigma = C C^T$ parameterized by the log standard deviation $\rho$ and lower triangle matrix $L$ such that $C = \exp(\rho) + L$. This (1) ensures $\Sigma$ remains symmetric and positive-definite during optimization 1 and (2) admits better sampling and log density computation (with improved time complexity, space complexity, and numerical stability)2. Thus the full-rank Gaussian is parameterized by $(\mu, \rho, L)$.

To-Do:

  • Support arbitrary PyTrees instead of just arrays
  • Decide on format for lower triangular matrix $L$ ($n$-length flattened array vs dense $2D$ array for $n = d(d+1)/2$)
  • Confirm initialization strategy from standard normal
  • Write tests (current test fails)

Footnotes

  1. Automatic Differentiation Variational Inference Section 2.4

  2. Lecture notes

@gil2rok
Copy link
Contributor Author

gil2rok commented Aug 14, 2024

@junpenglao I am trying to better grok JAX's PyTrees. I would love specific feedback on how to support PyTrees instead of just JAX arrays in my code. I believe this mostly affects my init(), _sample(), and generate_fullrank_logdensity() functions.

For example, consider my sampling implementation vs that in meanfield_vi.py:

def _sample(rng_key, mu, rho, L, num_samples):
    cholesky = jnp.tril(L, k=-1) + jnp.diag(jnp.exp(L))
    eps = jax.random.normal(rng_key, (num_samples,) + mu.shape)
    return mu + eps @ cholesky.T
def _sample(rng_key, mu, rho, num_samples):
    sigma = jax.tree.map(jnp.exp, rho)
    mu_flatten, unravel_fn = jax.flatten_util.ravel_pytree(mu)
    sigma_flat, _ = jax.flatten_util.ravel_pytree(sigma)
    flatten_sample = (
        jax.random.normal(rng_key, (num_samples,) + mu_flatten.shape) * sigma_flat
        + mu_flatten
    )
    return jax.vmap(unravel_fn)(flatten_sample)

How + why would I change my code? Any resource recommendations would be greatly appreciated.

@junpenglao junpenglao mentioned this pull request Aug 15, 2024
Copy link
Member

@junpenglao junpenglao left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PyTree with covariate matrix is a headache to deal with. I suggest you take a look at how pathfinder deal with it internally: basically all the state parameter is represented as flatten array, and you only unflatten at the end. Considering what you have right now already assume everything is a flatten array, you just need to add the flatten and unflatten part.

Comment on lines 155 to 165
def generate_fullrank_logdensity(mu, rho, L):
cholesky = jnp.tril(L, k=-1) + jnp.diag(jnp.exp(L))
log_det = 2 * jnp.sum(rho)
const = -0.5 * mu.shape[-1] * jnp.log(2 * jnp.pi)

def fullrank_logdensity(position):
y = jsp.linalg.solve_triangular(cholesky, position - mu, lower=True)
mahalanobis_dist = jnp.sum(y ** 2, axis=-1)
return const - 0.5 * log_det - 0.5 * mahalanobis_dist

return fullrank_logdensity
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not use multivariate_normal.logpdf from JAX?

Copy link
Contributor Author

@gil2rok gil2rok Aug 15, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wanted to avoid computing the inverse and log determinant of the covariance matrix $\Sigma$ by using the Cholesky factor when computing the logpdf.

Does jax.random.multivariate_normal.logpdf take Cholesky factors as input? I want to avoid needing to compute the covariance matrix $\Sigma = C C^T$, and then pass it into JAX's multivariate normal which separates it back into the Cholesky factor $C$.

From https://jax.readthedocs.io/en/latest/_autosummary/jax.random.multivariate_normal.html it appears the multivariate normal log density only accepts the covariance as a dense matrix!

Screenshot 2024-08-15 at 2 21 17 PM

See jax-ml/jax#11386. Thoughts on tradeoff btwn readability (with JAX's multivariate normal) and speed (custom implementation)?

@gil2rok
Copy link
Contributor Author

gil2rok commented Aug 15, 2024

A couple notes about the Cholesky factor:

  1. I believe my implementation is incorrect. I compute the $d \times d$ Cholesky factor as $C = \exp(\rho) + L$ for log standard deviation $\rho$ and lower triangle matrix $L$. However, I store $L$ as a $d \times d$ matrix instead of just the lower triangle:
L = jax.tree.map(lambda x: jnp.zeros((*x.shape, x.shape)), position)

During optimization, this means we may learn non-zero values for the diagonal and upper triangle in $L$! Thus I instead will store $L$ as list of length $n$ representing the number of elements in the lower triangle of a $d \times d$ matrix excluding the diagonal. Thus $n = \sum_{i=1}^{d-1} i = (d - 1)(d) / 2$. We now construct the Cholesky factor as:

def unflatten_lower_triangular(tril_flat):
    n = tril_flat.size  # Number of elements in the lower triangular part
    d = int(jnp.sqrt(1 + 8 * n) - 1) // 2  # Dimension of the original matrix
    
    lower_tri_matrix = jnp.zeros((d, d))
    indices = jnp.tril_indices(d, k=-1)  # Indices for the lower triangle (excluding the diagonal)
    return lower_tri_matrix.at[indices].set(tril_flat)

C = jnp.diag(jnp.exp(rho)) + unflatten_lower_triangular(L)
  1. Given my full-rank Gaussian parameterization as $(\mu, \rho, L)$, how should I flatten/unflatten to work with arbitrary Pytrees? Following @junpenglao 's advice, I see Pathfinder's approximate function:
def approximate(...):
    initial_position_flatten, unravel_fn = ravel_pytree(initial_position)
    ...
    unravel_fn_mapped = jax.vmap(unravel_fn)
    pathfinder_result = PathfinderState(
        elbo,
        unravel_fn_mapped(position),
        unravel_fn_mapped(grad_position),
        alpha,
        beta,
        gamma,
    )
    ...

How would I apply this to my _sample() function? I'm struggling to understand why flattening helps us handle arbitrary PyTrees, and how this works for the lower triangle $L$ vs mean $\mu$ or log std $\rho$. Is this correct or is more flattening needed?

def _sample(rng_key, mu, rho, L, num_samples):
    cholesky = jnp.diag(jnp.exp(rho)) + unflatten_lower_triangular(L)
    eps = jax.random.normal(rng_key, (num_samples,) + mu.shape)
    return mu + eps @ cholesky.T

Define `chol_params` as a flattened Cholesky factor PyTree that consists
of diagonal elements followed by the off-diagonal elements in row-major
order for n = dim * (dim + 1) / 2 elements.

The diagonal (first dim elements) are passed through a softplus function
to ensure positivity, crucial to maintain a valid covariance matrix

This parameterization allows for unconstrained optimization while
ensuring the resulting covariance matrix Sigma = CC^T is symmetric and
positive definite.

The `chol_params` are then reshaped into a lower triangular matrix
`chol_factor` using `jnp.tril` and `jnp.diag` functions.
@gil2rok
Copy link
Contributor Author

gil2rok commented Aug 16, 2024

  1. Largely disregard my previous comments about the Cholesky factor in Feat: Implement Full-Rank VI #720 (comment). I would love some feedback on my current implementation, which more closely follows the original pull request Implement Fullrank vi #479.
  2. For simplicity, I changed the code to use jax.random.multivariate_normal.logpdf even though it redundantly recommits the Cholesky decomposition internally. Do you think I should keep it or revert back to my more efficient implementation?

def fullrank_logdensity(position):
position_flatten = jax.flatten_util.ravel_pytree(position)[0]
# TODO: inefficient because of redundant cholesky decomposition
return jsp.stats.multivariate_normal.logpdf(position_flatten, mu_flatten, cov)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You have a good point about computation efficiency, let's keep the cholesky version of the logpdf. Could you rewrite the API to similar to multivariate_normal.logpdf, and use function partial when you call it?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you specify the API you imagine my implementation would have and how function partial would be applied to it? Feel free to reference my code I just committed.

@junpenglao
Copy link
Member

I see, you are following the pattern in meanfield VI, I think there are a few places need refactoring there. Let me send out a PR so you see what i meant.

@gil2rok
Copy link
Contributor Author

gil2rok commented Aug 18, 2024

I see, you are following the pattern in meanfield VI, I think there are a few places need refactoring there. Let me send out a PR so you see what i meant.

Once you finish refactoring Meanfield VI, I'm happy to adapt the new style. Let me know when you're done!

Fix testing bug, add docstrings, and change softmax to exponential when
converting `chol_params` to `chol_factor` in `_unflatten_cholesky`.
Refactor `_unflatten_cholesky()` function to take `dim` argument instead
of infering it (dynamically) from the `chol_params` input vector. This
avoids JIT compilation issues.

Also update docstrings.
Add assert statements that verify full-rank VI recovers the true,
full-rank covariance matrix.
@gil2rok
Copy link
Contributor Author

gil2rok commented Sep 11, 2024

Update: need to figure out why the full covariance matrix is not being recovered. May need to come back to this in a few weeks b/c of some deadlines.

@junpenglao junpenglao self-assigned this Sep 12, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants