Skip to content

Commit

Permalink
Merge branch 'main' into ciguaran_split_tempered_from_mcmc_construction
Browse files Browse the repository at this point in the history
  • Loading branch information
ciguaran committed Aug 27, 2024
2 parents 0942087 + 8a9b546 commit a9a4d0c
Show file tree
Hide file tree
Showing 5 changed files with 34 additions and 14 deletions.
29 changes: 22 additions & 7 deletions blackjax/adaptation/mclmc_adaptation.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from jax.flatten_util import ravel_pytree

from blackjax.diagnostics import effective_sample_size
from blackjax.util import incremental_value_update, pytree_size
from blackjax.util import generate_unit_vector, incremental_value_update, pytree_size


class MCLMCAdaptationState(NamedTuple):
Expand Down Expand Up @@ -147,20 +147,24 @@ def predictor(previous_state, params, adaptive_state, rng_key):

time, x_average, step_size_max = adaptive_state

rng_key, nan_key = jax.random.split(rng_key)

# dynamics
next_state, info = kernel(params.sqrt_diag_cov)(
rng_key=rng_key,
state=previous_state,
L=params.L,
step_size=params.step_size,
)

# step updating
success, state, step_size_max, energy_change = handle_nans(
previous_state,
next_state,
params.step_size,
step_size_max,
info.energy_change,
nan_key,
)

# Warning: var = 0 if there were nans, but we will give it a very small weight
Expand Down Expand Up @@ -202,8 +206,7 @@ def step(iteration_state, weight_and_key):
streaming_avg = incremental_value_update(
expectation=jnp.array([x, jnp.square(x)]),
incremental_val=streaming_avg,
weight=(1 - mask) * success * params.step_size,
zero_prevention=mask,
weight=mask * success * params.step_size,
)

return (state, params, adaptive_state, streaming_avg), None
Expand Down Expand Up @@ -233,7 +236,7 @@ def L_step_size_adaptation(state, params, num_steps, rng_key):
)

# we use the last num_steps2 to compute the diagonal preconditioner
mask = 1 - jnp.concatenate((jnp.zeros(num_steps1), jnp.ones(num_steps2)))
mask = jnp.concatenate((jnp.zeros(num_steps1), jnp.ones(num_steps2)))

# run the steps
state, params, _, (_, average) = run_steps(
Expand All @@ -243,7 +246,7 @@ def L_step_size_adaptation(state, params, num_steps, rng_key):
L = params.L
# determine L
sqrt_diag_cov = params.sqrt_diag_cov
if num_steps2 != 0.0:
if num_steps2 > 1:
x_average, x_squared_average = average[0], average[1]
variances = x_squared_average - jnp.square(x_average)
L = jnp.sqrt(jnp.sum(variances))
Expand Down Expand Up @@ -298,16 +301,28 @@ def step(state, key):
return adaptation_L


def handle_nans(previous_state, next_state, step_size, step_size_max, kinetic_change):
def handle_nans(
previous_state, next_state, step_size, step_size_max, kinetic_change, key
):
"""if there are nans, let's reduce the stepsize, and not update the state. The
function returns the old state in this case."""

reduced_step_size = 0.8
p, unravel_fn = ravel_pytree(next_state.position)
nonans = jnp.all(jnp.isfinite(p))
q, unravel_fn = ravel_pytree(next_state.momentum)
nonans = jnp.logical_and(jnp.all(jnp.isfinite(p)), jnp.all(jnp.isfinite(q)))
state, step_size, kinetic_change = jax.tree_util.tree_map(
lambda new, old: jax.lax.select(nonans, jnp.nan_to_num(new), old),
(next_state, step_size_max, kinetic_change),
(previous_state, step_size * reduced_step_size, 0.0),
)

state = jax.lax.cond(
jnp.isnan(next_state.logdensity),
lambda: state._replace(
momentum=generate_unit_vector(key, previous_state.position)
),
lambda: state,
)

return nonans, state, step_size, kinetic_change
1 change: 1 addition & 0 deletions blackjax/mcmc/integrators.py
Original file line number Diff line number Diff line change
Expand Up @@ -435,6 +435,7 @@ def stochastic_integrator(init_state, step_size, L_proposal, rng_key):
)
# one step of the deterministic dynamics
state, info = integrator(state, step_size)

# partial refreshment
state = state._replace(
momentum=partially_refresh_momentum(
Expand Down
6 changes: 2 additions & 4 deletions blackjax/smc/waste_free.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,10 +52,8 @@ def update(rng_key, position, step_parameters):
# step particles is num_resmapled, num_mcmc_steps, dimension_of_variable
# want to transformed into num_resampled * num_mcmc_steps, dimension of variable
def reshape_step_particles(x):
if len(x.shape) > 2:
return x.reshape((x.shape[0] * x.shape[1], -1))
else:
return x.flatten()
_num_resampled, num_mcmc_steps, *dimension_of_variable = x.shape
return x.reshape((_num_resampled * num_mcmc_steps, *dimension_of_variable))

step_particles = jax.tree.map(reshape_step_particles, states.position)
new_particles = jax.tree.map(
Expand Down
9 changes: 7 additions & 2 deletions blackjax/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,10 @@ def transform(state_and_incremental_val, info):
return SamplingAlgorithm(init_fn, update_fn), transform


def safediv(x, y):
return jnp.where(x == 0.0, 0.0, x / y)


def incremental_value_update(
expectation, incremental_val, weight=1.0, zero_prevention=0.0
):
Expand All @@ -302,8 +306,9 @@ def incremental_value_update(

total, average = incremental_val
average = tree_map(
lambda exp, av: (total * av + weight * exp)
/ (total + weight + zero_prevention),
lambda exp, av: safediv(
total * av + weight * exp, (total + weight + zero_prevention)
),
expectation,
average,
)
Expand Down
3 changes: 2 additions & 1 deletion tests/smc/test_smc.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@

import blackjax
import blackjax.smc.resampling as resampling
from blackjax.smc.base import extend_params, init, step, update_and_take_last
from blackjax.smc.base import extend_params, init, step
from blackjax.smc.tempered import update_and_take_last
from blackjax.smc.waste_free import update_waste_free


Expand Down

0 comments on commit a9a4d0c

Please sign in to comment.