diff --git a/blackjax/__init__.py b/blackjax/__init__.py index 71719201a..96df14920 100644 --- a/blackjax/__init__.py +++ b/blackjax/__init__.py @@ -18,6 +18,7 @@ from .mcmc.nuts import nuts from .mcmc.periodic_orbital import orbital_hmc from .mcmc.random_walk import additive_step_random_walk, irmh, rmh +from .mcmc.rmhmc import rmhmc from .optimizers import dual_averaging, lbfgs from .sgmcmc.csgld import csgld from .sgmcmc.sghmc import sghmc @@ -37,6 +38,7 @@ "lbfgs", "hmc", # mcmc "dynamic_hmc", + "rmhmc", "mala", "mgrad_gaussian", "nuts", diff --git a/blackjax/mcmc/__init__.py b/blackjax/mcmc/__init__.py index f27b199c6..6e207741d 100644 --- a/blackjax/mcmc/__init__.py +++ b/blackjax/mcmc/__init__.py @@ -9,6 +9,7 @@ nuts, periodic_orbital, random_walk, + rmhmc, ) __all__ = [ @@ -16,6 +17,7 @@ "elliptical_slice", "ghmc", "hmc", + "rmhmc", "mala", "nuts", "periodic_orbital", diff --git a/blackjax/mcmc/hmc.py b/blackjax/mcmc/hmc.py index 90bdbc60c..b48834e5f 100644 --- a/blackjax/mcmc/hmc.py +++ b/blackjax/mcmc/hmc.py @@ -22,7 +22,7 @@ from blackjax.base import SamplingAlgorithm from blackjax.mcmc.proposal import safe_energy_diff, static_binomial_sampling from blackjax.mcmc.trajectory import hmc_energy -from blackjax.types import Array, ArrayLikeTree, ArrayTree, PRNGKey +from blackjax.types import ArrayLikeTree, ArrayTree, PRNGKey __all__ = [ "HMCState", @@ -101,7 +101,8 @@ def build_kernel( integrator The symplectic integrator to use to integrate the Hamiltonian dynamics. divergence_threshold - Value of the difference in energy above which we consider that the transition is divergent. + Value of the difference in energy above which we consider that the transition is + divergent. Returns ------- @@ -116,18 +117,16 @@ def kernel( state: HMCState, logdensity_fn: Callable, step_size: float, - inverse_mass_matrix: Array, + inverse_mass_matrix: metrics.MetricTypes, num_integration_steps: int, ) -> tuple[HMCState, HMCInfo]: """Generate a new sample with the HMC kernel.""" - momentum_generator, kinetic_energy_fn, _ = metrics.gaussian_euclidean( - inverse_mass_matrix - ) - symplectic_integrator = integrator(logdensity_fn, kinetic_energy_fn) + metric = metrics.default_metric(inverse_mass_matrix) + symplectic_integrator = integrator(logdensity_fn, metric.kinetic_energy) proposal_generator = hmc_proposal( symplectic_integrator, - kinetic_energy_fn, + metric.kinetic_energy, step_size, num_integration_steps, divergence_threshold, @@ -136,7 +135,7 @@ def kernel( key_momentum, key_integrator = jax.random.split(rng_key, 2) position, logdensity, logdensity_grad = state - momentum = momentum_generator(key_momentum, position) + momentum = metric.sample_momentum(key_momentum, position) integrator_state = integrators.IntegratorState( position, momentum, logdensity, logdensity_grad @@ -154,10 +153,10 @@ def kernel( class hmc: """Implements the (basic) user interface for the HMC kernel. - The general hmc kernel builder (:meth:`blackjax.mcmc.hmc.build_kernel`, alias `blackjax.hmc.build_kernel`) can be - cumbersome to manipulate. Since most users only need to specify the kernel - parameters at initialization time, we provide a helper function that - specializes the general kernel. + The general hmc kernel builder (:meth:`blackjax.mcmc.hmc.build_kernel`, alias + `blackjax.hmc.build_kernel`) can be cumbersome to manipulate. Since most users only + need to specify the kernel parameters at initialization time, we provide a helper + function that specializes the general kernel. We also add the general kernel and state generator as an attribute to this class so users only need to pass `blackjax.hmc` to SMC, adaptation, etc. algorithms. @@ -169,7 +168,9 @@ class hmc: .. code:: - hmc = blackjax.hmc(logdensity_fn, step_size, inverse_mass_matrix, num_integration_steps) + hmc = blackjax.hmc( + logdensity_fn, step_size, inverse_mass_matrix, num_integration_steps + ) state = hmc.init(position) new_state, info = hmc.step(rng_key, state) @@ -188,7 +189,14 @@ class hmc: kernel = blackjax.hmc.build_kernel(integrators.mclachlan) state = blackjax.hmc.init(position, logdensity_fn) - state, info = kernel(rng_key, state, logdensity_fn, step_size, inverse_mass_matrix, num_integration_steps) + state, info = kernel( + rng_key, + state, + logdensity_fn, + step_size, + inverse_mass_matrix, + num_integration_steps, + ) Parameters ---------- @@ -198,7 +206,9 @@ class hmc: The value to use for the step size in the symplectic integrator. inverse_mass_matrix The value to use for the inverse mass matrix when drawing a value for - the momentum and computing the kinetic energy. + the momentum and computing the kinetic energy. This argument will be + passed to the ``metrics.default_metric`` function so it supports the + full interface presented there. num_integration_steps The number of steps we take with the symplectic integrator at each sample step before returning a sample. @@ -207,7 +217,8 @@ class hmc: which we say that the transition is divergent. The default value is commonly found in other libraries, and yet is arbitrary. integrator - (algorithm parameter) The symplectic integrator to use to integrate the trajectory.\ + (algorithm parameter) The symplectic integrator to use to integrate the + trajectory. Returns ------- @@ -221,7 +232,7 @@ def __new__( # type: ignore[misc] cls, logdensity_fn: Callable, step_size: float, - inverse_mass_matrix: Array, + inverse_mass_matrix: metrics.MetricTypes, num_integration_steps: int, *, divergence_threshold: int = 1000, @@ -248,7 +259,7 @@ def step_fn(rng_key: PRNGKey, state): def hmc_proposal( integrator: Callable, - kinetic_energy: Callable, + kinetic_energy: metrics.KineticEnergy, step_size: Union[float, ArrayLikeTree], num_integration_steps: int = 1, divergence_threshold: float = 1000, diff --git a/blackjax/mcmc/integrators.py b/blackjax/mcmc/integrators.py index be96fa4b1..489efb011 100644 --- a/blackjax/mcmc/integrators.py +++ b/blackjax/mcmc/integrators.py @@ -12,19 +12,20 @@ # See the License for the specific language governing permissions and # limitations under the License. """Symplectic, time-reversible, integrators for Hamiltonian trajectories.""" -from typing import Callable, NamedTuple +from typing import Any, Callable, NamedTuple, Tuple import jax import jax.numpy as jnp from jax.flatten_util import ravel_pytree -from blackjax.mcmc.metrics import EuclideanKineticEnergy +from blackjax.mcmc.metrics import KineticEnergy from blackjax.types import ArrayTree __all__ = [ "mclachlan", "velocity_verlet", "yoshida", + "implicit_midpoint", "noneuclidean_leapfrog", "noneuclidean_mclachlan", "noneuclidean_yoshida", @@ -170,7 +171,7 @@ def update( return update -def euclidean_momentum_update_fn(kinetic_energy_fn: EuclideanKineticEnergy): +def euclidean_momentum_update_fn(kinetic_energy_fn: KineticEnergy): kinetic_energy_grad_fn = jax.grad(kinetic_energy_fn) def update( @@ -216,7 +217,7 @@ def generate_euclidean_integrator(cofficients): """ def euclidean_integrator( - logdensity_fn: Callable, kinetic_energy_fn: EuclideanKineticEnergy + logdensity_fn: Callable, kinetic_energy_fn: KineticEnergy ) -> Integrator: position_update_fn = euclidean_position_update_fn(logdensity_fn) momentum_update_fn = euclidean_momentum_update_fn(kinetic_energy_fn) @@ -366,3 +367,115 @@ def noneuclidean_integrator( noneuclidean_leapfrog = generate_noneuclidean_integrator(velocity_verlet_cofficients) noneuclidean_yoshida = generate_noneuclidean_integrator(yoshida_cofficients) noneuclidean_mclachlan = generate_noneuclidean_integrator(mclachlan_cofficients) + +FixedPointSolver = Callable[ + [Callable[[ArrayTree], Tuple[ArrayTree, ArrayTree]], ArrayTree], + Tuple[ArrayTree, ArrayTree, Any], +] + + +class FixedPointIterationInfo(NamedTuple): + success: bool + norm: float + iters: int + + +def solve_fixed_point_iteration( + func: Callable[[ArrayTree], Tuple[ArrayTree, ArrayTree]], + x0: ArrayTree, + *, + convergence_tol: float = 1e-6, + divergence_tol: float = 1e10, + max_iters: int = 100, + norm_fn: Callable[[ArrayTree], float] = lambda x: jnp.max(jnp.abs(x)), +) -> Tuple[ArrayTree, ArrayTree, FixedPointIterationInfo]: + """Solve for x = func(x) using a fixed point iteration""" + + def compute_norm(x: ArrayTree, xp: ArrayTree) -> float: + return norm_fn(ravel_pytree(jax.tree_util.tree_map(jnp.subtract, x, xp))[0]) + + def cond_fn(args: Tuple[int, ArrayTree, ArrayTree, float]) -> bool: + n, _, _, norm = args + return ( + (n < max_iters) + & jnp.isfinite(norm) + & (norm < divergence_tol) + & (norm > convergence_tol) + ) + + def body_fn( + args: Tuple[int, ArrayTree, ArrayTree, float] + ) -> Tuple[int, ArrayTree, ArrayTree, float]: + n, x, _, _ = args + xn, aux = func(x) + norm = compute_norm(xn, x) + return n + 1, xn, aux, norm + + x, aux = func(x0) + iters, x, aux, norm = jax.lax.while_loop( + cond_fn, body_fn, (0, x, aux, compute_norm(x, x0)) + ) + success = jnp.isfinite(norm) & (norm <= convergence_tol) + return x, aux, FixedPointIterationInfo(success, norm, iters) + + +def implicit_midpoint( + logdensity_fn: Callable, + kinetic_energy_fn: KineticEnergy, + *, + solver: FixedPointSolver = solve_fixed_point_iteration, + **solver_kwargs: Any, +) -> Integrator: + """The implicit midpoint integrator with support for non-stationary kinetic energy + + This is an integrator based on :cite:t:`brofos2021evaluating`, which provides + support for kinetic energies that depend on position. This integrator requires that + the kinetic energy function takes two arguments: position and momentum. + + The ``solver`` parameter allows overloading of the fixed point solver. By default, a + simple fixed point iteration is used, but more advanced solvers could be implemented + in the future. + """ + logdensity_and_grad_fn = jax.value_and_grad(logdensity_fn) + kinetic_energy_grad_fn = jax.grad( + lambda q, p: kinetic_energy_fn(p, position=q), argnums=(0, 1) + ) + + def one_step(state: IntegratorState, step_size: float) -> IntegratorState: + position, momentum, _, _ = state + + def _update( + q: ArrayTree, + p: ArrayTree, + dUdq: ArrayTree, + initial: Tuple[ArrayTree, ArrayTree] = (position, momentum), + ) -> Tuple[ArrayTree, ArrayTree]: + dTdq, dHdp = kinetic_energy_grad_fn(q, p) + dHdq = jax.tree_util.tree_map(jnp.subtract, dTdq, dUdq) + + # Take a step from the _initial coordinates_ using the gradients of the + # Hamiltonian evaluated at the current guess for the midpoint + q = jax.tree_util.tree_map( + lambda q_, d_: q_ + 0.5 * step_size * d_, initial[0], dHdp + ) + p = jax.tree_util.tree_map( + lambda p_, d_: p_ - 0.5 * step_size * d_, initial[1], dHdq + ) + return q, p + + # Solve for the midpoint numerically + def _step(args: ArrayTree) -> Tuple[ArrayTree, ArrayTree]: + q, p = args + _, dLdq = logdensity_and_grad_fn(q) + return _update(q, p, dLdq), dLdq + + (q, p), dLdq, info = solver(_step, (position, momentum), **solver_kwargs) + del info # TODO: Track the returned info + + # Take an explicit update as recommended by Brofos & Lederman + _, dLdq = logdensity_and_grad_fn(q) + q, p = _update(q, p, dLdq, initial=(q, p)) + + return IntegratorState(q, p, *logdensity_and_grad_fn(q)) + + return one_step diff --git a/blackjax/mcmc/metrics.py b/blackjax/mcmc/metrics.py index a24bc00b4..1368a8441 100644 --- a/blackjax/mcmc/metrics.py +++ b/blackjax/mcmc/metrics.py @@ -22,29 +22,82 @@ For a Newtonian hamiltonian dynamic the kinetic energy is given by: .. math:: + K(p) = \frac{1}{2} p^T M^{-1} p We can also generate a relativistic dynamic :cite:p:`lu2017relativistic`. """ -from typing import Callable +from typing import Callable, NamedTuple, Optional, Protocol, Union import jax.numpy as jnp import jax.scipy as jscipy from jax.flatten_util import ravel_pytree +from jax.scipy import stats as sp_stats from blackjax.types import Array, ArrayLikeTree, ArrayTree, PRNGKey from blackjax.util import generate_gaussian_noise -__all__ = ["gaussian_euclidean"] +__all__ = ["default_metric", "gaussian_euclidean", "gaussian_riemannian"] + + +class KineticEnergy(Protocol): + def __call__( + self, momentum: ArrayLikeTree, position: Optional[ArrayLikeTree] = None + ) -> float: + ... + + +class CheckTurning(Protocol): + def __call__( + self, + momentum_left: ArrayLikeTree, + momentum_right: ArrayLikeTree, + momentum_sum: ArrayLikeTree, + position_left: Optional[ArrayLikeTree] = None, + position_right: Optional[ArrayLikeTree] = None, + ) -> bool: + ... + -EuclideanKineticEnergy = Callable[[ArrayLikeTree], float] +class Metric(NamedTuple): + sample_momentum: Callable[[PRNGKey, ArrayLikeTree], ArrayLikeTree] + kinetic_energy: KineticEnergy + check_turning: CheckTurning + + +MetricTypes = Union[Metric, Array, Callable[[ArrayLikeTree], Array]] + + +def default_metric(metric: MetricTypes) -> Metric: + """Convert an input metric into a ``Metric`` object following sensible default rules + + The metric can be specified in three different ways: + + - A ``Metric`` object that implements the full interface + - An ``Array`` which is assumed to specify the inverse mass matrix of a static + metric + - A function that takes a coordinate position and returns the mass matrix at that + location + """ + if isinstance(metric, Metric): + return metric + + # If the argument is a callable, we assume that it returns the mass matrix + # at the given position and return the corresponding Riemannian metric. + if callable(metric): + return gaussian_riemannian(metric) + + # If we make it here then the argument should be an array, and we'll assume + # that it specifies a static inverse mass matrix. + return gaussian_euclidean(metric) def gaussian_euclidean( inverse_mass_matrix: Array, -) -> tuple[Callable, EuclideanKineticEnergy, Callable]: - r"""Hamiltonian dynamic on euclidean manifold with normally-distributed momentum :cite:p:`betancourt2013general`. +) -> Metric: + r"""Hamiltonian dynamic on euclidean manifold with normally-distributed momentum + :cite:p:`betancourt2013general`. The gaussian euclidean metric is a euclidean metric further characterized by setting the conditional probability density :math:`\pi(momentum|position)` @@ -91,22 +144,28 @@ def gaussian_euclidean( L, identity, lower=True, trans=True ) # Note that mass_matrix_sqrt is a upper triangular matrix here, with - # jscipy.linalg.inv(mass_matrix_sqrt @ mass_matrix_sqrt.T) == inverse_mass_matrix - # An alternative is to compute directly the cholesky factor of the inverse mass matrix - # mass_matrix_sqrt = jscipy.linalg.cholesky(jscipy.linalg.inv(inverse_mass_matrix), lower=True) + # jscipy.linalg.inv(mass_matrix_sqrt @ mass_matrix_sqrt.T) + # == inverse_mass_matrix + # An alternative is to compute directly the cholesky factor of the inverse mass + # matrix + # mass_matrix_sqrt = jscipy.linalg.cholesky( + # jscipy.linalg.inv(inverse_mass_matrix), lower=True) # which the result would instead be a lower triangular matrix. matmul = jnp.matmul else: raise ValueError( "The mass matrix has the wrong number of dimensions:" - f" expected 1 or 2, got {jnp.ndim(inverse_mass_matrix)}." # type: ignore[arg-type] + f" expected 1 or 2, got {ndim}." ) def momentum_generator(rng_key: PRNGKey, position: ArrayLikeTree) -> ArrayTree: return generate_gaussian_noise(rng_key, position, sigma=mass_matrix_sqrt) - def kinetic_energy(momentum: ArrayLikeTree) -> float: + def kinetic_energy( + momentum: ArrayLikeTree, position: Optional[ArrayLikeTree] = None + ) -> float: + del position momentum, _ = ravel_pytree(momentum) velocity = matmul(inverse_mass_matrix, momentum) kinetic_energy_val = 0.5 * jnp.dot(velocity, momentum) @@ -116,6 +175,8 @@ def is_turning( momentum_left: ArrayLikeTree, momentum_right: ArrayLikeTree, momentum_sum: ArrayLikeTree, + position_left: Optional[ArrayLikeTree] = None, + position_right: Optional[ArrayLikeTree] = None, ) -> bool: """Generalized U-turn criterion :cite:p:`betancourt2013generalizing,nuts_uturn`. @@ -129,6 +190,8 @@ def is_turning( Sum of the momenta along the trajectory. """ + del position_left, position_right + m_left, _ = ravel_pytree(momentum_left) m_right, _ = ravel_pytree(momentum_right) m_sum, _ = ravel_pytree(momentum_sum) @@ -142,4 +205,82 @@ def is_turning( turning_at_right = jnp.dot(velocity_right, rho) <= 0 return turning_at_left | turning_at_right - return momentum_generator, kinetic_energy, is_turning + return Metric(momentum_generator, kinetic_energy, is_turning) + + +def gaussian_riemannian( + mass_matrix_fn: Callable, +) -> Metric: + def momentum_generator(rng_key: PRNGKey, position: ArrayLikeTree) -> ArrayLikeTree: + mass_matrix = mass_matrix_fn(position) + ndim = jnp.ndim(mass_matrix) + if ndim == 1: + mass_matrix_sqrt = jnp.sqrt(mass_matrix) + elif ndim == 2: + mass_matrix_sqrt = jscipy.linalg.cholesky(mass_matrix, lower=True) + else: + raise ValueError( + "The mass matrix has the wrong number of dimensions:" + f" expected 1 or 2, got {jnp.ndim(mass_matrix)}." + ) + + return generate_gaussian_noise(rng_key, position, sigma=mass_matrix_sqrt) + + def kinetic_energy( + momentum: ArrayLikeTree, position: Optional[ArrayLikeTree] = None + ) -> float: + if position is None: + raise ValueError( + "A Reinmannian kinetic energy function must be called with the " + "position specified; make sure to use a Reinmannian-capable " + "integrator like `implicit_midpoint`." + ) + + momentum, _ = ravel_pytree(momentum) + mass_matrix = mass_matrix_fn(position) + ndim = jnp.ndim(mass_matrix) + if ndim == 1: + return -jnp.sum(sp_stats.norm.logpdf(momentum, 0.0, jnp.sqrt(mass_matrix))) + elif ndim == 2: + return -sp_stats.multivariate_normal.logpdf( + momentum, jnp.zeros_like(momentum), mass_matrix + ) + else: + raise ValueError( + "The mass matrix has the wrong number of dimensions:" + f" expected 1 or 2, got {jnp.ndim(mass_matrix)}." + ) + + def is_turning( + momentum_left: ArrayLikeTree, + momentum_right: ArrayLikeTree, + momentum_sum: ArrayLikeTree, + position_left: Optional[ArrayLikeTree] = None, + position_right: Optional[ArrayLikeTree] = None, + ) -> bool: + del momentum_left, momentum_right, momentum_sum, position_left, position_right + raise NotImplementedError( + "NUTS sampling is not yet implemented for Riemannian manifolds" + ) + + # Here's a possible implementation of this function, but the NUTS + # proposal will require some refactoring to work properly, since we need + # to be able to access the coordinates at the left and right endpoints + # to compute the mass matrix at those points. + + # m_left, _ = ravel_pytree(momentum_left) + # m_right, _ = ravel_pytree(momentum_right) + # m_sum, _ = ravel_pytree(momentum_sum) + + # mass_matrix_left = mass_matrix_fn(position_left) + # mass_matrix_right = mass_matrix_fn(position_right) + # velocity_left = jnp.linalg.solve(mass_matrix_left, m_left) + # velocity_right = jnp.linalg.solve(mass_matrix_right, m_right) + + # # rho = m_sum + # rho = m_sum - (m_right + m_left) / 2 + # turning_at_left = jnp.dot(velocity_left, rho) <= 0 + # turning_at_right = jnp.dot(velocity_right, rho) <= 0 + # return turning_at_left | turning_at_right + + return Metric(momentum_generator, kinetic_energy, is_turning) diff --git a/blackjax/mcmc/nuts.py b/blackjax/mcmc/nuts.py index 883121514..5ffc083b1 100644 --- a/blackjax/mcmc/nuts.py +++ b/blackjax/mcmc/nuts.py @@ -25,7 +25,7 @@ import blackjax.mcmc.termination as termination import blackjax.mcmc.trajectory as trajectory from blackjax.base import SamplingAlgorithm -from blackjax.types import Array, ArrayLikeTree, ArrayTree, PRNGKey +from blackjax.types import ArrayLikeTree, ArrayTree, PRNGKey __all__ = ["NUTSInfo", "init", "build_kernel", "nuts"] @@ -115,21 +115,17 @@ def kernel( state: hmc.HMCState, logdensity_fn: Callable, step_size: float, - inverse_mass_matrix: Array, + inverse_mass_matrix: metrics.MetricTypes, max_num_doublings: int = 10, ) -> tuple[hmc.HMCState, NUTSInfo]: """Generate a new sample with the NUTS kernel.""" - ( - momentum_generator, - kinetic_energy_fn, - uturn_check_fn, - ) = metrics.gaussian_euclidean(inverse_mass_matrix) - symplectic_integrator = integrator(logdensity_fn, kinetic_energy_fn) + metric = metrics.default_metric(inverse_mass_matrix) + symplectic_integrator = integrator(logdensity_fn, metric.kinetic_energy) proposal_generator = iterative_nuts_proposal( symplectic_integrator, - kinetic_energy_fn, - uturn_check_fn, + metric.kinetic_energy, + metric.check_turning, max_num_doublings, divergence_threshold, ) @@ -137,7 +133,7 @@ def kernel( key_momentum, key_integrator = jax.random.split(rng_key, 2) position, logdensity, logdensity_grad = state - momentum = momentum_generator(key_momentum, position) + momentum = metric.sample_momentum(key_momentum, position) integrator_state = integrators.IntegratorState( position, momentum, logdensity, logdensity_grad @@ -214,7 +210,7 @@ def __new__( # type: ignore[misc] cls, logdensity_fn: Callable, step_size: float, - inverse_mass_matrix: Array, + inverse_mass_matrix: metrics.MetricTypes, *, max_num_doublings: int = 10, divergence_threshold: int = 1000, @@ -241,8 +237,8 @@ def step_fn(rng_key: PRNGKey, state): def iterative_nuts_proposal( integrator: Callable, - kinetic_energy: Callable, - uturn_check_fn: Callable, + kinetic_energy: metrics.KineticEnergy, + uturn_check_fn: metrics.CheckTurning, max_num_expansions: int = 10, divergence_threshold: float = 1000, ) -> Callable: diff --git a/blackjax/mcmc/rmhmc.py b/blackjax/mcmc/rmhmc.py new file mode 100644 index 000000000..edcfb3571 --- /dev/null +++ b/blackjax/mcmc/rmhmc.py @@ -0,0 +1,95 @@ +# Copyright 2020- The Blackjax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Callable, Union + +import blackjax.mcmc.integrators as integrators +import blackjax.mcmc.metrics as metrics +from blackjax.base import SamplingAlgorithm +from blackjax.mcmc import hmc +from blackjax.types import ArrayTree, PRNGKey + +__all__ = ["init", "build_kernel", "rmhmc"] + + +init = hmc.init +build_kernel = hmc.build_kernel + + +class rmhmc: + """A Riemannian Manifold Hamiltonian Monte Carlo kernel + + Of note, this kernel is simply an alias of the ``hmc`` kernel with a + different choice of default integrator (``implicit_midpoint`` instead of + ``velocity_verlet``) since RMHMC is typically used for Hamiltonian systems + that are not separable. + + Parameters + ---------- + logdensity_fn + The log-density function we wish to draw samples from. + step_size + The value to use for the step size in the symplectic integrator. + mass_matrix + A function which computes the mass matrix (not inverse) at a given + position when drawing a value for the momentum and computing the kinetic + energy. In practice, this argument will be passed to the + ``metrics.default_metric`` function so it supports all the options + discussed there. + num_integration_steps + The number of steps we take with the symplectic integrator at each + sample step before returning a sample. + divergence_threshold + The absolute value of the difference in energy between two states above + which we say that the transition is divergent. The default value is + commonly found in other libraries, and yet is arbitrary. + integrator + (algorithm parameter) The symplectic integrator to use to integrate the + trajectory. + + Returns + ------- + A ``SamplingAlgorithm``. + """ + + init = staticmethod(init) + build_kernel = staticmethod(build_kernel) + + def __new__( # type: ignore[misc] + cls, + logdensity_fn: Callable, + step_size: float, + mass_matrix: Union[metrics.Metric, Callable], + num_integration_steps: int, + *, + divergence_threshold: int = 1000, + integrator: Callable = integrators.implicit_midpoint, + ) -> SamplingAlgorithm: + kernel = cls.build_kernel(integrator, divergence_threshold) + + def init_fn(position: ArrayTree, rng_key=None): + del rng_key + return cls.init(position, logdensity_fn) + + def step_fn(rng_key: PRNGKey, state): + return kernel( + rng_key, + state, + logdensity_fn, + step_size, + mass_matrix, + num_integration_steps, + ) + + return SamplingAlgorithm(init_fn, step_fn) diff --git a/blackjax/mcmc/trajectory.py b/blackjax/mcmc/trajectory.py index 00f25989d..6338acc2b 100644 --- a/blackjax/mcmc/trajectory.py +++ b/blackjax/mcmc/trajectory.py @@ -618,6 +618,8 @@ def update_sum_log_p_accept(inputs): def hmc_energy(kinetic_energy): def energy(state): - return -state.logdensity + kinetic_energy(state.momentum) + return -state.logdensity + kinetic_energy( + state.momentum, position=state.position + ) return energy diff --git a/docs/refs.bib b/docs/refs.bib index eee65c7ea..c5e66ca41 100644 --- a/docs/refs.bib +++ b/docs/refs.bib @@ -143,6 +143,19 @@ @article{mclachlan1995numerical publisher={SIAM} } +@inproceedings{brofos2021evaluating, + title={Evaluating the Implicit Midpoint Integrator for Riemannian Hamiltonian Monte Carlo}, + author={Brofos, James and Lederman, Roy R}, + booktitle={Proceedings of the 38th International Conference on Machine Learning}, + pages={1072--1081}, + year={2021}, + editor={Meila, Marina and Zhang, Tong}, + volume={139}, + series={Proceedings of Machine Learning Research}, + month={18--24 Jul}, + publisher={PMLR} +} + @book{schlick2010molecular, title={Molecular modeling and simulation: an interdisciplinary guide}, author={Schlick, Tamar}, diff --git a/tests/mcmc/test_integrators.py b/tests/mcmc/test_integrators.py index 2ef285dd2..2d803308a 100644 --- a/tests/mcmc/test_integrators.py +++ b/tests/mcmc/test_integrators.py @@ -7,6 +7,7 @@ import numpy as np from absl.testing import absltest, parameterized from jax.flatten_util import ravel_pytree +from scipy.special import ellipj import blackjax.mcmc.integrators as integrators from blackjax.mcmc.integrators import esh_dynamics_momentum_update_one_step @@ -19,7 +20,8 @@ def HarmonicOscillator(inv_mass_matrix, k=1.0, m=1.0): def neg_potential_energy(q): return -jnp.sum(0.5 * k * jnp.square(q["x"])) - def kinetic_energy(p): + def kinetic_energy(p, position=None): + del position v = jnp.multiply(inv_mass_matrix, p["x"]) return jnp.sum(0.5 * jnp.dot(v, p["x"])) @@ -32,7 +34,8 @@ def FreeFall(inv_mass_matrix, g=1.0): def neg_potential_energy(q): return -jnp.sum(g * q["x"]) - def kinetic_energy(p): + def kinetic_energy(p, position=None): + del position v = jnp.multiply(inv_mass_matrix, p["x"]) return jnp.sum(0.5 * jnp.dot(v, p["x"])) @@ -45,7 +48,8 @@ def PlanetaryMotion(inv_mass_matrix): def neg_potential_energy(q): return 1.0 / jnp.power(q["x"] ** 2 + q["y"] ** 2, 0.5) - def kinetic_energy(p): + def kinetic_energy(p, position=None): + del position z = jnp.stack([p["x"], p["y"]], axis=-1) return 0.5 * jnp.dot(inv_mass_matrix, z**2) @@ -59,7 +63,8 @@ def log_density(q): q, _ = ravel_pytree(q) return stats.multivariate_normal.logpdf(q, jnp.zeros_like(q), inv_mass_matrix) - def kinetic_energy(p): + def kinetic_energy(p, position=None): + del position p, _ = ravel_pytree(p) return 0.5 * p.T @ inv_mass_matrix @ p @@ -131,6 +136,10 @@ def kinetic_energy(p): "velocity_verlet": {"algorithm": integrators.velocity_verlet, "precision": 1e-4}, "mclachlan": {"algorithm": integrators.mclachlan, "precision": 1e-5}, "yoshida": {"algorithm": integrators.yoshida, "precision": 1e-6}, + "implicit_midpoint": { + "algorithm": integrators.implicit_midpoint, + "precision": 1e-4, + }, "noneuclidean_leapfrog": {"algorithm": integrators.noneuclidean_leapfrog}, "noneuclidean_mclachlan": {"algorithm": integrators.noneuclidean_mclachlan}, "noneuclidean_yoshida": {"algorithm": integrators.noneuclidean_yoshida}, @@ -159,6 +168,7 @@ class IntegratorTest(chex.TestCase): "velocity_verlet", "mclachlan", "yoshida", + "implicit_midpoint", ], ) ) @@ -319,6 +329,54 @@ def test_noneuclidean_integrator(self, integrator_name): energy_change = kinetic_energy_change[-1] + potential_energy_change self.assertAlmostEqual(energy_change, 0, delta=1e-3) + @chex.all_variants(with_pmap=False) + def test_non_separable(self): + """Test the integration of a non-separable Hamiltonian with a known + closed-form solution, as defined in https://arxiv.org/abs/1609.02212. + """ + + def neg_potential(q): + return -0.5 * (q**2 + 1) + + def kinetic_energy(p, position=None): + return 0.5 * p**2 * (1 + position**2) + + step = self.variant( + integrators.implicit_midpoint(neg_potential, kinetic_energy) + ) + step_size = 1e-3 + q = jnp.array(-1.0) + p = jnp.array(0.0) + initial_state = integrators.IntegratorState( + q, p, neg_potential(q), jax.grad(neg_potential)(q) + ) + + def scan_body(state, _): + state = step(state, step_size) + return state, state + + final_state, traj = jax.lax.scan( + scan_body, + initial_state, + xs=None, + length=10_000, + ) + + # The closed-form solution is computed as follows: + t = step_size * np.arange(len(traj.position)) + expected = q * ellipj(t * np.sqrt(1 + q**2), q**2 / (1 + q**2))[1] + + # Check that the trajectory matches the closed-form solution to + # acceptable precision + chex.assert_trees_all_close(traj.position, expected, atol=step_size) + + # And check the conservation of energy + energy = -neg_potential(q) + kinetic_energy(p, position=q) + new_energy = -neg_potential(final_state.position) + kinetic_energy( + final_state.momentum, position=final_state.position + ) + self.assertAlmostEqual(energy, new_energy, delta=1e-4) + if __name__ == "__main__": absltest.main() diff --git a/tests/mcmc/test_metrics.py b/tests/mcmc/test_metrics.py index 3501ce0a8..f806a375c 100644 --- a/tests/mcmc/test_metrics.py +++ b/tests/mcmc/test_metrics.py @@ -67,5 +67,81 @@ def test_gaussian_euclidean_dim_2(self): np.testing.assert_allclose(kinetic_energy_val, expected_kinetic_energy_val) +class GaussianRiemannianMetricsTest(chex.TestCase): + def setUp(self): + super().setUp() + self.key = random.PRNGKey(0) + self.dtype = "float32" + + @parameterized.named_parameters( + {"testcase_name": "0d", "shape": ()}, + {"testcase_name": "3d", "shape": (1, 2, 3)}, + ) + def test_gaussian_riemannian_value_errors(self, shape): + x = jnp.ones(shape=shape) + metric = metrics.gaussian_riemannian(lambda _: x) + with self.assertRaisesRegex( + ValueError, "The mass matrix has the wrong number of dimensions" + ): + metric.sample_momentum(self.key, x) + + with self.assertRaisesRegex( + ValueError, "The mass matrix has the wrong number of dimensions" + ): + metric.kinetic_energy(x, position=x) + + with self.assertRaisesRegex( + ValueError, "must be called with the position specified" + ): + metric.kinetic_energy(x) + + @chex.all_variants(with_pmap=False) + def test_gaussian_riemannian_dim_1(self): + inverse_mass_matrix = jnp.asarray([1 / 4], dtype=self.dtype) + mass_matrix = jnp.asarray([4.0], dtype=self.dtype) + momentum, kinetic_energy, _ = metrics.gaussian_riemannian(lambda _: mass_matrix) + + arbitrary_position = jnp.asarray([12345], dtype=self.dtype) + momentum_val = self.variant(momentum)(self.key, arbitrary_position) + + # 2 is square root inverse of 1/4 + expected_momentum_val = 2 * random.normal(self.key) + + kinetic_energy_val = self.variant(kinetic_energy)( + momentum_val, position=arbitrary_position + ) + velocity = inverse_mass_matrix * momentum_val + expected_kinetic_energy_val = 0.5 * velocity * momentum_val + expected_kinetic_energy_val += 0.5 * jnp.sum(jnp.log(2 * jnp.pi * mass_matrix)) + + assert momentum_val == expected_momentum_val + assert kinetic_energy_val == expected_kinetic_energy_val + + @chex.all_variants(with_pmap=False) + def test_gaussian_euclidean_dim_2(self): + inverse_mass_matrix = jnp.asarray( + [[1 / 9, 0.5], [0.5, 1 / 4]], dtype=self.dtype + ) + mass_matrix = jnp.linalg.inv(inverse_mass_matrix) + momentum, kinetic_energy, _ = metrics.gaussian_riemannian(lambda _: mass_matrix) + + arbitrary_position = jnp.asarray([12345, 23456], dtype=self.dtype) + momentum_val = self.variant(momentum)(self.key, arbitrary_position) + + L_inv = linalg.cholesky(linalg.inv(inverse_mass_matrix), lower=True) + expected_momentum_val = L_inv @ random.normal(self.key, shape=(2,)) + + kinetic_energy_val = self.variant(kinetic_energy)( + momentum_val, position=arbitrary_position + ) + velocity = jnp.dot(inverse_mass_matrix, momentum_val) + expected_kinetic_energy_val = 0.5 * jnp.matmul(velocity, momentum_val) + expected_kinetic_energy_val += 0.5 * jnp.linalg.slogdet(mass_matrix)[1] + expected_kinetic_energy_val += 0.5 * len(mass_matrix) * jnp.log(2 * jnp.pi) + + np.testing.assert_allclose(expected_momentum_val, momentum_val) + np.testing.assert_allclose(kinetic_energy_val, expected_kinetic_energy_val) + + if __name__ == "__main__": absltest.main() diff --git a/tests/mcmc/test_sampling.py b/tests/mcmc/test_sampling.py index c719286a4..a4d3bd6fd 100644 --- a/tests/mcmc/test_sampling.py +++ b/tests/mcmc/test_sampling.py @@ -536,6 +536,11 @@ def test_latent_gaussian(self): ) +def rmhmc_static_mass_matrix_fn(position): + del position + return jnp.array([1.0]) + + normal_test_cases = [ { "algorithm": blackjax.hmc, @@ -620,6 +625,16 @@ def test_latent_gaussian(self): "num_sampling_steps": 20_000, "burnin": 2_000, }, + { + "algorithm": blackjax.rmhmc, + "initial_position": jnp.array(3.0), + "parameters": { + "step_size": 1.0, + "num_integration_steps": 30, + }, + "num_sampling_steps": 6000, + "burnin": 1_000, + }, ] @@ -647,6 +662,9 @@ def test_univariate_normal( if algorithm == blackjax.rmh: parameters["proposal_generator"] = rmh_proposal_distribution + if algorithm == blackjax.rmhmc: + parameters["mass_matrix"] = rmhmc_static_mass_matrix_fn + inference_algorithm = algorithm(self.normal_logprob, **parameters) rng_key = self.key if algorithm == blackjax.elliptical_slice: