From 9c2fea78f43d81d1b4cf594dbe348899245fffbc Mon Sep 17 00:00:00 2001 From: = Date: Fri, 24 May 2024 18:05:03 +0200 Subject: [PATCH] RENAME STD_MAT --- blackjax/adaptation/mclmc_adaptation.py | 18 +++++++++--------- blackjax/mcmc/integrators.py | 10 +++++----- blackjax/mcmc/mclmc.py | 8 ++++---- tests/mcmc/test_integrators.py | 4 ++-- tests/mcmc/test_sampling.py | 18 +++++++++--------- 5 files changed, 29 insertions(+), 29 deletions(-) diff --git a/blackjax/adaptation/mclmc_adaptation.py b/blackjax/adaptation/mclmc_adaptation.py index 73fa6a327..b1b012c70 100644 --- a/blackjax/adaptation/mclmc_adaptation.py +++ b/blackjax/adaptation/mclmc_adaptation.py @@ -30,13 +30,13 @@ class MCLMCAdaptationState(NamedTuple): The momentum decoherent rate for the MCLMC algorithm. step_size The step size used for the MCLMC algorithm. - sqrt_diag_cov_mat + sqrt_diag_cov A matrix used for preconditioning. """ L: float step_size: float - sqrt_diag_cov_mat: float + sqrt_diag_cov: float def mclmc_find_L_and_step_size( @@ -104,7 +104,7 @@ def mclmc_find_L_and_step_size( """ dim = pytree_size(state.position) params = MCLMCAdaptationState( - jnp.sqrt(dim), jnp.sqrt(dim) * 0.25, sqrt_diag_cov_mat=jnp.ones((dim,)) + jnp.sqrt(dim), jnp.sqrt(dim) * 0.25, sqrt_diag_cov=jnp.ones((dim,)) ) part1_key, part2_key = jax.random.split(rng_key, 2) @@ -121,7 +121,7 @@ def mclmc_find_L_and_step_size( if frac_tune3 != 0: state, params = make_adaptation_L( - mclmc_kernel(params.sqrt_diag_cov_mat), frac=frac_tune3, Lfactor=0.4 + mclmc_kernel(params.sqrt_diag_cov), frac=frac_tune3, Lfactor=0.4 )(state, params, num_steps, part2_key) return state, params @@ -148,7 +148,7 @@ def predictor(previous_state, params, adaptive_state, rng_key): time, x_average, step_size_max = adaptive_state # dynamics - next_state, info = kernel(params.sqrt_diag_cov_mat)( + next_state, info = kernel(params.sqrt_diag_cov)( rng_key=rng_key, state=previous_state, L=params.L, @@ -242,15 +242,15 @@ def L_step_size_adaptation(state, params, num_steps, rng_key): L = params.L # determine L - sqrt_diag_cov_mat = params.sqrt_diag_cov_mat + sqrt_diag_cov = params.sqrt_diag_cov if num_steps2 != 0.0: x_average, x_squared_average = average[0], average[1] variances = x_squared_average - jnp.square(x_average) L = jnp.sqrt(jnp.sum(variances)) if diagonal_preconditioning: - sqrt_diag_cov_mat = jnp.sqrt(variances) - params = params._replace(sqrt_diag_cov_mat=sqrt_diag_cov_mat) + sqrt_diag_cov = jnp.sqrt(variances) + params = params._replace(sqrt_diag_cov=sqrt_diag_cov) L = jnp.sqrt(dim) # readjust the stepsize @@ -260,7 +260,7 @@ def L_step_size_adaptation(state, params, num_steps, rng_key): xs=(jnp.ones(steps), keys), state=state, params=params ) - return state, MCLMCAdaptationState(L, params.step_size, sqrt_diag_cov_mat) + return state, MCLMCAdaptationState(L, params.step_size, sqrt_diag_cov) return L_step_size_adaptation diff --git a/blackjax/mcmc/integrators.py b/blackjax/mcmc/integrators.py index 2dce5671e..1cc698e8f 100644 --- a/blackjax/mcmc/integrators.py +++ b/blackjax/mcmc/integrators.py @@ -294,7 +294,7 @@ def _normalized_flatten_array(x, tol=1e-13): return jnp.where(norm > tol, x / norm, x), norm -def esh_dynamics_momentum_update_one_step(sqrt_diag_cov_mat=1.0): +def esh_dynamics_momentum_update_one_step(sqrt_diag_cov=1.0): def update( momentum: ArrayTree, logdensity_grad: ArrayTree, @@ -313,7 +313,7 @@ def update( logdensity_grad = logdensity_grad flatten_grads, unravel_fn = ravel_pytree(logdensity_grad) - flatten_grads = flatten_grads * sqrt_diag_cov_mat + flatten_grads = flatten_grads * sqrt_diag_cov flatten_momentum, _ = ravel_pytree(momentum) dims = flatten_momentum.shape[0] normalized_gradient, gradient_norm = _normalized_flatten_array(flatten_grads) @@ -325,7 +325,7 @@ def update( + 2 * zeta * flatten_momentum ) new_momentum_normalized, _ = _normalized_flatten_array(new_momentum_raw) - gr = unravel_fn(new_momentum_normalized * sqrt_diag_cov_mat) + gr = unravel_fn(new_momentum_normalized * sqrt_diag_cov) next_momentum = unravel_fn(new_momentum_normalized) kinetic_energy_change = ( delta @@ -359,10 +359,10 @@ def generate_isokinetic_integrator(coefficients): def isokinetic_integrator( logdensity_fn: Callable, *args, **kwargs ) -> GeneralIntegrator: - sqrt_diag_cov_mat = kwargs.get("sqrt_diag_cov_mat", 1.0) + sqrt_diag_cov = kwargs.get("sqrt_diag_cov", 1.0) position_update_fn = euclidean_position_update_fn(logdensity_fn) one_step = generalized_two_stage_integrator( - esh_dynamics_momentum_update_one_step(sqrt_diag_cov_mat), + esh_dynamics_momentum_update_one_step(sqrt_diag_cov), position_update_fn, coefficients, format_output_fn=format_isokinetic_state_output, diff --git a/blackjax/mcmc/mclmc.py b/blackjax/mcmc/mclmc.py index d841f64e3..27b5c2e9c 100644 --- a/blackjax/mcmc/mclmc.py +++ b/blackjax/mcmc/mclmc.py @@ -60,7 +60,7 @@ def init(position: ArrayLike, logdensity_fn, rng_key): ) -def build_kernel(logdensity_fn, sqrt_diag_cov_mat, integrator): +def build_kernel(logdensity_fn, sqrt_diag_cov, integrator): """Build a HMC kernel. Parameters @@ -80,7 +80,7 @@ def build_kernel(logdensity_fn, sqrt_diag_cov_mat, integrator): """ - step = with_isokinetic_maruyama(integrator(logdensity_fn, sqrt_diag_cov_mat)) + step = with_isokinetic_maruyama(integrator(logdensity_fn, sqrt_diag_cov)) def kernel( rng_key: PRNGKey, state: IntegratorState, L: float, step_size: float @@ -105,7 +105,7 @@ def as_top_level_api( L, step_size, integrator=isokinetic_mclachlan, - sqrt_diag_cov_mat=1.0, + sqrt_diag_cov=1.0, ) -> SamplingAlgorithm: """The general mclmc kernel builder (:meth:`blackjax.mcmc.mclmc.build_kernel`, alias `blackjax.mclmc.build_kernel`) can be cumbersome to manipulate. Since most users only need to specify the kernel @@ -153,7 +153,7 @@ def as_top_level_api( A ``SamplingAlgorithm``. """ - kernel = build_kernel(logdensity_fn, sqrt_diag_cov_mat, integrator) + kernel = build_kernel(logdensity_fn, sqrt_diag_cov, integrator) def init_fn(position: ArrayLike, rng_key: PRNGKey): return init(position, logdensity_fn, rng_key) diff --git a/tests/mcmc/test_integrators.py b/tests/mcmc/test_integrators.py index 3439f52e6..937339aaf 100644 --- a/tests/mcmc/test_integrators.py +++ b/tests/mcmc/test_integrators.py @@ -235,7 +235,7 @@ def test_esh_momentum_update(self, dims): # Efficient implementation update_stable = self.variant( - esh_dynamics_momentum_update_one_step(sqrt_diag_cov_mat=1.0) + esh_dynamics_momentum_update_one_step(sqrt_diag_cov=1.0) ) next_momentum1, *_ = update_stable(momentum, gradient, step_size, 1.0) np.testing.assert_array_almost_equal(next_momentum, next_momentum1) @@ -260,7 +260,7 @@ def test_isokinetic_leapfrog(self): next_state, kinetic_energy_change = step(initial_state, step_size) # explicit integration - op1 = esh_dynamics_momentum_update_one_step(sqrt_diag_cov_mat=1.0) + op1 = esh_dynamics_momentum_update_one_step(sqrt_diag_cov=1.0) op2 = integrators.euclidean_position_update_fn(logdensity_fn) position, momentum, _, logdensity_grad = initial_state momentum, kinetic_grad, kinetic_energy_change0 = op1( diff --git a/tests/mcmc/test_sampling.py b/tests/mcmc/test_sampling.py index fb272ae7a..63eada8ac 100644 --- a/tests/mcmc/test_sampling.py +++ b/tests/mcmc/test_sampling.py @@ -111,10 +111,10 @@ def run_mclmc( position=initial_position, logdensity_fn=logdensity_fn, rng_key=init_key ) - kernel = lambda sqrt_diag_cov_mat: blackjax.mcmc.mclmc.build_kernel( + kernel = lambda sqrt_diag_cov: blackjax.mcmc.mclmc.build_kernel( logdensity_fn=logdensity_fn, integrator=blackjax.mcmc.mclmc.isokinetic_mclachlan, - sqrt_diag_cov_mat=sqrt_diag_cov_mat, + sqrt_diag_cov=sqrt_diag_cov, ) ( @@ -132,7 +132,7 @@ def run_mclmc( logdensity_fn, L=blackjax_mclmc_sampler_params.L, step_size=blackjax_mclmc_sampler_params.step_size, - sqrt_diag_cov_mat=blackjax_mclmc_sampler_params.sqrt_diag_cov_mat, + sqrt_diag_cov=blackjax_mclmc_sampler_params.sqrt_diag_cov, ) _, samples, _ = run_inference_algorithm( @@ -300,7 +300,7 @@ def __init__(self, d, condition_number): integrator = isokinetic_mclachlan - def get_sqrt_diag_cov_mat(): + def get_sqrt_diag_cov(): init_key, tune_key = jax.random.split(key) initial_position = model.sample_init(init_key) @@ -311,10 +311,10 @@ def get_sqrt_diag_cov_mat(): rng_key=init_key, ) - kernel = lambda sqrt_diag_cov_mat: blackjax.mcmc.mclmc.build_kernel( + kernel = lambda sqrt_diag_cov: blackjax.mcmc.mclmc.build_kernel( logdensity_fn=model.logdensity_fn, integrator=integrator, - sqrt_diag_cov_mat=sqrt_diag_cov_mat, + sqrt_diag_cov=sqrt_diag_cov, ) ( @@ -328,13 +328,13 @@ def get_sqrt_diag_cov_mat(): diagonal_preconditioning=True, ) - return blackjax_mclmc_sampler_params.sqrt_diag_cov_mat + return blackjax_mclmc_sampler_params.sqrt_diag_cov - sqrt_diag_cov_mat = get_sqrt_diag_cov_mat() + sqrt_diag_cov = get_sqrt_diag_cov() assert ( jnp.abs( jnp.dot( - (sqrt_diag_cov_mat**2) / jnp.linalg.norm(sqrt_diag_cov_mat**2), + (sqrt_diag_cov**2) / jnp.linalg.norm(sqrt_diag_cov**2), eigs / jnp.linalg.norm(eigs), ) - 1