diff --git a/blackjax/vi/fullrank_vi.py b/blackjax/vi/fullrank_vi.py index 4ae17fcf6..5e4e23990 100644 --- a/blackjax/vi/fullrank_vi.py +++ b/blackjax/vi/fullrank_vi.py @@ -85,7 +85,7 @@ def step( num_samples: int = 5, stl_estimator: bool = True, ) -> tuple[FRVIState, FRVIInfo]: - """Approximate the target density using the full-rank Gaussian approximation + """Approximate the target density using the full-rank Gaussian approximation. Parameters ---------- @@ -166,35 +166,36 @@ def sample_fn(rng_key: PRNGKey, state: FRVIState, num_samples: int): return VIAlgorithm(init_fn, step_fn, sample_fn) -def _unflatten_cholesky(chol_params): +def _unflatten_cholesky(chol_params, dim): """Construct the Cholesky factor from a flattened vector of Cholesky parameters. + Transforms a flattened vector representation of the Cholesky factor (`chol_params`) + into its proper lower triangular matrix form (`chol_factor`). It specifically + reshapes the input vector `chol_params` into a lower triangular matrix with zeros + above the diagonal and exponentiates the diagonal elements to ensure positivity. + The Cholesky factor (L) is a lower triangular matrix with positive diagonal - elements used to parameterize the (full-rank) covariance matrix of the Gaussian + elements used to parameterize the full-rank covariance matrix of the Gaussian approximation as Sigma = LL^T. This parameterization allows for (1) efficient sampling and log density evaluation, and (2) ensuring the covariance matrix is symmetric and positive definite during (unconconstrained) optimization. - Transforms a flattened vector representation of the Cholesky factor (`chol_params`) - into its proper lower triangular matrix form (`chol_factor`). It specifically - reshapes the input vector `chol_params` into a lower triangular matrix with zeros - above the diagonal and exponentiates the diagonal elements to ensure positivity. - Parameters ---------- chol_params Flattened Cholesky factor of the full-rank covariance matrix. + dim + Dimensionality of the Gaussian distribution. Returns ------- chol_factor Cholesky factor of the full-rank covariance matrix. + """ - n = chol_params.size - dim = int(jnp.sqrt(1 + 8 * n) - 1) // 2 tril = jnp.zeros((dim, dim)) tril = tril.at[jnp.tril_indices(dim, k=-1)].set(chol_params[dim:]) diag = jnp.exp(chol_params[:dim]) # TODO: replace with softplus? @@ -221,18 +222,36 @@ def _sample(rng_key, mu, chol_params, num_samples): Samples drawn from the full-rank Gaussian approximation. """ + mu_flatten, unravel_fn = jax.flatten_util.ravel_pytree(mu) - chol_factor = _unflatten_cholesky(chol_params) - eps = jax.random.normal(rng_key, (num_samples,) + mu_flatten.shape) + dim = mu_flatten.size + chol_factor = _unflatten_cholesky(chol_params, dim) + eps = jax.random.normal(rng_key, (num_samples,) + (dim,)) flatten_sample = eps @ chol_factor.T + mu_flatten return jax.vmap(unravel_fn)(flatten_sample) def generate_fullrank_logdensity(mu, chol_params): + """Generate the log-density function of a full-rank Gaussian distribution. + + Parameters + ---------- + mu + Mean of the Gaussian distribution. + chol_params + Flattened Cholesky factor of the Gaussian distribution. + + Returns + ------- + A function that computes the log-density of the full-rank Gaussian distribution. + + """ + mu_flatten, _ = jax.flatten_util.ravel_pytree(mu) - chol_factor = _unflatten_cholesky(chol_params) + dim = mu_flatten.size + chol_factor = _unflatten_cholesky(chol_params, dim) log_det = 2 * jnp.sum(jnp.log(jnp.diag(chol_factor))) - const = -0.5 * mu_flatten.size * jnp.log(2 * jnp.pi) + const = -0.5 * dim * jnp.log(2 * jnp.pi) def fullrank_logdensity(position): position_flatten, _ = jax.flatten_util.ravel_pytree(position)