diff --git a/blackjax/adaptation/mclmc_adaptation.py b/blackjax/adaptation/mclmc_adaptation.py index 27321321a..73fa6a327 100644 --- a/blackjax/adaptation/mclmc_adaptation.py +++ b/blackjax/adaptation/mclmc_adaptation.py @@ -20,7 +20,7 @@ from jax.flatten_util import ravel_pytree from blackjax.diagnostics import effective_sample_size -from blackjax.util import pytree_size, streaming_average +from blackjax.util import pytree_size, streaming_average_update class MCLMCAdaptationState(NamedTuple): @@ -30,13 +30,13 @@ class MCLMCAdaptationState(NamedTuple): The momentum decoherent rate for the MCLMC algorithm. step_size The step size used for the MCLMC algorithm. - std_mat + sqrt_diag_cov_mat A matrix used for preconditioning. """ L: float step_size: float - std_mat: float + sqrt_diag_cov_mat: float def mclmc_find_L_and_step_size( @@ -81,10 +81,30 @@ def mclmc_find_L_and_step_size( Returns ------- A tuple containing the final state of the MCMC algorithm and the final hyperparameters. + + Example + ------- + .. code:: + kernel = lambda std_mat : blackjax.mcmc.mclmc.build_kernel( + logdensity_fn=logdensity_fn, + integrator=integrator, + std_mat=std_mat, + ) + + ( + blackjax_state_after_tuning, + blackjax_mclmc_sampler_params, + ) = blackjax.mclmc_find_L_and_step_size( + mclmc_kernel=kernel, + num_steps=num_steps, + state=initial_state, + rng_key=tune_key, + diagonal_preconditioning=preconditioning, + ) """ dim = pytree_size(state.position) params = MCLMCAdaptationState( - jnp.sqrt(dim), jnp.sqrt(dim) * 0.25, std_mat=jnp.ones((dim,)) + jnp.sqrt(dim), jnp.sqrt(dim) * 0.25, sqrt_diag_cov_mat=jnp.ones((dim,)) ) part1_key, part2_key = jax.random.split(rng_key, 2) @@ -101,7 +121,7 @@ def mclmc_find_L_and_step_size( if frac_tune3 != 0: state, params = make_adaptation_L( - mclmc_kernel(params.std_mat), frac=frac_tune3, Lfactor=0.4 + mclmc_kernel(params.sqrt_diag_cov_mat), frac=frac_tune3, Lfactor=0.4 )(state, params, num_steps, part2_key) return state, params @@ -128,7 +148,7 @@ def predictor(previous_state, params, adaptive_state, rng_key): time, x_average, step_size_max = adaptive_state # dynamics - next_state, info = kernel(params.std_mat)( + next_state, info = kernel(params.sqrt_diag_cov_mat)( rng_key=rng_key, state=previous_state, L=params.L, @@ -179,7 +199,7 @@ def step(iteration_state, weight_and_key): x = ravel_pytree(state.position)[0] # update the running average of x, x^2 - streaming_avg = streaming_average( + streaming_avg = streaming_average_update( expectation=jnp.array([x, jnp.square(x)]), streaming_avg=streaming_avg, weight=(1 - mask) * success * params.step_size, @@ -188,6 +208,17 @@ def step(iteration_state, weight_and_key): return (state, params, adaptive_state, streaming_avg), None + run_steps = lambda xs, state, params: jax.lax.scan( + step, + init=( + state, + params, + (0.0, 0.0, jnp.inf), + (0.0, jnp.array([jnp.zeros(dim), jnp.zeros(dim)])), + ), + xs=xs, + )[0] + def L_step_size_adaptation(state, params, num_steps, rng_key): num_steps1, num_steps2 = ( int(num_steps * frac_tune1) + 1, @@ -205,45 +236,31 @@ def L_step_size_adaptation(state, params, num_steps, rng_key): mask = 1 - jnp.concatenate((jnp.zeros(num_steps1), jnp.ones(num_steps2))) # run the steps - state, params, _, (_, average) = jax.lax.scan( - step, - init=( - state, - params, - (0.0, 0.0, jnp.inf), - (0.0, jnp.array([jnp.zeros(dim), jnp.zeros(dim)])), - ), - xs=(mask, L_step_size_adaptation_keys), - )[0] + state, params, _, (_, average) = run_steps( + xs=(mask, L_step_size_adaptation_keys), state=state, params=params + ) L = params.L # determine L - std_mat = params.std_mat + sqrt_diag_cov_mat = params.sqrt_diag_cov_mat if num_steps2 != 0.0: x_average, x_squared_average = average[0], average[1] variances = x_squared_average - jnp.square(x_average) L = jnp.sqrt(jnp.sum(variances)) if diagonal_preconditioning: - std_mat = jnp.sqrt(variances) - params = params._replace(std_mat=std_mat) + sqrt_diag_cov_mat = jnp.sqrt(variances) + params = params._replace(sqrt_diag_cov_mat=sqrt_diag_cov_mat) L = jnp.sqrt(dim) # readjust the stepsize steps = num_steps2 // 3 # we do some small number of steps keys = jax.random.split(final_key, steps) - state, params, _, (_, average) = jax.lax.scan( - step, - init=( - state, - params, - (0.0, 0.0, jnp.inf), - (0.0, jnp.array([jnp.zeros(dim), jnp.zeros(dim)])), - ), - xs=(jnp.ones(steps), keys), - )[0] - - return state, MCLMCAdaptationState(L, params.step_size, std_mat) + state, params, _, (_, average) = run_steps( + xs=(jnp.ones(steps), keys), state=state, params=params + ) + + return state, MCLMCAdaptationState(L, params.step_size, sqrt_diag_cov_mat) return L_step_size_adaptation diff --git a/blackjax/mcmc/integrators.py b/blackjax/mcmc/integrators.py index ed11fb1a0..2dce5671e 100644 --- a/blackjax/mcmc/integrators.py +++ b/blackjax/mcmc/integrators.py @@ -294,7 +294,7 @@ def _normalized_flatten_array(x, tol=1e-13): return jnp.where(norm > tol, x / norm, x), norm -def esh_dynamics_momentum_update_one_step(std_mat): +def esh_dynamics_momentum_update_one_step(sqrt_diag_cov_mat=1.0): def update( momentum: ArrayTree, logdensity_grad: ArrayTree, @@ -313,7 +313,7 @@ def update( logdensity_grad = logdensity_grad flatten_grads, unravel_fn = ravel_pytree(logdensity_grad) - flatten_grads = flatten_grads * std_mat + flatten_grads = flatten_grads * sqrt_diag_cov_mat flatten_momentum, _ = ravel_pytree(momentum) dims = flatten_momentum.shape[0] normalized_gradient, gradient_norm = _normalized_flatten_array(flatten_grads) @@ -325,7 +325,7 @@ def update( + 2 * zeta * flatten_momentum ) new_momentum_normalized, _ = _normalized_flatten_array(new_momentum_raw) - gr = unravel_fn(new_momentum_normalized * std_mat) + gr = unravel_fn(new_momentum_normalized * sqrt_diag_cov_mat) next_momentum = unravel_fn(new_momentum_normalized) kinetic_energy_change = ( delta @@ -357,11 +357,12 @@ def format_isokinetic_state_output( def generate_isokinetic_integrator(coefficients): def isokinetic_integrator( - logdensity_fn: Callable, std_mat: ArrayTree = 1.0, *args, **kwargs + logdensity_fn: Callable, *args, **kwargs ) -> GeneralIntegrator: + sqrt_diag_cov_mat = kwargs.get("sqrt_diag_cov_mat", 1.0) position_update_fn = euclidean_position_update_fn(logdensity_fn) one_step = generalized_two_stage_integrator( - esh_dynamics_momentum_update_one_step(std_mat), + esh_dynamics_momentum_update_one_step(sqrt_diag_cov_mat), position_update_fn, coefficients, format_output_fn=format_isokinetic_state_output, diff --git a/blackjax/mcmc/mclmc.py b/blackjax/mcmc/mclmc.py index 62a6da735..d841f64e3 100644 --- a/blackjax/mcmc/mclmc.py +++ b/blackjax/mcmc/mclmc.py @@ -60,7 +60,7 @@ def init(position: ArrayLike, logdensity_fn, rng_key): ) -def build_kernel(logdensity_fn, std_mat, integrator): +def build_kernel(logdensity_fn, sqrt_diag_cov_mat, integrator): """Build a HMC kernel. Parameters @@ -80,7 +80,7 @@ def build_kernel(logdensity_fn, std_mat, integrator): """ - step = with_isokinetic_maruyama(integrator(logdensity_fn, std_mat)) + step = with_isokinetic_maruyama(integrator(logdensity_fn, sqrt_diag_cov_mat)) def kernel( rng_key: PRNGKey, state: IntegratorState, L: float, step_size: float @@ -105,7 +105,7 @@ def as_top_level_api( L, step_size, integrator=isokinetic_mclachlan, - std_mat=1.0, + sqrt_diag_cov_mat=1.0, ) -> SamplingAlgorithm: """The general mclmc kernel builder (:meth:`blackjax.mcmc.mclmc.build_kernel`, alias `blackjax.mclmc.build_kernel`) can be cumbersome to manipulate. Since most users only need to specify the kernel @@ -153,7 +153,7 @@ def as_top_level_api( A ``SamplingAlgorithm``. """ - kernel = build_kernel(logdensity_fn, std_mat, integrator) + kernel = build_kernel(logdensity_fn, sqrt_diag_cov_mat, integrator) def init_fn(position: ArrayLike, rng_key: PRNGKey): return init(position, logdensity_fn, rng_key) diff --git a/blackjax/util.py b/blackjax/util.py index 02c27e51c..71d7345fb 100644 --- a/blackjax/util.py +++ b/blackjax/util.py @@ -209,7 +209,7 @@ def one_step(average_and_state, xs, return_state): _, rng_key = xs average, state = average_and_state state, info = inference_algorithm.step(rng_key, state) - average = streaming_average(expectation(transform(state)), average) + average = streaming_average_update(expectation(transform(state)), average) if return_state: return (average, state), (transform(state), info) else: @@ -232,7 +232,9 @@ def one_step(average_and_state, xs, return_state): return transform(final_state), state_history, info_history -def streaming_average(expectation, streaming_avg, weight=1.0, zero_prevention=0.0): +def streaming_average_update( + expectation, streaming_avg, weight=1.0, zero_prevention=0.0 +): """Compute the streaming average of a function O(x) using a weight. Parameters: ---------- diff --git a/tests/mcmc/test_integrators.py b/tests/mcmc/test_integrators.py index 68c12c499..3439f52e6 100644 --- a/tests/mcmc/test_integrators.py +++ b/tests/mcmc/test_integrators.py @@ -234,7 +234,9 @@ def test_esh_momentum_update(self, dims): ) / (jnp.cosh(delta) + jnp.dot(gradient_normalized, momentum * jnp.sinh(delta))) # Efficient implementation - update_stable = self.variant(esh_dynamics_momentum_update_one_step(std_mat=1.0)) + update_stable = self.variant( + esh_dynamics_momentum_update_one_step(sqrt_diag_cov_mat=1.0) + ) next_momentum1, *_ = update_stable(momentum, gradient, step_size, 1.0) np.testing.assert_array_almost_equal(next_momentum, next_momentum1) @@ -258,7 +260,7 @@ def test_isokinetic_leapfrog(self): next_state, kinetic_energy_change = step(initial_state, step_size) # explicit integration - op1 = esh_dynamics_momentum_update_one_step(std_mat=1.0) + op1 = esh_dynamics_momentum_update_one_step(sqrt_diag_cov_mat=1.0) op2 = integrators.euclidean_position_update_fn(logdensity_fn) position, momentum, _, logdensity_grad = initial_state momentum, kinetic_grad, kinetic_energy_change0 = op1( diff --git a/tests/mcmc/test_sampling.py b/tests/mcmc/test_sampling.py index 604316e48..fb272ae7a 100644 --- a/tests/mcmc/test_sampling.py +++ b/tests/mcmc/test_sampling.py @@ -111,10 +111,10 @@ def run_mclmc( position=initial_position, logdensity_fn=logdensity_fn, rng_key=init_key ) - kernel = lambda std_mat: blackjax.mcmc.mclmc.build_kernel( + kernel = lambda sqrt_diag_cov_mat: blackjax.mcmc.mclmc.build_kernel( logdensity_fn=logdensity_fn, integrator=blackjax.mcmc.mclmc.isokinetic_mclachlan, - std_mat=std_mat, + sqrt_diag_cov_mat=sqrt_diag_cov_mat, ) ( @@ -132,7 +132,7 @@ def run_mclmc( logdensity_fn, L=blackjax_mclmc_sampler_params.L, step_size=blackjax_mclmc_sampler_params.step_size, - std_mat=blackjax_mclmc_sampler_params.std_mat, + sqrt_diag_cov_mat=blackjax_mclmc_sampler_params.sqrt_diag_cov_mat, ) _, samples, _ = run_inference_algorithm( @@ -300,7 +300,7 @@ def __init__(self, d, condition_number): integrator = isokinetic_mclachlan - def get_std_mat(): + def get_sqrt_diag_cov_mat(): init_key, tune_key = jax.random.split(key) initial_position = model.sample_init(init_key) @@ -311,10 +311,10 @@ def get_std_mat(): rng_key=init_key, ) - kernel = lambda std_mat: blackjax.mcmc.mclmc.build_kernel( + kernel = lambda sqrt_diag_cov_mat: blackjax.mcmc.mclmc.build_kernel( logdensity_fn=model.logdensity_fn, integrator=integrator, - std_mat=std_mat, + sqrt_diag_cov_mat=sqrt_diag_cov_mat, ) ( @@ -328,13 +328,13 @@ def get_std_mat(): diagonal_preconditioning=True, ) - return blackjax_mclmc_sampler_params.std_mat + return blackjax_mclmc_sampler_params.sqrt_diag_cov_mat - std_mat = get_std_mat() + sqrt_diag_cov_mat = get_sqrt_diag_cov_mat() assert ( jnp.abs( jnp.dot( - (std_mat**2) / jnp.linalg.norm(std_mat**2), + (sqrt_diag_cov_mat**2) / jnp.linalg.norm(sqrt_diag_cov_mat**2), eigs / jnp.linalg.norm(eigs), ) - 1