Skip to content

Commit

Permalink
Fix: Full-rank VI compatible with JIT compilation
Browse files Browse the repository at this point in the history
Refactor `_unflatten_cholesky()` function to take `dim` argument instead
of infering it (dynamically) from the `chol_params` input vector. This
avoids JIT compilation issues.

Also update docstrings.
  • Loading branch information
gil2rok committed Aug 19, 2024
1 parent bb08e1f commit ced702f
Showing 1 changed file with 33 additions and 14 deletions.
47 changes: 33 additions & 14 deletions blackjax/vi/fullrank_vi.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def step(
num_samples: int = 5,
stl_estimator: bool = True,
) -> tuple[FRVIState, FRVIInfo]:
"""Approximate the target density using the full-rank Gaussian approximation
"""Approximate the target density using the full-rank Gaussian approximation.
Parameters
----------
Expand Down Expand Up @@ -166,35 +166,36 @@ def sample_fn(rng_key: PRNGKey, state: FRVIState, num_samples: int):
return VIAlgorithm(init_fn, step_fn, sample_fn)


def _unflatten_cholesky(chol_params):
def _unflatten_cholesky(chol_params, dim):
"""Construct the Cholesky factor from a flattened vector of Cholesky parameters.
Transforms a flattened vector representation of the Cholesky factor (`chol_params`)
into its proper lower triangular matrix form (`chol_factor`). It specifically
reshapes the input vector `chol_params` into a lower triangular matrix with zeros
above the diagonal and exponentiates the diagonal elements to ensure positivity.
The Cholesky factor (L) is a lower triangular matrix with positive diagonal
elements used to parameterize the (full-rank) covariance matrix of the Gaussian
elements used to parameterize the full-rank covariance matrix of the Gaussian
approximation as Sigma = LL^T.
This parameterization allows for (1) efficient sampling and log density evaluation,
and (2) ensuring the covariance matrix is symmetric and positive definite during
(unconconstrained) optimization.
Transforms a flattened vector representation of the Cholesky factor (`chol_params`)
into its proper lower triangular matrix form (`chol_factor`). It specifically
reshapes the input vector `chol_params` into a lower triangular matrix with zeros
above the diagonal and exponentiates the diagonal elements to ensure positivity.
Parameters
----------
chol_params
Flattened Cholesky factor of the full-rank covariance matrix.
dim
Dimensionality of the Gaussian distribution.
Returns
-------
chol_factor
Cholesky factor of the full-rank covariance matrix.
"""

n = chol_params.size
dim = int(jnp.sqrt(1 + 8 * n) - 1) // 2
tril = jnp.zeros((dim, dim))
tril = tril.at[jnp.tril_indices(dim, k=-1)].set(chol_params[dim:])
diag = jnp.exp(chol_params[:dim]) # TODO: replace with softplus?
Expand All @@ -221,18 +222,36 @@ def _sample(rng_key, mu, chol_params, num_samples):
Samples drawn from the full-rank Gaussian approximation.
"""

mu_flatten, unravel_fn = jax.flatten_util.ravel_pytree(mu)
chol_factor = _unflatten_cholesky(chol_params)
eps = jax.random.normal(rng_key, (num_samples,) + mu_flatten.shape)
dim = mu_flatten.size
chol_factor = _unflatten_cholesky(chol_params, dim)
eps = jax.random.normal(rng_key, (num_samples,) + (dim,))
flatten_sample = eps @ chol_factor.T + mu_flatten
return jax.vmap(unravel_fn)(flatten_sample)


def generate_fullrank_logdensity(mu, chol_params):
"""Generate the log-density function of a full-rank Gaussian distribution.
Parameters
----------
mu
Mean of the Gaussian distribution.
chol_params
Flattened Cholesky factor of the Gaussian distribution.
Returns
-------
A function that computes the log-density of the full-rank Gaussian distribution.
"""

mu_flatten, _ = jax.flatten_util.ravel_pytree(mu)
chol_factor = _unflatten_cholesky(chol_params)
dim = mu_flatten.size
chol_factor = _unflatten_cholesky(chol_params, dim)
log_det = 2 * jnp.sum(jnp.log(jnp.diag(chol_factor)))
const = -0.5 * mu_flatten.size * jnp.log(2 * jnp.pi)
const = -0.5 * dim * jnp.log(2 * jnp.pi)

def fullrank_logdensity(position):
position_flatten, _ = jax.flatten_util.ravel_pytree(position)
Expand Down

0 comments on commit ced702f

Please sign in to comment.