Skip to content

Commit

Permalink
Bug fix (#724)
Browse files Browse the repository at this point in the history
* bug fix; first part

* bug fix; first part

* further debug

* remove print statements
  • Loading branch information
reubenharry committed Aug 24, 2024
1 parent 4a11236 commit 072cc81
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 5 deletions.
7 changes: 4 additions & 3 deletions blackjax/adaptation/mclmc_adaptation.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,7 @@ def predictor(previous_state, params, adaptive_state, rng_key):
L=params.L,
step_size=params.step_size,
)

# step updating
success, state, step_size_max, energy_change = handle_nans(
previous_state,
Expand Down Expand Up @@ -203,7 +204,6 @@ def step(iteration_state, weight_and_key):
expectation=jnp.array([x, jnp.square(x)]),
incremental_val=streaming_avg,
weight=(1 - mask) * success * params.step_size,
zero_prevention=mask,
)

return (state, params, adaptive_state, streaming_avg), None
Expand Down Expand Up @@ -243,7 +243,7 @@ def L_step_size_adaptation(state, params, num_steps, rng_key):
L = params.L
# determine L
sqrt_diag_cov = params.sqrt_diag_cov
if num_steps2 != 0.0:
if num_steps2 > 1:
x_average, x_squared_average = average[0], average[1]
variances = x_squared_average - jnp.square(x_average)
L = jnp.sqrt(jnp.sum(variances))
Expand Down Expand Up @@ -304,7 +304,8 @@ def handle_nans(previous_state, next_state, step_size, step_size_max, kinetic_ch

reduced_step_size = 0.8
p, unravel_fn = ravel_pytree(next_state.position)
nonans = jnp.all(jnp.isfinite(p))
q, unravel_fn = ravel_pytree(next_state.momentum)
nonans = jnp.logical_and(jnp.all(jnp.isfinite(p)), jnp.all(jnp.isfinite(q)))
state, step_size, kinetic_change = jax.tree_util.tree_map(
lambda new, old: jax.lax.select(nonans, jnp.nan_to_num(new), old),
(next_state, step_size_max, kinetic_change),
Expand Down
1 change: 1 addition & 0 deletions blackjax/mcmc/integrators.py
Original file line number Diff line number Diff line change
Expand Up @@ -435,6 +435,7 @@ def stochastic_integrator(init_state, step_size, L_proposal, rng_key):
)
# one step of the deterministic dynamics
state, info = integrator(state, step_size)

# partial refreshment
state = state._replace(
momentum=partially_refresh_momentum(
Expand Down
9 changes: 7 additions & 2 deletions blackjax/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,10 @@ def transform(state_and_incremental_val, info):
return SamplingAlgorithm(init_fn, update_fn), transform


def safediv(x, y):
return jnp.where(x == 0.0, 0.0, x / y)


def incremental_value_update(
expectation, incremental_val, weight=1.0, zero_prevention=0.0
):
Expand All @@ -302,8 +306,9 @@ def incremental_value_update(

total, average = incremental_val
average = tree_map(
lambda exp, av: (total * av + weight * exp)
/ (total + weight + zero_prevention),
lambda exp, av: safediv(
total * av + weight * exp, (total + weight + zero_prevention)
),
expectation,
average,
)
Expand Down

0 comments on commit 072cc81

Please sign in to comment.