Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

NaN Handling #727

Merged
merged 6 commits into from
Aug 26, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 18 additions & 4 deletions blackjax/adaptation/mclmc_adaptation.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from jax.flatten_util import ravel_pytree

from blackjax.diagnostics import effective_sample_size
from blackjax.util import incremental_value_update, pytree_size
from blackjax.util import generate_unit_vector, incremental_value_update, pytree_size


class MCLMCAdaptationState(NamedTuple):
Expand Down Expand Up @@ -147,6 +147,8 @@ def predictor(previous_state, params, adaptive_state, rng_key):

time, x_average, step_size_max = adaptive_state

rng_key, nan_key = jax.random.split(rng_key)

# dynamics
next_state, info = kernel(params.sqrt_diag_cov)(
rng_key=rng_key,
Expand All @@ -162,6 +164,7 @@ def predictor(previous_state, params, adaptive_state, rng_key):
params.step_size,
step_size_max,
info.energy_change,
nan_key,
)

# Warning: var = 0 if there were nans, but we will give it a very small weight
Expand Down Expand Up @@ -203,7 +206,7 @@ def step(iteration_state, weight_and_key):
streaming_avg = incremental_value_update(
expectation=jnp.array([x, jnp.square(x)]),
incremental_val=streaming_avg,
weight=(1 - mask) * success * params.step_size,
weight=mask * success * params.step_size,
)

return (state, params, adaptive_state, streaming_avg), None
Expand Down Expand Up @@ -233,7 +236,7 @@ def L_step_size_adaptation(state, params, num_steps, rng_key):
)

# we use the last num_steps2 to compute the diagonal preconditioner
mask = 1 - jnp.concatenate((jnp.zeros(num_steps1), jnp.ones(num_steps2)))
mask = jnp.concatenate((jnp.zeros(num_steps1), jnp.ones(num_steps2)))

# run the steps
state, params, _, (_, average) = run_steps(
Expand Down Expand Up @@ -298,7 +301,9 @@ def step(state, key):
return adaptation_L


def handle_nans(previous_state, next_state, step_size, step_size_max, kinetic_change):
def handle_nans(
previous_state, next_state, step_size, step_size_max, kinetic_change, key
):
"""if there are nans, let's reduce the stepsize, and not update the state. The
function returns the old state in this case."""

Expand All @@ -311,4 +316,13 @@ def handle_nans(previous_state, next_state, step_size, step_size_max, kinetic_ch
(next_state, step_size_max, kinetic_change),
(previous_state, step_size * reduced_step_size, 0.0),
)

state = jax.lax.cond(
jnp.isnan(next_state.logdensity),
lambda: state._replace(
momentum=generate_unit_vector(key, previous_state.position)
),
lambda: state,
)

return nonans, state, step_size, kinetic_change
Loading