Skip to content

Commit

Permalink
Merge branch 'new_integrator' into adjusted_mclmc
Browse files Browse the repository at this point in the history
  • Loading branch information
reubenharry committed May 19, 2024
2 parents 0a11a0f + 0ff1d24 commit 0ab0694
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 5 deletions.
4 changes: 2 additions & 2 deletions blackjax/adaptation/mclmc_adaptation.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,10 +181,10 @@ def step(iteration_state, weight_and_key):
state, params, adaptive_state, rng_key
)

x = ravel_pytree(state.position)[0]
# update the running average of x, x^2
streaming_avg = streaming_average(
O=lambda x: jnp.array([x, jnp.square(x)]),
x=ravel_pytree(state.position)[0],
expectation=jnp.array([x, jnp.square(x)]),
streaming_avg=streaming_avg,
weight=(1 - mask) * success * params.step_size,
zero_prevention=mask,
Expand Down
5 changes: 2 additions & 3 deletions blackjax/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@ def one_step(average_and_state, xs, return_state):
_, rng_key = xs
average, state = average_and_state
state, info = inference_algorithm.step(rng_key, state)
average = streaming_average(expectation, state, average)
average = streaming_average(expectation(state), average)
if return_state:
return (average, state), (transform(state), info)
else:
Expand All @@ -232,7 +232,7 @@ def one_step(average_and_state, xs, return_state):
return transform(final_state), state_history, info_history


def streaming_average(O, x, streaming_avg, weight=1.0, zero_prevention=0.0):
def streaming_average(expectation, streaming_avg, weight=1.0, zero_prevention=0.0):
"""Compute the streaming average of a function O(x) using a weight.
Parameters:
----------
Expand All @@ -251,7 +251,6 @@ def streaming_average(O, x, streaming_avg, weight=1.0, zero_prevention=0.0):
new streaming average
"""

expectation = O(x)
flat_expectation, unravel_fn = ravel_pytree(expectation)
total, average = streaming_avg
flat_average, _ = ravel_pytree(average)
Expand Down

0 comments on commit 0ab0694

Please sign in to comment.