Skip to content

Commit

Permalink
UPDATE PR
Browse files Browse the repository at this point in the history
  • Loading branch information
reubenharry committed May 24, 2024
1 parent 4e2b7c0 commit 6bacb6c
Show file tree
Hide file tree
Showing 6 changed files with 76 additions and 54 deletions.
81 changes: 49 additions & 32 deletions blackjax/adaptation/mclmc_adaptation.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from jax.flatten_util import ravel_pytree

from blackjax.diagnostics import effective_sample_size
from blackjax.util import pytree_size, streaming_average
from blackjax.util import pytree_size, streaming_average_update


class MCLMCAdaptationState(NamedTuple):
Expand All @@ -30,13 +30,13 @@ class MCLMCAdaptationState(NamedTuple):
The momentum decoherent rate for the MCLMC algorithm.
step_size
The step size used for the MCLMC algorithm.
std_mat
sqrt_diag_cov_mat
A matrix used for preconditioning.
"""

L: float
step_size: float
std_mat: float
sqrt_diag_cov_mat: float


def mclmc_find_L_and_step_size(
Expand Down Expand Up @@ -81,10 +81,30 @@ def mclmc_find_L_and_step_size(
Returns
-------
A tuple containing the final state of the MCMC algorithm and the final hyperparameters.
Example
-------
.. code::
kernel = lambda std_mat : blackjax.mcmc.mclmc.build_kernel(
logdensity_fn=logdensity_fn,
integrator=integrator,
std_mat=std_mat,
)
(
blackjax_state_after_tuning,
blackjax_mclmc_sampler_params,
) = blackjax.mclmc_find_L_and_step_size(
mclmc_kernel=kernel,
num_steps=num_steps,
state=initial_state,
rng_key=tune_key,
diagonal_preconditioning=preconditioning,
)
"""
dim = pytree_size(state.position)
params = MCLMCAdaptationState(
jnp.sqrt(dim), jnp.sqrt(dim) * 0.25, std_mat=jnp.ones((dim,))
jnp.sqrt(dim), jnp.sqrt(dim) * 0.25, sqrt_diag_cov_mat=jnp.ones((dim,))
)
part1_key, part2_key = jax.random.split(rng_key, 2)

Expand All @@ -101,7 +121,7 @@ def mclmc_find_L_and_step_size(

if frac_tune3 != 0:
state, params = make_adaptation_L(
mclmc_kernel(params.std_mat), frac=frac_tune3, Lfactor=0.4
mclmc_kernel(params.sqrt_diag_cov_mat), frac=frac_tune3, Lfactor=0.4
)(state, params, num_steps, part2_key)

return state, params
Expand All @@ -128,7 +148,7 @@ def predictor(previous_state, params, adaptive_state, rng_key):
time, x_average, step_size_max = adaptive_state

# dynamics
next_state, info = kernel(params.std_mat)(
next_state, info = kernel(params.sqrt_diag_cov_mat)(
rng_key=rng_key,
state=previous_state,
L=params.L,
Expand Down Expand Up @@ -179,7 +199,7 @@ def step(iteration_state, weight_and_key):

x = ravel_pytree(state.position)[0]
# update the running average of x, x^2
streaming_avg = streaming_average(
streaming_avg = streaming_average_update(
expectation=jnp.array([x, jnp.square(x)]),
streaming_avg=streaming_avg,
weight=(1 - mask) * success * params.step_size,
Expand All @@ -188,6 +208,17 @@ def step(iteration_state, weight_and_key):

return (state, params, adaptive_state, streaming_avg), None

run_steps = lambda xs, state, params: jax.lax.scan(
step,
init=(
state,
params,
(0.0, 0.0, jnp.inf),
(0.0, jnp.array([jnp.zeros(dim), jnp.zeros(dim)])),
),
xs=xs,
)[0]

def L_step_size_adaptation(state, params, num_steps, rng_key):
num_steps1, num_steps2 = (
int(num_steps * frac_tune1) + 1,
Expand All @@ -205,45 +236,31 @@ def L_step_size_adaptation(state, params, num_steps, rng_key):
mask = 1 - jnp.concatenate((jnp.zeros(num_steps1), jnp.ones(num_steps2)))

# run the steps
state, params, _, (_, average) = jax.lax.scan(
step,
init=(
state,
params,
(0.0, 0.0, jnp.inf),
(0.0, jnp.array([jnp.zeros(dim), jnp.zeros(dim)])),
),
xs=(mask, L_step_size_adaptation_keys),
)[0]
state, params, _, (_, average) = run_steps(
xs=(mask, L_step_size_adaptation_keys), state=state, params=params
)

L = params.L
# determine L
std_mat = params.std_mat
sqrt_diag_cov_mat = params.sqrt_diag_cov_mat
if num_steps2 != 0.0:
x_average, x_squared_average = average[0], average[1]
variances = x_squared_average - jnp.square(x_average)
L = jnp.sqrt(jnp.sum(variances))

if diagonal_preconditioning:
std_mat = jnp.sqrt(variances)
params = params._replace(std_mat=std_mat)
sqrt_diag_cov_mat = jnp.sqrt(variances)
params = params._replace(sqrt_diag_cov_mat=sqrt_diag_cov_mat)
L = jnp.sqrt(dim)

# readjust the stepsize
steps = num_steps2 // 3 # we do some small number of steps
keys = jax.random.split(final_key, steps)
state, params, _, (_, average) = jax.lax.scan(
step,
init=(
state,
params,
(0.0, 0.0, jnp.inf),
(0.0, jnp.array([jnp.zeros(dim), jnp.zeros(dim)])),
),
xs=(jnp.ones(steps), keys),
)[0]

return state, MCLMCAdaptationState(L, params.step_size, std_mat)
state, params, _, (_, average) = run_steps(
xs=(jnp.ones(steps), keys), state=state, params=params
)

return state, MCLMCAdaptationState(L, params.step_size, sqrt_diag_cov_mat)

return L_step_size_adaptation

Expand Down
11 changes: 6 additions & 5 deletions blackjax/mcmc/integrators.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,7 +294,7 @@ def _normalized_flatten_array(x, tol=1e-13):
return jnp.where(norm > tol, x / norm, x), norm


def esh_dynamics_momentum_update_one_step(std_mat):
def esh_dynamics_momentum_update_one_step(sqrt_diag_cov_mat=1.0):
def update(
momentum: ArrayTree,
logdensity_grad: ArrayTree,
Expand All @@ -313,7 +313,7 @@ def update(

logdensity_grad = logdensity_grad
flatten_grads, unravel_fn = ravel_pytree(logdensity_grad)
flatten_grads = flatten_grads * std_mat
flatten_grads = flatten_grads * sqrt_diag_cov_mat
flatten_momentum, _ = ravel_pytree(momentum)
dims = flatten_momentum.shape[0]
normalized_gradient, gradient_norm = _normalized_flatten_array(flatten_grads)
Expand All @@ -325,7 +325,7 @@ def update(
+ 2 * zeta * flatten_momentum
)
new_momentum_normalized, _ = _normalized_flatten_array(new_momentum_raw)
gr = unravel_fn(new_momentum_normalized * std_mat)
gr = unravel_fn(new_momentum_normalized * sqrt_diag_cov_mat)
next_momentum = unravel_fn(new_momentum_normalized)
kinetic_energy_change = (
delta
Expand Down Expand Up @@ -357,11 +357,12 @@ def format_isokinetic_state_output(

def generate_isokinetic_integrator(coefficients):
def isokinetic_integrator(
logdensity_fn: Callable, std_mat: ArrayTree = 1.0, *args, **kwargs
logdensity_fn: Callable, *args, **kwargs
) -> GeneralIntegrator:
sqrt_diag_cov_mat = kwargs.get("sqrt_diag_cov_mat", 1.0)
position_update_fn = euclidean_position_update_fn(logdensity_fn)
one_step = generalized_two_stage_integrator(
esh_dynamics_momentum_update_one_step(std_mat),
esh_dynamics_momentum_update_one_step(sqrt_diag_cov_mat),
position_update_fn,
coefficients,
format_output_fn=format_isokinetic_state_output,
Expand Down
8 changes: 4 additions & 4 deletions blackjax/mcmc/mclmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def init(position: ArrayLike, logdensity_fn, rng_key):
)


def build_kernel(logdensity_fn, std_mat, integrator):
def build_kernel(logdensity_fn, sqrt_diag_cov_mat, integrator):
"""Build a HMC kernel.
Parameters
Expand All @@ -80,7 +80,7 @@ def build_kernel(logdensity_fn, std_mat, integrator):
"""

step = with_isokinetic_maruyama(integrator(logdensity_fn, std_mat))
step = with_isokinetic_maruyama(integrator(logdensity_fn, sqrt_diag_cov_mat))

def kernel(
rng_key: PRNGKey, state: IntegratorState, L: float, step_size: float
Expand All @@ -105,7 +105,7 @@ def as_top_level_api(
L,
step_size,
integrator=isokinetic_mclachlan,
std_mat=1.0,
sqrt_diag_cov_mat=1.0,
) -> SamplingAlgorithm:
"""The general mclmc kernel builder (:meth:`blackjax.mcmc.mclmc.build_kernel`, alias `blackjax.mclmc.build_kernel`) can be
cumbersome to manipulate. Since most users only need to specify the kernel
Expand Down Expand Up @@ -153,7 +153,7 @@ def as_top_level_api(
A ``SamplingAlgorithm``.
"""

kernel = build_kernel(logdensity_fn, std_mat, integrator)
kernel = build_kernel(logdensity_fn, sqrt_diag_cov_mat, integrator)

def init_fn(position: ArrayLike, rng_key: PRNGKey):
return init(position, logdensity_fn, rng_key)
Expand Down
6 changes: 4 additions & 2 deletions blackjax/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@ def one_step(average_and_state, xs, return_state):
_, rng_key = xs
average, state = average_and_state
state, info = inference_algorithm.step(rng_key, state)
average = streaming_average(expectation(transform(state)), average)
average = streaming_average_update(expectation(transform(state)), average)
if return_state:
return (average, state), (transform(state), info)
else:
Expand All @@ -232,7 +232,9 @@ def one_step(average_and_state, xs, return_state):
return transform(final_state), state_history, info_history


def streaming_average(expectation, streaming_avg, weight=1.0, zero_prevention=0.0):
def streaming_average_update(
expectation, streaming_avg, weight=1.0, zero_prevention=0.0
):
"""Compute the streaming average of a function O(x) using a weight.
Parameters:
----------
Expand Down
6 changes: 4 additions & 2 deletions tests/mcmc/test_integrators.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,9 @@ def test_esh_momentum_update(self, dims):
) / (jnp.cosh(delta) + jnp.dot(gradient_normalized, momentum * jnp.sinh(delta)))

# Efficient implementation
update_stable = self.variant(esh_dynamics_momentum_update_one_step(std_mat=1.0))
update_stable = self.variant(
esh_dynamics_momentum_update_one_step(sqrt_diag_cov_mat=1.0)
)
next_momentum1, *_ = update_stable(momentum, gradient, step_size, 1.0)
np.testing.assert_array_almost_equal(next_momentum, next_momentum1)

Expand All @@ -258,7 +260,7 @@ def test_isokinetic_leapfrog(self):
next_state, kinetic_energy_change = step(initial_state, step_size)

# explicit integration
op1 = esh_dynamics_momentum_update_one_step(std_mat=1.0)
op1 = esh_dynamics_momentum_update_one_step(sqrt_diag_cov_mat=1.0)
op2 = integrators.euclidean_position_update_fn(logdensity_fn)
position, momentum, _, logdensity_grad = initial_state
momentum, kinetic_grad, kinetic_energy_change0 = op1(
Expand Down
18 changes: 9 additions & 9 deletions tests/mcmc/test_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,10 +111,10 @@ def run_mclmc(
position=initial_position, logdensity_fn=logdensity_fn, rng_key=init_key
)

kernel = lambda std_mat: blackjax.mcmc.mclmc.build_kernel(
kernel = lambda sqrt_diag_cov_mat: blackjax.mcmc.mclmc.build_kernel(
logdensity_fn=logdensity_fn,
integrator=blackjax.mcmc.mclmc.isokinetic_mclachlan,
std_mat=std_mat,
sqrt_diag_cov_mat=sqrt_diag_cov_mat,
)

(
Expand All @@ -132,7 +132,7 @@ def run_mclmc(
logdensity_fn,
L=blackjax_mclmc_sampler_params.L,
step_size=blackjax_mclmc_sampler_params.step_size,
std_mat=blackjax_mclmc_sampler_params.std_mat,
sqrt_diag_cov_mat=blackjax_mclmc_sampler_params.sqrt_diag_cov_mat,
)

_, samples, _ = run_inference_algorithm(
Expand Down Expand Up @@ -300,7 +300,7 @@ def __init__(self, d, condition_number):

integrator = isokinetic_mclachlan

def get_std_mat():
def get_sqrt_diag_cov_mat():
init_key, tune_key = jax.random.split(key)

initial_position = model.sample_init(init_key)
Expand All @@ -311,10 +311,10 @@ def get_std_mat():
rng_key=init_key,
)

kernel = lambda std_mat: blackjax.mcmc.mclmc.build_kernel(
kernel = lambda sqrt_diag_cov_mat: blackjax.mcmc.mclmc.build_kernel(
logdensity_fn=model.logdensity_fn,
integrator=integrator,
std_mat=std_mat,
sqrt_diag_cov_mat=sqrt_diag_cov_mat,
)

(
Expand All @@ -328,13 +328,13 @@ def get_std_mat():
diagonal_preconditioning=True,
)

return blackjax_mclmc_sampler_params.std_mat
return blackjax_mclmc_sampler_params.sqrt_diag_cov_mat

std_mat = get_std_mat()
sqrt_diag_cov_mat = get_sqrt_diag_cov_mat()
assert (
jnp.abs(
jnp.dot(
(std_mat**2) / jnp.linalg.norm(std_mat**2),
(sqrt_diag_cov_mat**2) / jnp.linalg.norm(sqrt_diag_cov_mat**2),
eigs / jnp.linalg.norm(eigs),
)
- 1
Expand Down

0 comments on commit 6bacb6c

Please sign in to comment.