diff --git a/blackjax/vi/fullrank_vi.py b/blackjax/vi/fullrank_vi.py index 5e4e23990..321c1fa5c 100644 --- a/blackjax/vi/fullrank_vi.py +++ b/blackjax/vi/fullrank_vi.py @@ -222,7 +222,7 @@ def _sample(rng_key, mu, chol_params, num_samples): Samples drawn from the full-rank Gaussian approximation. """ - + mu_flatten, unravel_fn = jax.flatten_util.ravel_pytree(mu) dim = mu_flatten.size chol_factor = _unflatten_cholesky(chol_params, dim)