From 029b981e4af2f0aefe3a17c031ff9c238d527777 Mon Sep 17 00:00:00 2001 From: Junpeng Lao Date: Tue, 12 Dec 2023 06:20:33 +0100 Subject: [PATCH] Refactor mgrad_gaussian (#628) * Refactor mgrad_gaussian * fix formatting * Add a svd_from_cov helper function --- blackjax/mcmc/marginal_latent_gaussian.py | 155 +++++++++++++++++----- tests/mcmc/test_latent_gaussian.py | 17 ++- tests/mcmc/test_sampling.py | 10 +- 3 files changed, 135 insertions(+), 47 deletions(-) diff --git a/blackjax/mcmc/marginal_latent_gaussian.py b/blackjax/mcmc/marginal_latent_gaussian.py index 0a3d9ad0a..8c910769b 100644 --- a/blackjax/mcmc/marginal_latent_gaussian.py +++ b/blackjax/mcmc/marginal_latent_gaussian.py @@ -22,7 +22,7 @@ from blackjax.mcmc.proposal import static_binomial_sampling from blackjax.types import Array, PRNGKey -__all__ = ["MarginalState", "MarginalInfo", "init_and_kernel", "mgrad_gaussian"] +__all__ = ["MarginalState", "MarginalInfo", "init", "build_kernel", "mgrad_gaussian"] # [TODO](https://github.com/blackjax-devs/blackjax/issues/237) @@ -50,6 +50,40 @@ class MarginalState(NamedTuple): U_grad_x: Array +class CovarianceSVD(NamedTuple): + """Singular Value Decomposition of the covariance matrix. + + U + Unitary array of the covariance matrix. + Gamma + Singular values of the covariance matrix. + U_t + Transpose of the unitary array of the covariance matrix. + + """ + + U: Array + Gamma: Array + U_t: Array + + +def svd_from_covariance(covariance: Array) -> CovarianceSVD: + """Compute the singular value decomposition of the covariance matrix. + + Parameters + ---------- + covariance + The covariance matrix. + + Returns + ------- + A ``CovarianceSVD`` object. + + """ + U, Gamma, U_t = jnp.linalg.svd(covariance, hermitian=True) + return CovarianceSVD(U, Gamma, U_t) + + class MarginalInfo(NamedTuple): """Additional information on the RMH chain. @@ -72,28 +106,66 @@ class MarginalInfo(NamedTuple): proposal: MarginalState -def init_and_kernel(logdensity_fn, covariance, mean=None): - """Build the marginal version of the auxiliary gradient-based sampler +def generate_mean_shifted_logprob(logdensity_fn, mean, covariance): + """Generate a log-density function that is shifted by a constant + + Parameters + ---------- + logdensity_fn + The original log-density function + mean + The mean of the prior Gaussian density + covariance + The covariance of the prior Gaussian density. + + Returns + ------- + A log-density function that is shifted by a constant + + """ + shift = linalg.solve(covariance, mean, assume_a="pos") + + def shifted_logdensity_fn(x): + return logdensity_fn(x) + jnp.dot(x, shift) + + return shifted_logdensity_fn + + +def init(position, logdensity_fn, U_t): + """Initialize the marginal version of the auxiliary gradient-based sampler. + + Parameters + ---------- + position + The initial position of the chain. + logdensity_fn + The logarithm of the likelihood function for the latent Gaussian model. + U_t + The unitary array of the covariance matrix. + """ + logdensity, logdensity_grad = jax.value_and_grad(logdensity_fn)(position) + return MarginalState( + position, logdensity, logdensity_grad, U_t @ position, U_t @ logdensity_grad + ) + + +def build_kernel(cov_svd: CovarianceSVD): + """Build the marginal version of the auxiliary gradient-based sampler. + + Parameters + ---------- + cov_svd + The singular value decomposition of the covariance matrix. Returns ------- A kernel that takes a rng_key and a Pytree that contains the current state of the chain and that returns a new state of the chain along with information about the transition. - An init function. - """ - U, Gamma, U_t = jnp.linalg.svd(covariance, hermitian=True) + U, Gamma, U_t = cov_svd - if mean is not None: - shift = linalg.solve(covariance, mean, assume_a="pos") - val_and_grad = jax.value_and_grad( - lambda x: logdensity_fn(x) + jnp.dot(x, shift) - ) - else: - val_and_grad = jax.value_and_grad(logdensity_fn) - - def step(key: PRNGKey, state: MarginalState, delta): + def kernel(key: PRNGKey, state: MarginalState, logdensity_fn, delta): y_key, u_key = jax.random.split(key, 2) position, logdensity, logdensity_grad, U_x, U_grad_x = state @@ -111,7 +183,7 @@ def step(key: PRNGKey, state: MarginalState, delta): y = U @ temp # Bookkeeping - log_p_y, grad_y = val_and_grad(y) + log_p_y, grad_y = jax.value_and_grad(logdensity_fn)(y) U_y = U_t @ y U_grad_y = U_t @ grad_y @@ -131,39 +203,34 @@ def step(key: PRNGKey, state: MarginalState, delta): info = MarginalInfo(p_accept, do_accept, proposed_state) return accepted_state, info - def init(position): - logdensity, logdensity_grad = val_and_grad(position) - return MarginalState( - position, logdensity, logdensity_grad, U_t @ position, U_t @ logdensity_grad - ) - - return init, step + return kernel class mgrad_gaussian: """Implements the marginal sampler for latent Gaussian model of :cite:p:`titsias2018auxiliary`. It uses a first order approximation to the log_likelihood of a model with Gaussian prior. - Interestingly, the only parameter that needs calibrating is the "step size" delta, which can be done very efficiently. + Interestingly, the only parameter that needs calibrating is the "step size" delta, + which can be done very efficiently. Calibrating it to have an acceptance rate of roughly 50% is a good starting point. Examples -------- - A new marginal latent Gaussian MCMC kernel for a model q(x) ∝ exp(f(x)) N(x; m, C) can be initialized and - used for a given "step size" delta with the following code: + A new marginal latent Gaussian MCMC kernel for a model q(x) ∝ exp(f(x)) N(x; m, C) + can be initialized and used for a given "step size" delta with the following code: .. code:: - mgrad_gaussian = blackjax.mgrad_gaussian(f, C, use_inverse=False, mean=m) + mgrad_gaussian = blackjax.mgrad_gaussian(f, C, mean=m, step_size=delta) state = mgrad_gaussian.init(zeros) # Starting at the mean of the prior - new_state, info = mgrad_gaussian.step(rng_key, state, delta) + new_state, info = mgrad_gaussian.step(rng_key, state) We can JIT-compile the step function for better performance .. code:: step = jax.jit(mgrad_gaussian.step) - new_state, info = step(rng_key, state, delta) + new_state, info = step(rng_key, state) Parameters ---------- @@ -180,22 +247,40 @@ class mgrad_gaussian: """ + init = staticmethod(init) + build_kernel = staticmethod(build_kernel) + def __new__( # type: ignore[misc] cls, logdensity_fn: Callable, - covariance: Array, + covariance: Optional[Array] = None, mean: Optional[Array] = None, + cov_svd: Optional[CovarianceSVD] = None, + step_size: float = 1.0, ) -> SamplingAlgorithm: - init, kernel = init_and_kernel(logdensity_fn, covariance, mean) + if cov_svd is None: + if covariance is None: + raise ValueError("Either covariance or cov_svd must be provided.") + cov_svd = svd_from_covariance(covariance) + + U, Gamma, U_t = cov_svd + + if mean is not None: + logdensity_fn = generate_mean_shifted_logprob( + logdensity_fn, mean, covariance + ) + + kernel = cls.build_kernel(cov_svd) def init_fn(position: Array): - return init(position) + return init(position, logdensity_fn, U_t) - def step_fn(rng_key: PRNGKey, state, delta: float): + def step_fn(rng_key: PRNGKey, state): return kernel( rng_key, state, - delta, + logdensity_fn, + step_size, ) - return SamplingAlgorithm(init_fn, step_fn) # type: ignore[arg-type] + return SamplingAlgorithm(init_fn, step_fn) diff --git a/tests/mcmc/test_latent_gaussian.py b/tests/mcmc/test_latent_gaussian.py index 9f46c9d63..0c7d6bbc1 100644 --- a/tests/mcmc/test_latent_gaussian.py +++ b/tests/mcmc/test_latent_gaussian.py @@ -6,7 +6,12 @@ import numpy as np from absl.testing import absltest, parameterized -from blackjax.mcmc.marginal_latent_gaussian import init_and_kernel +from blackjax.mcmc.marginal_latent_gaussian import ( + build_kernel, + generate_mean_shifted_logprob, + init, + svd_from_covariance, +) class GaussianTest(chex.TestCase): @@ -26,14 +31,16 @@ def test_gaussian(self, seed, mean): obs = jax.random.normal(key4, (D,)) log_pdf = lambda x: stats.multivariate_normal.logpdf(x, obs, R) + if prior_mean is not None: + log_pdf = generate_mean_shifted_logprob(log_pdf, prior_mean, C) DELTA = 50.0 - - init, step = init_and_kernel(log_pdf, C, mean=prior_mean) - step = jax.jit(step) + cov_svd = svd_from_covariance(C) + _step = build_kernel(cov_svd) + step = jax.jit(lambda key, state, delta: _step(key, state, log_pdf, delta)) init_x = np.zeros((D,)) - init_state = init(init_x) + init_state = init(init_x, log_pdf, cov_svd.U_t) keys = jax.random.split(key5, n_samples) diff --git a/tests/mcmc/test_sampling.py b/tests/mcmc/test_sampling.py index 879eba550..b87c14d7c 100644 --- a/tests/mcmc/test_sampling.py +++ b/tests/mcmc/test_sampling.py @@ -513,13 +513,9 @@ def test_latent_gaussian(self): from blackjax import mgrad_gaussian inference_algorithm = mgrad_gaussian( - lambda x: -0.5 * jnp.sum((x - 1.0) ** 2), self.C - ) - inference_algorithm = inference_algorithm._replace( - step=functools.partial( - inference_algorithm.step, - delta=self.delta, - ) + lambda x: -0.5 * jnp.sum((x - 1.0) ** 2), + covariance=self.C, + step_size=self.delta, ) initial_state = inference_algorithm.init(jnp.zeros((1,)))