Skip to content

Commit

Permalink
Enh: compute normal log density with cholesky factor
Browse files Browse the repository at this point in the history
  • Loading branch information
gil2rok committed Aug 16, 2024
1 parent b13eb12 commit 4ea435c
Showing 1 changed file with 9 additions and 6 deletions.
15 changes: 9 additions & 6 deletions blackjax/vi/fullrank_vi.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def init(
"""Initialize the full-rank VI state."""
mu = jax.tree.map(jnp.zeros_like, position)
dim = jax.flatten_util.ravel_pytree(mu)[0].shape[0]
chol_params = jax.flatten_util.ravel_pytree(jnp.tril(jnp.eye(dim)))[0]
chol_params, _ = jax.flatten_util.ravel_pytree(jnp.tril(jnp.eye(dim)))
opt_state = optimizer.init((mu, chol_params))
return FRVIState(mu, chol_params, opt_state)

Expand Down Expand Up @@ -170,19 +170,22 @@ def _unflatten_cholesky(chol_params):
def _sample(rng_key, mu, chol_params, num_samples):
mu_flatten, unravel_fn = jax.flatten_util.ravel_pytree(mu)
chol_factor = _unflatten_cholesky(chol_params)
eps = jax.random.normal(rng_key, (num_samples, mu_flatten.shape[0]))
eps = jax.random.normal(rng_key, (num_samples, mu_flatten.size))
flatten_sample = eps @ chol_factor.T + mu_flatten
return jax.vmap(unravel_fn)(flatten_sample)


def generate_fullrank_logdensity(mu, chol_params):
mu_flatten, _ = jax.flatten_util.ravel_pytree(mu)
chol_factor = _unflatten_cholesky(chol_params)
cov = chol_factor @ chol_factor.T
log_det = 2 * jnp.sum(jnp.log(jnp.diag(chol_factor)))
const = -0.5 * mu_flatten.size * jnp.log(2 * jnp.pi)

def fullrank_logdensity(position):
position_flatten = jax.flatten_util.ravel_pytree(position)[0]
# TODO: inefficient because of redundant cholesky decomposition
return jsp.stats.multivariate_normal.logpdf(position_flatten, mu_flatten, cov)
position_flatten, _ = jax.flatten_util.ravel_pytree(position)
centered_position = position_flatten - mu_flatten
y = jsp.linalg.solve_triangular(chol_factor, centered_position, lower=True)
mahalanobis_dist = jnp.sum(y ** 2)
return const - 0.5 * (log_det + mahalanobis_dist)

return fullrank_logdensity

0 comments on commit 4ea435c

Please sign in to comment.