Skip to content

Commit

Permalink
Preconditioned mclmc (#673)
Browse files Browse the repository at this point in the history
* TESTS

* TESTS

* UPDATE DOCSTRING

* ADD STREAMING VERSION

* ADD PRECONDITIONING TO MCLMC

* ADD PRECONDITIONING TO TUNING FOR MCLMC

* UPDATE GITIGNORE

* UPDATE GITIGNORE

* UPDATE TESTS

* UPDATE TESTS

* ADD DOCSTRING

* ADD TEST

* STREAMING AVERAGE

* ADD TEST

* REFACTOR RUN_INFERENCE_ALGORITHM

* UPDATE DOCSTRING

* Precommit

* CLEAN TESTS

* GITIGNORE

* PRECOMMIT CLEAN UP

* ADD INITIAL_POSITION

* FIX TEST

* ADD TEST

* REMOVE BENCHMARKS

* BUG FIX

* CHANGE PRECISION

* CHANGE PRECISION

* RENAME O

* UPDATE STREAMING AVG

* UPDATE PR

* RENAME STD_MAT
  • Loading branch information
reubenharry authored May 25, 2024
1 parent e0a7f9e commit 5831740
Show file tree
Hide file tree
Showing 7 changed files with 302 additions and 156 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# Created by https://www.gitignore.io/api/python
# Edit at https://www.gitignore.io/?templates=python

explore.py

### Python ###
# Byte-compiled / optimized / DLL files
__pycache__/
Expand Down
152 changes: 82 additions & 70 deletions blackjax/adaptation/mclmc_adaptation.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from jax.flatten_util import ravel_pytree

from blackjax.diagnostics import effective_sample_size
from blackjax.util import pytree_size
from blackjax.util import pytree_size, streaming_average_update


class MCLMCAdaptationState(NamedTuple):
Expand All @@ -30,10 +30,13 @@ class MCLMCAdaptationState(NamedTuple):
The momentum decoherent rate for the MCLMC algorithm.
step_size
The step size used for the MCLMC algorithm.
sqrt_diag_cov
A matrix used for preconditioning.
"""

L: float
step_size: float
sqrt_diag_cov: float


def mclmc_find_L_and_step_size(
Expand All @@ -47,6 +50,7 @@ def mclmc_find_L_and_step_size(
desired_energy_var=5e-4,
trust_in_estimate=1.5,
num_effective_samples=150,
diagonal_preconditioning=True,
):
"""
Finds the optimal value of the parameters for the MCLMC algorithm.
Expand Down Expand Up @@ -78,38 +82,30 @@ def mclmc_find_L_and_step_size(
-------
A tuple containing the final state of the MCMC algorithm and the final hyperparameters.
Examples
Example
-------
.. code::
kernel = lambda std_mat : blackjax.mcmc.mclmc.build_kernel(
logdensity_fn=logdensity_fn,
integrator=integrator,
std_mat=std_mat,
)
# Define the kernel function
def kernel(x):
return x ** 2
# Define the initial state
initial_state = MCMCState(position=0, momentum=1)
# Generate a random number generator key
rng_key = jax.random.key(0)
# Find the optimal parameters for the MCLMC algorithm
final_state, final_params = mclmc_find_L_and_step_size(
(
blackjax_state_after_tuning,
blackjax_mclmc_sampler_params,
) = blackjax.mclmc_find_L_and_step_size(
mclmc_kernel=kernel,
num_steps=1000,
num_steps=num_steps,
state=initial_state,
rng_key=rng_key,
frac_tune1=0.2,
frac_tune2=0.3,
frac_tune3=0.1,
desired_energy_var=1e-4,
trust_in_estimate=2.0,
num_effective_samples=200,
rng_key=tune_key,
diagonal_preconditioning=preconditioning,
)
"""
dim = pytree_size(state.position)
params = MCLMCAdaptationState(jnp.sqrt(dim), jnp.sqrt(dim) * 0.25)
params = MCLMCAdaptationState(
jnp.sqrt(dim), jnp.sqrt(dim) * 0.25, sqrt_diag_cov=jnp.ones((dim,))
)
part1_key, part2_key = jax.random.split(rng_key, 2)

state, params = make_L_step_size_adaptation(
Expand All @@ -120,12 +116,13 @@ def kernel(x):
desired_energy_var=desired_energy_var,
trust_in_estimate=trust_in_estimate,
num_effective_samples=num_effective_samples,
diagonal_preconditioning=diagonal_preconditioning,
)(state, params, num_steps, part1_key)

if frac_tune3 != 0:
state, params = make_adaptation_L(mclmc_kernel, frac=frac_tune3, Lfactor=0.4)(
state, params, num_steps, part2_key
)
state, params = make_adaptation_L(
mclmc_kernel(params.sqrt_diag_cov), frac=frac_tune3, Lfactor=0.4
)(state, params, num_steps, part2_key)

return state, params

Expand All @@ -135,6 +132,7 @@ def make_L_step_size_adaptation(
dim,
frac_tune1,
frac_tune2,
diagonal_preconditioning,
desired_energy_var=1e-3,
trust_in_estimate=1.5,
num_effective_samples=150,
Expand All @@ -150,7 +148,7 @@ def predictor(previous_state, params, adaptive_state, rng_key):
time, x_average, step_size_max = adaptive_state

# dynamics
next_state, info = kernel(
next_state, info = kernel(params.sqrt_diag_cov)(
rng_key=rng_key,
state=previous_state,
L=params.L,
Expand Down Expand Up @@ -185,68 +183,84 @@ def predictor(previous_state, params, adaptive_state, rng_key):
) * step_size_max # if the proposed stepsize is above the stepsize where we have seen divergences
params_new = params._replace(step_size=step_size)

return state, params_new, params_new, (time, x_average, step_size_max), success

def update_kalman(x, state, outer_weight, success, step_size):
"""kalman filter to estimate the size of the posterior"""
time, x_average, x_squared_average = state
weight = outer_weight * step_size * success
zero_prevention = 1 - outer_weight
x_average = (time * x_average + weight * x) / (
time + weight + zero_prevention
) # Update <f(x)> with a Kalman filter
x_squared_average = (time * x_squared_average + weight * jnp.square(x)) / (
time + weight + zero_prevention
) # Update <f(x)> with a Kalman filter
time += weight
return (time, x_average, x_squared_average)
adaptive_state = (time, x_average, step_size_max)

adap0 = (0.0, 0.0, jnp.inf)
return state, params_new, adaptive_state, success

def step(iteration_state, weight_and_key):
"""does one step of the dynamics and updates the estimate of the posterior size and optimal stepsize"""

outer_weight, rng_key = weight_and_key
state, params, adaptive_state, kalman_state = iteration_state
state, params, params_final, adaptive_state, success = predictor(
mask, rng_key = weight_and_key
state, params, adaptive_state, streaming_avg = iteration_state

state, params, adaptive_state, success = predictor(
state, params, adaptive_state, rng_key
)
position, _ = ravel_pytree(state.position)
kalman_state = update_kalman(
position, kalman_state, outer_weight, success, params.step_size

x = ravel_pytree(state.position)[0]
# update the running average of x, x^2
streaming_avg = streaming_average_update(
expectation=jnp.array([x, jnp.square(x)]),
streaming_avg=streaming_avg,
weight=(1 - mask) * success * params.step_size,
zero_prevention=mask,
)

return (state, params_final, adaptive_state, kalman_state), None
return (state, params, adaptive_state, streaming_avg), None

run_steps = lambda xs, state, params: jax.lax.scan(
step,
init=(
state,
params,
(0.0, 0.0, jnp.inf),
(0.0, jnp.array([jnp.zeros(dim), jnp.zeros(dim)])),
),
xs=xs,
)[0]

def L_step_size_adaptation(state, params, num_steps, rng_key):
num_steps1, num_steps2 = int(num_steps * frac_tune1), int(
num_steps * frac_tune2
num_steps1, num_steps2 = (
int(num_steps * frac_tune1) + 1,
int(num_steps * frac_tune2) + 1,
)
L_step_size_adaptation_keys = jax.random.split(
rng_key, num_steps1 + num_steps2 + 1
)
L_step_size_adaptation_keys, final_key = (
L_step_size_adaptation_keys[:-1],
L_step_size_adaptation_keys[-1],
)
L_step_size_adaptation_keys = jax.random.split(rng_key, num_steps1 + num_steps2)

# we use the last num_steps2 to compute the diagonal preconditioner
outer_weights = jnp.concatenate((jnp.zeros(num_steps1), jnp.ones(num_steps2)))

# initial state of the kalman filter
kalman_state = (0.0, jnp.zeros(dim), jnp.zeros(dim))
mask = 1 - jnp.concatenate((jnp.zeros(num_steps1), jnp.ones(num_steps2)))

# run the steps
kalman_state, *_ = jax.lax.scan(
step,
init=(state, params, adap0, kalman_state),
xs=(outer_weights, L_step_size_adaptation_keys),
length=num_steps1 + num_steps2,
state, params, _, (_, average) = run_steps(
xs=(mask, L_step_size_adaptation_keys), state=state, params=params
)
state, params, _, kalman_state_output = kalman_state

L = params.L
# determine L
sqrt_diag_cov = params.sqrt_diag_cov
if num_steps2 != 0.0:
_, F1, F2 = kalman_state_output
variances = F2 - jnp.square(F1)
x_average, x_squared_average = average[0], average[1]
variances = x_squared_average - jnp.square(x_average)
L = jnp.sqrt(jnp.sum(variances))

return state, MCLMCAdaptationState(L, params.step_size)
if diagonal_preconditioning:
sqrt_diag_cov = jnp.sqrt(variances)
params = params._replace(sqrt_diag_cov=sqrt_diag_cov)
L = jnp.sqrt(dim)

# readjust the stepsize
steps = num_steps2 // 3 # we do some small number of steps
keys = jax.random.split(final_key, steps)
state, params, _, (_, average) = run_steps(
xs=(jnp.ones(steps), keys), state=state, params=params
)

return state, MCLMCAdaptationState(L, params.step_size, sqrt_diag_cov)

return L_step_size_adaptation

Expand All @@ -258,7 +272,6 @@ def adaptation_L(state, params, num_steps, key):
num_steps = int(num_steps * frac)
adaptation_L_keys = jax.random.split(key, num_steps)

# run kernel in the normal way
def step(state, key):
next_state, _ = kernel(
rng_key=key,
Expand Down Expand Up @@ -297,5 +310,4 @@ def handle_nans(previous_state, next_state, step_size, step_size_max, kinetic_ch
(next_state, step_size_max, kinetic_change),
(previous_state, step_size * reduced_step_size, 0.0),
)

return nonans, state, step_size, kinetic_change
Loading

0 comments on commit 5831740

Please sign in to comment.