Skip to content

Commit

Permalink
RENAME STD_MAT
Browse files Browse the repository at this point in the history
  • Loading branch information
reubenharry committed May 24, 2024
1 parent 6bacb6c commit 9c2fea7
Show file tree
Hide file tree
Showing 5 changed files with 29 additions and 29 deletions.
18 changes: 9 additions & 9 deletions blackjax/adaptation/mclmc_adaptation.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,13 @@ class MCLMCAdaptationState(NamedTuple):
The momentum decoherent rate for the MCLMC algorithm.
step_size
The step size used for the MCLMC algorithm.
sqrt_diag_cov_mat
sqrt_diag_cov
A matrix used for preconditioning.
"""

L: float
step_size: float
sqrt_diag_cov_mat: float
sqrt_diag_cov: float


def mclmc_find_L_and_step_size(
Expand Down Expand Up @@ -104,7 +104,7 @@ def mclmc_find_L_and_step_size(
"""
dim = pytree_size(state.position)
params = MCLMCAdaptationState(
jnp.sqrt(dim), jnp.sqrt(dim) * 0.25, sqrt_diag_cov_mat=jnp.ones((dim,))
jnp.sqrt(dim), jnp.sqrt(dim) * 0.25, sqrt_diag_cov=jnp.ones((dim,))
)
part1_key, part2_key = jax.random.split(rng_key, 2)

Expand All @@ -121,7 +121,7 @@ def mclmc_find_L_and_step_size(

if frac_tune3 != 0:
state, params = make_adaptation_L(
mclmc_kernel(params.sqrt_diag_cov_mat), frac=frac_tune3, Lfactor=0.4
mclmc_kernel(params.sqrt_diag_cov), frac=frac_tune3, Lfactor=0.4
)(state, params, num_steps, part2_key)

return state, params
Expand All @@ -148,7 +148,7 @@ def predictor(previous_state, params, adaptive_state, rng_key):
time, x_average, step_size_max = adaptive_state

# dynamics
next_state, info = kernel(params.sqrt_diag_cov_mat)(
next_state, info = kernel(params.sqrt_diag_cov)(
rng_key=rng_key,
state=previous_state,
L=params.L,
Expand Down Expand Up @@ -242,15 +242,15 @@ def L_step_size_adaptation(state, params, num_steps, rng_key):

L = params.L
# determine L
sqrt_diag_cov_mat = params.sqrt_diag_cov_mat
sqrt_diag_cov = params.sqrt_diag_cov
if num_steps2 != 0.0:
x_average, x_squared_average = average[0], average[1]
variances = x_squared_average - jnp.square(x_average)
L = jnp.sqrt(jnp.sum(variances))

if diagonal_preconditioning:
sqrt_diag_cov_mat = jnp.sqrt(variances)
params = params._replace(sqrt_diag_cov_mat=sqrt_diag_cov_mat)
sqrt_diag_cov = jnp.sqrt(variances)
params = params._replace(sqrt_diag_cov=sqrt_diag_cov)
L = jnp.sqrt(dim)

# readjust the stepsize
Expand All @@ -260,7 +260,7 @@ def L_step_size_adaptation(state, params, num_steps, rng_key):
xs=(jnp.ones(steps), keys), state=state, params=params
)

return state, MCLMCAdaptationState(L, params.step_size, sqrt_diag_cov_mat)
return state, MCLMCAdaptationState(L, params.step_size, sqrt_diag_cov)

return L_step_size_adaptation

Expand Down
10 changes: 5 additions & 5 deletions blackjax/mcmc/integrators.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,7 +294,7 @@ def _normalized_flatten_array(x, tol=1e-13):
return jnp.where(norm > tol, x / norm, x), norm


def esh_dynamics_momentum_update_one_step(sqrt_diag_cov_mat=1.0):
def esh_dynamics_momentum_update_one_step(sqrt_diag_cov=1.0):
def update(
momentum: ArrayTree,
logdensity_grad: ArrayTree,
Expand All @@ -313,7 +313,7 @@ def update(

logdensity_grad = logdensity_grad
flatten_grads, unravel_fn = ravel_pytree(logdensity_grad)
flatten_grads = flatten_grads * sqrt_diag_cov_mat
flatten_grads = flatten_grads * sqrt_diag_cov
flatten_momentum, _ = ravel_pytree(momentum)
dims = flatten_momentum.shape[0]
normalized_gradient, gradient_norm = _normalized_flatten_array(flatten_grads)
Expand All @@ -325,7 +325,7 @@ def update(
+ 2 * zeta * flatten_momentum
)
new_momentum_normalized, _ = _normalized_flatten_array(new_momentum_raw)
gr = unravel_fn(new_momentum_normalized * sqrt_diag_cov_mat)
gr = unravel_fn(new_momentum_normalized * sqrt_diag_cov)
next_momentum = unravel_fn(new_momentum_normalized)
kinetic_energy_change = (
delta
Expand Down Expand Up @@ -359,10 +359,10 @@ def generate_isokinetic_integrator(coefficients):
def isokinetic_integrator(
logdensity_fn: Callable, *args, **kwargs
) -> GeneralIntegrator:
sqrt_diag_cov_mat = kwargs.get("sqrt_diag_cov_mat", 1.0)
sqrt_diag_cov = kwargs.get("sqrt_diag_cov", 1.0)
position_update_fn = euclidean_position_update_fn(logdensity_fn)
one_step = generalized_two_stage_integrator(
esh_dynamics_momentum_update_one_step(sqrt_diag_cov_mat),
esh_dynamics_momentum_update_one_step(sqrt_diag_cov),
position_update_fn,
coefficients,
format_output_fn=format_isokinetic_state_output,
Expand Down
8 changes: 4 additions & 4 deletions blackjax/mcmc/mclmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def init(position: ArrayLike, logdensity_fn, rng_key):
)


def build_kernel(logdensity_fn, sqrt_diag_cov_mat, integrator):
def build_kernel(logdensity_fn, sqrt_diag_cov, integrator):
"""Build a HMC kernel.
Parameters
Expand All @@ -80,7 +80,7 @@ def build_kernel(logdensity_fn, sqrt_diag_cov_mat, integrator):
"""

step = with_isokinetic_maruyama(integrator(logdensity_fn, sqrt_diag_cov_mat))
step = with_isokinetic_maruyama(integrator(logdensity_fn, sqrt_diag_cov))

def kernel(
rng_key: PRNGKey, state: IntegratorState, L: float, step_size: float
Expand All @@ -105,7 +105,7 @@ def as_top_level_api(
L,
step_size,
integrator=isokinetic_mclachlan,
sqrt_diag_cov_mat=1.0,
sqrt_diag_cov=1.0,
) -> SamplingAlgorithm:
"""The general mclmc kernel builder (:meth:`blackjax.mcmc.mclmc.build_kernel`, alias `blackjax.mclmc.build_kernel`) can be
cumbersome to manipulate. Since most users only need to specify the kernel
Expand Down Expand Up @@ -153,7 +153,7 @@ def as_top_level_api(
A ``SamplingAlgorithm``.
"""

kernel = build_kernel(logdensity_fn, sqrt_diag_cov_mat, integrator)
kernel = build_kernel(logdensity_fn, sqrt_diag_cov, integrator)

def init_fn(position: ArrayLike, rng_key: PRNGKey):
return init(position, logdensity_fn, rng_key)
Expand Down
4 changes: 2 additions & 2 deletions tests/mcmc/test_integrators.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,7 @@ def test_esh_momentum_update(self, dims):

# Efficient implementation
update_stable = self.variant(
esh_dynamics_momentum_update_one_step(sqrt_diag_cov_mat=1.0)
esh_dynamics_momentum_update_one_step(sqrt_diag_cov=1.0)
)
next_momentum1, *_ = update_stable(momentum, gradient, step_size, 1.0)
np.testing.assert_array_almost_equal(next_momentum, next_momentum1)
Expand All @@ -260,7 +260,7 @@ def test_isokinetic_leapfrog(self):
next_state, kinetic_energy_change = step(initial_state, step_size)

# explicit integration
op1 = esh_dynamics_momentum_update_one_step(sqrt_diag_cov_mat=1.0)
op1 = esh_dynamics_momentum_update_one_step(sqrt_diag_cov=1.0)
op2 = integrators.euclidean_position_update_fn(logdensity_fn)
position, momentum, _, logdensity_grad = initial_state
momentum, kinetic_grad, kinetic_energy_change0 = op1(
Expand Down
18 changes: 9 additions & 9 deletions tests/mcmc/test_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,10 +111,10 @@ def run_mclmc(
position=initial_position, logdensity_fn=logdensity_fn, rng_key=init_key
)

kernel = lambda sqrt_diag_cov_mat: blackjax.mcmc.mclmc.build_kernel(
kernel = lambda sqrt_diag_cov: blackjax.mcmc.mclmc.build_kernel(
logdensity_fn=logdensity_fn,
integrator=blackjax.mcmc.mclmc.isokinetic_mclachlan,
sqrt_diag_cov_mat=sqrt_diag_cov_mat,
sqrt_diag_cov=sqrt_diag_cov,
)

(
Expand All @@ -132,7 +132,7 @@ def run_mclmc(
logdensity_fn,
L=blackjax_mclmc_sampler_params.L,
step_size=blackjax_mclmc_sampler_params.step_size,
sqrt_diag_cov_mat=blackjax_mclmc_sampler_params.sqrt_diag_cov_mat,
sqrt_diag_cov=blackjax_mclmc_sampler_params.sqrt_diag_cov,
)

_, samples, _ = run_inference_algorithm(
Expand Down Expand Up @@ -300,7 +300,7 @@ def __init__(self, d, condition_number):

integrator = isokinetic_mclachlan

def get_sqrt_diag_cov_mat():
def get_sqrt_diag_cov():
init_key, tune_key = jax.random.split(key)

initial_position = model.sample_init(init_key)
Expand All @@ -311,10 +311,10 @@ def get_sqrt_diag_cov_mat():
rng_key=init_key,
)

kernel = lambda sqrt_diag_cov_mat: blackjax.mcmc.mclmc.build_kernel(
kernel = lambda sqrt_diag_cov: blackjax.mcmc.mclmc.build_kernel(
logdensity_fn=model.logdensity_fn,
integrator=integrator,
sqrt_diag_cov_mat=sqrt_diag_cov_mat,
sqrt_diag_cov=sqrt_diag_cov,
)

(
Expand All @@ -328,13 +328,13 @@ def get_sqrt_diag_cov_mat():
diagonal_preconditioning=True,
)

return blackjax_mclmc_sampler_params.sqrt_diag_cov_mat
return blackjax_mclmc_sampler_params.sqrt_diag_cov

sqrt_diag_cov_mat = get_sqrt_diag_cov_mat()
sqrt_diag_cov = get_sqrt_diag_cov()
assert (
jnp.abs(
jnp.dot(
(sqrt_diag_cov_mat**2) / jnp.linalg.norm(sqrt_diag_cov_mat**2),
(sqrt_diag_cov**2) / jnp.linalg.norm(sqrt_diag_cov**2),
eigs / jnp.linalg.norm(eigs),
)
- 1
Expand Down

0 comments on commit 9c2fea7

Please sign in to comment.