-
Notifications
You must be signed in to change notification settings - Fork 106
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Feat: Implement Full-Rank VI #720
Draft
gil2rok
wants to merge
12
commits into
blackjax-devs:main
Choose a base branch
from
gil2rok:fullrank_vi
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Draft
Changes from 1 commit
Commits
Show all changes
12 commits
Select commit
Hold shift + click to select a range
cbaff41
fullrank vi first draft
gil2rok f80b593
Fix: cholesky factor as flattened PyTree
gil2rok c2b38eb
Doc: clarify chol_params order
gil2rok b13eb12
Doc: formatting
gil2rok 4ea435c
Enh: compute normal log density with cholesky factor
gil2rok 5119889
Doc: formatting
gil2rok 6b9c002
Doc: Clarify Cholesky unflattening
gil2rok 4379a6d
Fix: Non-jitted full-rank VI works
gil2rok bb08e1f
Doc: formatting
gil2rok ced702f
Fix: Full-rank VI compatible with JIT compilation
gil2rok 26da046
Doc: formatting
gil2rok 4b4534f
Tests: Check full-rank covariance matrix
gil2rok File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,3 @@ | ||
from . import meanfield_vi, pathfinder, schrodinger_follmer, svgd | ||
from . import fullrank_vi, meanfield_vi, pathfinder, schrodinger_follmer, svgd | ||
|
||
__all__ = ["pathfinder", "meanfield_vi", "svgd", "schrodinger_follmer"] | ||
__all__ = ["fullrank_vi", "meanfield_vi", "pathfinder", "svgd", "schrodinger_follmer"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,165 @@ | ||
# Copyright 2020- The Blackjax Authors. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
from typing import Callable, NamedTuple | ||
|
||
import jax | ||
import jax.numpy as jnp | ||
import jax.scipy as jsp | ||
from optax import GradientTransformation, OptState | ||
|
||
from blackjax.base import VIAlgorithm | ||
from blackjax.types import ArrayLikeTree, ArrayTree, PRNGKey | ||
|
||
__all__ = [ | ||
"FRVIState", | ||
"FRVIInfo", | ||
"sample", | ||
"generate_fullrank_logdensity", | ||
"step", | ||
"as_top_level_api", | ||
] | ||
|
||
|
||
class FRVIState(NamedTuple): | ||
mu: ArrayTree | ||
rho: ArrayTree | ||
L: ArrayTree | ||
opt_state: OptState | ||
|
||
|
||
class FRVIInfo(NamedTuple): | ||
elbo: float | ||
|
||
|
||
def init( | ||
position: ArrayLikeTree, | ||
optimizer: GradientTransformation, | ||
*optimizer_args, | ||
**optimizer_kwargs, | ||
) -> FRVIState: | ||
"""Initialize the full-rank VI state.""" | ||
mu = jax.tree.map(jnp.zeros_like, position) | ||
rho = jax.tree.map(jnp.zeros_like, position) | ||
L = jax.tree.map(lambda x: jnp.zeros((*x.shape, x.shape)), position) | ||
opt_state = optimizer.init((mu, rho, L)) | ||
return FRVIState(mu, rho, L, opt_state) | ||
|
||
|
||
def step( | ||
rng_key: PRNGKey, | ||
state: FRVIState, | ||
logdensity_fn: Callable, | ||
optimizer: GradientTransformation, | ||
num_samples: int = 5, | ||
stl_estimator: bool = True, | ||
) -> tuple[FRVIState, FRVIInfo]: | ||
"""Approximate the target density using the full-rank approximation. | ||
|
||
Parameters | ||
---------- | ||
rng_key | ||
Key for JAX's pseudo-random number generator. | ||
init_state | ||
Initial state of the full-rank approximation. | ||
logdensity_fn | ||
Function that represents the target log-density to approximate. | ||
optimizer | ||
Optax `GradientTransformation` to be used for optimization. | ||
num_samples | ||
The number of samples that are taken from the approximation | ||
at each step to compute the Kullback-Leibler divergence between | ||
the approximation and the target log-density. | ||
stl_estimator | ||
Whether to use stick-the-landing (STL) gradient estimator :cite:p:`roeder2017sticking` for gradient estimation. | ||
The STL estimator has lower gradient variance by removing the score function term | ||
from the gradient. It is suggested by :cite:p:`agrawal2020advances` to always keep it in order for better results. | ||
|
||
""" | ||
|
||
parameters = (state.mu, state.rho, state.L) | ||
|
||
def kl_divergence_fn(parameters): | ||
mu, rho, L = parameters | ||
z = _sample(rng_key, mu, rho, L, num_samples) | ||
if stl_estimator: | ||
parameters = jax.tree_map(jax.lax.stop_gradient, (mu, rho, L)) | ||
logq = jax.vmap(generate_fullrank_logdensity(mu, rho, L))(z) | ||
logp = jax.vmap(logdensity_fn)(z) | ||
return (logq - logp).mean() | ||
|
||
elbo, elbo_grad = jax.value_and_grad(kl_divergence_fn)(parameters) | ||
updates, new_opt_state = optimizer.update(elbo_grad, state.opt_state, parameters) | ||
new_parameters = jax.tree.map(lambda p, u: p + u, parameters, updates) | ||
|
||
new_mu, new_rho, new_L = new_parameters | ||
return FRVIState(new_mu, new_rho, new_L, new_opt_state), FRVIInfo(elbo) | ||
|
||
|
||
def sample(rng_key: PRNGKey, state: FRVIState, num_samples: int = 1): | ||
"""Sample from the full-rank approximation.""" | ||
return _sample(rng_key, state.mu, state.rho, state.L, num_samples) | ||
|
||
|
||
def as_top_level_api( | ||
logdensity_fn: Callable, | ||
optimizer: GradientTransformation, | ||
num_samples: int = 100, | ||
): | ||
"""High-level implementation of Full-Rank Variational Inference. | ||
|
||
Parameters | ||
---------- | ||
logdensity_fn | ||
A function that represents the log-density function associated with | ||
the distribution we want to sample from. | ||
optimizer | ||
Optax optimizer to use to optimize the ELBO. | ||
num_samples | ||
Number of samples to take at each step to optimize the ELBO. | ||
|
||
Returns | ||
------- | ||
A ``VIAlgorithm``. | ||
|
||
""" | ||
|
||
def init_fn(position: ArrayLikeTree): | ||
return init(position, optimizer) | ||
|
||
def step_fn(rng_key: PRNGKey, state: FRVIState) -> tuple[FRVIState, FRVIInfo]: | ||
return step(rng_key, state, logdensity_fn, optimizer, num_samples) | ||
|
||
def sample_fn(rng_key: PRNGKey, state: FRVIState, num_samples: int): | ||
return sample(rng_key, state, num_samples) | ||
|
||
return VIAlgorithm(init_fn, step_fn, sample_fn) | ||
|
||
|
||
def _sample(rng_key, mu, rho, L, num_samples): | ||
cholesky = jnp.tril(L, k=-1) + jnp.diag(jnp.exp(L)) | ||
eps = jax.random.normal(rng_key, (num_samples,) + mu.shape) | ||
return mu + eps @ cholesky.T | ||
|
||
|
||
def generate_fullrank_logdensity(mu, rho, L): | ||
cholesky = jnp.tril(L, k=-1) + jnp.diag(jnp.exp(L)) | ||
log_det = 2 * jnp.sum(rho) | ||
const = -0.5 * mu.shape[-1] * jnp.log(2 * jnp.pi) | ||
|
||
def fullrank_logdensity(position): | ||
y = jsp.linalg.solve_triangular(cholesky, position - mu, lower=True) | ||
mahalanobis_dist = jnp.sum(y ** 2, axis=-1) | ||
return const - 0.5 * log_det - 0.5 * mahalanobis_dist | ||
|
||
return fullrank_logdensity | ||
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
import chex | ||
import jax | ||
import jax.numpy as jnp | ||
import jax.scipy.stats as stats | ||
import optax | ||
from absl.testing import absltest | ||
|
||
import blackjax | ||
|
||
|
||
class FullRankVITest(chex.TestCase): | ||
def setUp(self): | ||
super().setUp() | ||
self.key = jax.random.PRNGKey(42) | ||
|
||
@chex.variants(with_jit=True, without_jit=True) | ||
def test_recover_posterior(self): | ||
ground_truth = [ | ||
# loc, scale | ||
(2, 4), | ||
(3, 5), | ||
] | ||
|
||
def logdensity_fn(x): | ||
logpdf = stats.norm.logpdf(x["x_1"], *ground_truth[0]) + stats.norm.logpdf( | ||
x["x_2"], *ground_truth[1] | ||
) | ||
return jnp.sum(logpdf) | ||
|
||
initial_position = {"x_1": 0.0, "x_2": 0.0} | ||
|
||
num_steps = 50_000 | ||
num_samples = 500 | ||
|
||
optimizer = optax.sgd(1e-2) | ||
frvi = blackjax.fullrank_vi(logdensity_fn, optimizer, num_samples) | ||
state = frvi.init(initial_position) | ||
|
||
rng_key = self.key | ||
for i in range(num_steps): | ||
subkey = jax.random.split(rng_key, i) | ||
state, _ = self.variant(frvi.step)(subkey, state) | ||
|
||
loc_1, loc_2 = state.mu["x_1"], state.mu["x_2"] | ||
self.assertAlmostEqual(loc_1, ground_truth[0][0], delta=0.01) | ||
self.assertAlmostEqual(loc_2, ground_truth[1][0], delta=0.01) | ||
|
||
if __name__ == "__main__": | ||
absltest.main() |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why not use
multivariate_normal.logpdf
from JAX?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wanted to avoid computing the inverse and log determinant of the covariance matrix$\Sigma$ by using the Cholesky factor when computing the logpdf.
Does$\Sigma = C C^T$ , and then pass it into JAX's multivariate normal which separates it back into the Cholesky factor $C$ .
jax.random.multivariate_normal.logpdf
take Cholesky factors as input? I want to avoid needing to compute the covariance matrixFrom https://jax.readthedocs.io/en/latest/_autosummary/jax.random.multivariate_normal.html it appears the multivariate normal log density only accepts the covariance as a dense matrix!
See jax-ml/jax#11386. Thoughts on tradeoff btwn readability (with JAX's multivariate normal) and speed (custom implementation)?