Skip to content

Commit

Permalink
Adding Riemannian Manifold HMC (#538)
Browse files Browse the repository at this point in the history
* Adding initial implementation of RMHMC

moving RMHMC to a separate submodule

fixing parallel tests and improving kinetic energy interface

Moving explicit leapfrog step to end of implicit midpoint

lint

fix explicit update; include logdet in kinetic energy

lint

* implementing untested rmhmc turning criterion

* implementing Metric type

* adding test for integrating non-separable potential

* add energy check in non-separable test

* add test for riemannian metric

* Fix typing

* fix test

---------

Co-authored-by: Junpeng Lao <[email protected]>
  • Loading branch information
dfm and junpenglao authored Dec 14, 2023
1 parent 4058971 commit f12fc38
Show file tree
Hide file tree
Showing 12 changed files with 580 additions and 53 deletions.
2 changes: 2 additions & 0 deletions blackjax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from .mcmc.nuts import nuts
from .mcmc.periodic_orbital import orbital_hmc
from .mcmc.random_walk import additive_step_random_walk, irmh, rmh
from .mcmc.rmhmc import rmhmc
from .optimizers import dual_averaging, lbfgs
from .sgmcmc.csgld import csgld
from .sgmcmc.sghmc import sghmc
Expand All @@ -37,6 +38,7 @@
"lbfgs",
"hmc", # mcmc
"dynamic_hmc",
"rmhmc",
"mala",
"mgrad_gaussian",
"nuts",
Expand Down
2 changes: 2 additions & 0 deletions blackjax/mcmc/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,15 @@
nuts,
periodic_orbital,
random_walk,
rmhmc,
)

__all__ = [
"barker",
"elliptical_slice",
"ghmc",
"hmc",
"rmhmc",
"mala",
"nuts",
"periodic_orbital",
Expand Down
49 changes: 30 additions & 19 deletions blackjax/mcmc/hmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from blackjax.base import SamplingAlgorithm
from blackjax.mcmc.proposal import safe_energy_diff, static_binomial_sampling
from blackjax.mcmc.trajectory import hmc_energy
from blackjax.types import Array, ArrayLikeTree, ArrayTree, PRNGKey
from blackjax.types import ArrayLikeTree, ArrayTree, PRNGKey

__all__ = [
"HMCState",
Expand Down Expand Up @@ -101,7 +101,8 @@ def build_kernel(
integrator
The symplectic integrator to use to integrate the Hamiltonian dynamics.
divergence_threshold
Value of the difference in energy above which we consider that the transition is divergent.
Value of the difference in energy above which we consider that the transition is
divergent.

Returns
-------
Expand All @@ -116,18 +117,16 @@ def kernel(
state: HMCState,
logdensity_fn: Callable,
step_size: float,
inverse_mass_matrix: Array,
inverse_mass_matrix: metrics.MetricTypes,
num_integration_steps: int,
) -> tuple[HMCState, HMCInfo]:
"""Generate a new sample with the HMC kernel."""

momentum_generator, kinetic_energy_fn, _ = metrics.gaussian_euclidean(
inverse_mass_matrix
)
symplectic_integrator = integrator(logdensity_fn, kinetic_energy_fn)
metric = metrics.default_metric(inverse_mass_matrix)
symplectic_integrator = integrator(logdensity_fn, metric.kinetic_energy)
proposal_generator = hmc_proposal(
symplectic_integrator,
kinetic_energy_fn,
metric.kinetic_energy,
step_size,
num_integration_steps,
divergence_threshold,
Expand All @@ -136,7 +135,7 @@ def kernel(
key_momentum, key_integrator = jax.random.split(rng_key, 2)

position, logdensity, logdensity_grad = state
momentum = momentum_generator(key_momentum, position)
momentum = metric.sample_momentum(key_momentum, position)

integrator_state = integrators.IntegratorState(
position, momentum, logdensity, logdensity_grad
Expand All @@ -154,10 +153,10 @@ def kernel(
class hmc:
"""Implements the (basic) user interface for the HMC kernel.

The general hmc kernel builder (:meth:`blackjax.mcmc.hmc.build_kernel`, alias `blackjax.hmc.build_kernel`) can be
cumbersome to manipulate. Since most users only need to specify the kernel
parameters at initialization time, we provide a helper function that
specializes the general kernel.
The general hmc kernel builder (:meth:`blackjax.mcmc.hmc.build_kernel`, alias
`blackjax.hmc.build_kernel`) can be cumbersome to manipulate. Since most users only
need to specify the kernel parameters at initialization time, we provide a helper
function that specializes the general kernel.

We also add the general kernel and state generator as an attribute to this class so
users only need to pass `blackjax.hmc` to SMC, adaptation, etc. algorithms.
Expand All @@ -169,7 +168,9 @@ class hmc:

.. code::

hmc = blackjax.hmc(logdensity_fn, step_size, inverse_mass_matrix, num_integration_steps)
hmc = blackjax.hmc(
logdensity_fn, step_size, inverse_mass_matrix, num_integration_steps
)
state = hmc.init(position)
new_state, info = hmc.step(rng_key, state)

Expand All @@ -188,7 +189,14 @@ class hmc:

kernel = blackjax.hmc.build_kernel(integrators.mclachlan)
state = blackjax.hmc.init(position, logdensity_fn)
state, info = kernel(rng_key, state, logdensity_fn, step_size, inverse_mass_matrix, num_integration_steps)
state, info = kernel(
rng_key,
state,
logdensity_fn,
step_size,
inverse_mass_matrix,
num_integration_steps,
)

Parameters
----------
Expand All @@ -198,7 +206,9 @@ class hmc:
The value to use for the step size in the symplectic integrator.
inverse_mass_matrix
The value to use for the inverse mass matrix when drawing a value for
the momentum and computing the kinetic energy.
the momentum and computing the kinetic energy. This argument will be
passed to the ``metrics.default_metric`` function so it supports the
full interface presented there.
num_integration_steps
The number of steps we take with the symplectic integrator at each
sample step before returning a sample.
Expand All @@ -207,7 +217,8 @@ class hmc:
which we say that the transition is divergent. The default value is
commonly found in other libraries, and yet is arbitrary.
integrator
(algorithm parameter) The symplectic integrator to use to integrate the trajectory.\
(algorithm parameter) The symplectic integrator to use to integrate the
trajectory.

Returns
-------
Expand All @@ -221,7 +232,7 @@ def __new__( # type: ignore[misc]
cls,
logdensity_fn: Callable,
step_size: float,
inverse_mass_matrix: Array,
inverse_mass_matrix: metrics.MetricTypes,
num_integration_steps: int,
*,
divergence_threshold: int = 1000,
Expand All @@ -248,7 +259,7 @@ def step_fn(rng_key: PRNGKey, state):

def hmc_proposal(
integrator: Callable,
kinetic_energy: Callable,
kinetic_energy: metrics.KineticEnergy,
step_size: Union[float, ArrayLikeTree],
num_integration_steps: int = 1,
divergence_threshold: float = 1000,
Expand Down
121 changes: 117 additions & 4 deletions blackjax/mcmc/integrators.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,19 +12,20 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Symplectic, time-reversible, integrators for Hamiltonian trajectories."""
from typing import Callable, NamedTuple
from typing import Any, Callable, NamedTuple, Tuple

import jax
import jax.numpy as jnp
from jax.flatten_util import ravel_pytree

from blackjax.mcmc.metrics import EuclideanKineticEnergy
from blackjax.mcmc.metrics import KineticEnergy
from blackjax.types import ArrayTree

__all__ = [
"mclachlan",
"velocity_verlet",
"yoshida",
"implicit_midpoint",
"noneuclidean_leapfrog",
"noneuclidean_mclachlan",
"noneuclidean_yoshida",
Expand Down Expand Up @@ -170,7 +171,7 @@ def update(
return update


def euclidean_momentum_update_fn(kinetic_energy_fn: EuclideanKineticEnergy):
def euclidean_momentum_update_fn(kinetic_energy_fn: KineticEnergy):
kinetic_energy_grad_fn = jax.grad(kinetic_energy_fn)

def update(
Expand Down Expand Up @@ -216,7 +217,7 @@ def generate_euclidean_integrator(cofficients):
"""

def euclidean_integrator(
logdensity_fn: Callable, kinetic_energy_fn: EuclideanKineticEnergy
logdensity_fn: Callable, kinetic_energy_fn: KineticEnergy
) -> Integrator:
position_update_fn = euclidean_position_update_fn(logdensity_fn)
momentum_update_fn = euclidean_momentum_update_fn(kinetic_energy_fn)
Expand Down Expand Up @@ -366,3 +367,115 @@ def noneuclidean_integrator(
noneuclidean_leapfrog = generate_noneuclidean_integrator(velocity_verlet_cofficients)
noneuclidean_yoshida = generate_noneuclidean_integrator(yoshida_cofficients)
noneuclidean_mclachlan = generate_noneuclidean_integrator(mclachlan_cofficients)

FixedPointSolver = Callable[
[Callable[[ArrayTree], Tuple[ArrayTree, ArrayTree]], ArrayTree],
Tuple[ArrayTree, ArrayTree, Any],
]


class FixedPointIterationInfo(NamedTuple):
success: bool
norm: float
iters: int


def solve_fixed_point_iteration(
func: Callable[[ArrayTree], Tuple[ArrayTree, ArrayTree]],
x0: ArrayTree,
*,
convergence_tol: float = 1e-6,
divergence_tol: float = 1e10,
max_iters: int = 100,
norm_fn: Callable[[ArrayTree], float] = lambda x: jnp.max(jnp.abs(x)),
) -> Tuple[ArrayTree, ArrayTree, FixedPointIterationInfo]:
"""Solve for x = func(x) using a fixed point iteration"""

def compute_norm(x: ArrayTree, xp: ArrayTree) -> float:
return norm_fn(ravel_pytree(jax.tree_util.tree_map(jnp.subtract, x, xp))[0])

def cond_fn(args: Tuple[int, ArrayTree, ArrayTree, float]) -> bool:
n, _, _, norm = args
return (
(n < max_iters)
& jnp.isfinite(norm)
& (norm < divergence_tol)
& (norm > convergence_tol)
)

def body_fn(
args: Tuple[int, ArrayTree, ArrayTree, float]
) -> Tuple[int, ArrayTree, ArrayTree, float]:
n, x, _, _ = args
xn, aux = func(x)
norm = compute_norm(xn, x)
return n + 1, xn, aux, norm

x, aux = func(x0)
iters, x, aux, norm = jax.lax.while_loop(
cond_fn, body_fn, (0, x, aux, compute_norm(x, x0))
)
success = jnp.isfinite(norm) & (norm <= convergence_tol)
return x, aux, FixedPointIterationInfo(success, norm, iters)


def implicit_midpoint(
logdensity_fn: Callable,
kinetic_energy_fn: KineticEnergy,
*,
solver: FixedPointSolver = solve_fixed_point_iteration,
**solver_kwargs: Any,
) -> Integrator:
"""The implicit midpoint integrator with support for non-stationary kinetic energy

This is an integrator based on :cite:t:`brofos2021evaluating`, which provides
support for kinetic energies that depend on position. This integrator requires that
the kinetic energy function takes two arguments: position and momentum.

The ``solver`` parameter allows overloading of the fixed point solver. By default, a
simple fixed point iteration is used, but more advanced solvers could be implemented
in the future.
"""
logdensity_and_grad_fn = jax.value_and_grad(logdensity_fn)
kinetic_energy_grad_fn = jax.grad(
lambda q, p: kinetic_energy_fn(p, position=q), argnums=(0, 1)
)

def one_step(state: IntegratorState, step_size: float) -> IntegratorState:
position, momentum, _, _ = state

def _update(
q: ArrayTree,
p: ArrayTree,
dUdq: ArrayTree,
initial: Tuple[ArrayTree, ArrayTree] = (position, momentum),
) -> Tuple[ArrayTree, ArrayTree]:
dTdq, dHdp = kinetic_energy_grad_fn(q, p)
dHdq = jax.tree_util.tree_map(jnp.subtract, dTdq, dUdq)

# Take a step from the _initial coordinates_ using the gradients of the
# Hamiltonian evaluated at the current guess for the midpoint
q = jax.tree_util.tree_map(
lambda q_, d_: q_ + 0.5 * step_size * d_, initial[0], dHdp
)
p = jax.tree_util.tree_map(
lambda p_, d_: p_ - 0.5 * step_size * d_, initial[1], dHdq
)
return q, p

# Solve for the midpoint numerically
def _step(args: ArrayTree) -> Tuple[ArrayTree, ArrayTree]:
q, p = args
_, dLdq = logdensity_and_grad_fn(q)
return _update(q, p, dLdq), dLdq

(q, p), dLdq, info = solver(_step, (position, momentum), **solver_kwargs)
del info # TODO: Track the returned info

# Take an explicit update as recommended by Brofos & Lederman
_, dLdq = logdensity_and_grad_fn(q)
q, p = _update(q, p, dLdq, initial=(q, p))

return IntegratorState(q, p, *logdensity_and_grad_fn(q))

return one_step
Loading

0 comments on commit f12fc38

Please sign in to comment.