From 072cc81a67154c3bc4601b75ecec66da25da0899 Mon Sep 17 00:00:00 2001 From: Reuben Date: Sat, 24 Aug 2024 19:21:05 -0400 Subject: [PATCH] Bug fix (#724) * bug fix; first part * bug fix; first part * further debug * remove print statements --- blackjax/adaptation/mclmc_adaptation.py | 7 ++++--- blackjax/mcmc/integrators.py | 1 + blackjax/util.py | 9 +++++++-- 3 files changed, 12 insertions(+), 5 deletions(-) diff --git a/blackjax/adaptation/mclmc_adaptation.py b/blackjax/adaptation/mclmc_adaptation.py index 7645a890b..3365526b3 100644 --- a/blackjax/adaptation/mclmc_adaptation.py +++ b/blackjax/adaptation/mclmc_adaptation.py @@ -154,6 +154,7 @@ def predictor(previous_state, params, adaptive_state, rng_key): L=params.L, step_size=params.step_size, ) + # step updating success, state, step_size_max, energy_change = handle_nans( previous_state, @@ -203,7 +204,6 @@ def step(iteration_state, weight_and_key): expectation=jnp.array([x, jnp.square(x)]), incremental_val=streaming_avg, weight=(1 - mask) * success * params.step_size, - zero_prevention=mask, ) return (state, params, adaptive_state, streaming_avg), None @@ -243,7 +243,7 @@ def L_step_size_adaptation(state, params, num_steps, rng_key): L = params.L # determine L sqrt_diag_cov = params.sqrt_diag_cov - if num_steps2 != 0.0: + if num_steps2 > 1: x_average, x_squared_average = average[0], average[1] variances = x_squared_average - jnp.square(x_average) L = jnp.sqrt(jnp.sum(variances)) @@ -304,7 +304,8 @@ def handle_nans(previous_state, next_state, step_size, step_size_max, kinetic_ch reduced_step_size = 0.8 p, unravel_fn = ravel_pytree(next_state.position) - nonans = jnp.all(jnp.isfinite(p)) + q, unravel_fn = ravel_pytree(next_state.momentum) + nonans = jnp.logical_and(jnp.all(jnp.isfinite(p)), jnp.all(jnp.isfinite(q))) state, step_size, kinetic_change = jax.tree_util.tree_map( lambda new, old: jax.lax.select(nonans, jnp.nan_to_num(new), old), (next_state, step_size_max, kinetic_change), diff --git a/blackjax/mcmc/integrators.py b/blackjax/mcmc/integrators.py index 1d4b95a09..e9d19e3dc 100644 --- a/blackjax/mcmc/integrators.py +++ b/blackjax/mcmc/integrators.py @@ -435,6 +435,7 @@ def stochastic_integrator(init_state, step_size, L_proposal, rng_key): ) # one step of the deterministic dynamics state, info = integrator(state, step_size) + # partial refreshment state = state._replace( momentum=partially_refresh_momentum( diff --git a/blackjax/util.py b/blackjax/util.py index b6c5367b5..8cdcd45ee 100644 --- a/blackjax/util.py +++ b/blackjax/util.py @@ -281,6 +281,10 @@ def transform(state_and_incremental_val, info): return SamplingAlgorithm(init_fn, update_fn), transform +def safediv(x, y): + return jnp.where(x == 0.0, 0.0, x / y) + + def incremental_value_update( expectation, incremental_val, weight=1.0, zero_prevention=0.0 ): @@ -302,8 +306,9 @@ def incremental_value_update( total, average = incremental_val average = tree_map( - lambda exp, av: (total * av + weight * exp) - / (total + weight + zero_prevention), + lambda exp, av: safediv( + total * av + weight * exp, (total + weight + zero_prevention) + ), expectation, average, )