Skip to content

Commit

Permalink
Add low-rank-modified metric
Browse files Browse the repository at this point in the history
  • Loading branch information
aseyboldt committed May 24, 2024
1 parent e0a7f9e commit b323a77
Showing 1 changed file with 125 additions and 6 deletions.
131 changes: 125 additions & 6 deletions blackjax/mcmc/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@
We can also generate a relativistic dynamic :cite:p:`lu2017relativistic`.
"""
from typing import Callable, NamedTuple, Optional, Protocol, Union

from typing import Any, Callable, NamedTuple, Optional, Protocol, Union

import jax.numpy as jnp
import jax.scipy as jscipy
Expand All @@ -38,14 +39,18 @@
from blackjax.types import Array, ArrayLikeTree, ArrayTree, PRNGKey
from blackjax.util import generate_gaussian_noise

__all__ = ["default_metric", "gaussian_euclidean", "gaussian_riemannian"]
__all__ = [
"default_metric",
"gaussian_euclidean",
"gaussian_riemannian",
"gaussian_euclidean_low_rank",
]


class KineticEnergy(Protocol):
def __call__(
self, momentum: ArrayLikeTree, position: Optional[ArrayLikeTree] = None
) -> float:
...
) -> float: ...


class CheckTurning(Protocol):
Expand All @@ -56,14 +61,14 @@ def __call__(
momentum_sum: ArrayLikeTree,
position_left: Optional[ArrayLikeTree] = None,
position_right: Optional[ArrayLikeTree] = None,
) -> bool:
...
) -> bool: ...


class Metric(NamedTuple):
sample_momentum: Callable[[PRNGKey, ArrayLikeTree], ArrayLikeTree]
kinetic_energy: KineticEnergy
check_turning: CheckTurning
data: Any = None


MetricTypes = Union[Metric, Array, Callable[[ArrayLikeTree], Array]]
Expand Down Expand Up @@ -208,6 +213,120 @@ def is_turning(
return Metric(momentum_generator, kinetic_energy, is_turning)


def gaussian_euclidean_low_rank(
diagonal_scale_std: Array,
eigenvectors: Array,
eigenvalues: Array,
) -> Metric:
r"""Hamiltonian dynamic on euclidean manifold with normally-distributed momentum
:cite:p:`betancourt2013general`.
The gaussian euclidean metric is a euclidean metric further characterized
by setting the conditional probability density :math:`\pi(momentum|position)`
to follow a standard gaussian distribution. A Newtonian hamiltonian
dynamics is assumed.
This uses the mass matrix $(D^{-1}(V(\Sigma - I)V^T + I)D^{-1})^{-1}$.
Parameters
----------
diagonal_scale_std
The diagonal $D^{-1}$. This should for instance correspond to the standard deviation
of the posterior.
eigenvectors
An arbitrary number of eigenvectors
eigenvalues
The corresponding eigenvalues
Returns
-------
momentum_generator
A function that generates a value for the momentum at random.
kinetic_energy
A function that returns the kinetic energy given the momentum.
is_turning
A function that determines whether a trajectory is turning back on
itself given the values of the momentum along the trajectory.
"""
(ndim,) = jnp.shape(diagonal_scale_std)
(ndim_, n_eigs) = jnp.shape(eigenvectors)
if ndim != ndim_:
raise ValueError("Shape mismatch in metric.")

(n_eigs_,) = jnp.shape(eigenvalues)
if n_eigs != n_eigs_:
raise ValueError("Shape mismatch in metric.")

# Compute (V(\Sigma - I)V^T + I)x
def inner_matrix_mult(vals, vecs, x):
projected = x @ vecs
scaled = (vals - 1) * projected
projected_back = vecs @ scaled
return projected_back + x

def inv_mass_matrix_mult(x):
scaled = x * diagonal_scale_std
product = inner_matrix_mult(eigenvalues, eigenvectors, scaled)
return product * diagonal_scale_std

def momentum_generator(rng_key: PRNGKey, position: ArrayLikeTree) -> ArrayTree:
unit_draws = generate_gaussian_noise(rng_key, position)
sqrt_vals = jnp.sqrt(jnp.reciprocal(eigenvalues))
sqrt_inv_diag = jnp.sqrt(jnp.reciprocal(diagonal_scale_std))
return inner_matrix_mult(sqrt_vals, eigenvectors, unit_draws) * sqrt_inv_diag

def kinetic_energy(
momentum: ArrayLikeTree, position: Optional[ArrayLikeTree] = None
) -> float:
del position
momentum, _ = ravel_pytree(momentum)
velocity = inv_mass_matrix_mult(momentum)
kinetic_energy_val = 0.5 * jnp.dot(velocity, momentum)
return kinetic_energy_val

def is_turning(
momentum_left: ArrayLikeTree,
momentum_right: ArrayLikeTree,
momentum_sum: ArrayLikeTree,
position_left: Optional[ArrayLikeTree] = None,
position_right: Optional[ArrayLikeTree] = None,
) -> bool:
"""Generalized U-turn criterion :cite:p:`betancourt2013generalizing,nuts_uturn`.
Parameters
----------
momentum_left
Momentum of the leftmost point of the trajectory.
momentum_right
Momentum of the rightmost point of the trajectory.
momentum_sum
Sum of the momenta along the trajectory.
"""
del position_left, position_right

m_left, _ = ravel_pytree(momentum_left)
m_right, _ = ravel_pytree(momentum_right)
m_sum, _ = ravel_pytree(momentum_sum)

velocity_left = inv_mass_matrix_mult(m_left)
velocity_right = inv_mass_matrix_mult(m_right)

# rho = m_sum
rho = m_sum - (m_right + m_left) / 2
turning_at_left = jnp.dot(velocity_left, rho) <= 0
turning_at_right = jnp.dot(velocity_right, rho) <= 0
return turning_at_left | turning_at_right

return Metric(
momentum_generator,
kinetic_energy,
is_turning,
data=(diagonal_scale_std, eigenvalues, eigenvectors),
)


def gaussian_riemannian(
mass_matrix_fn: Callable,
) -> Metric:
Expand Down

0 comments on commit b323a77

Please sign in to comment.