From 178b452b77548b349140339dfb88af0fd80b380b Mon Sep 17 00:00:00 2001 From: = Date: Sun, 19 May 2024 17:49:15 +0200 Subject: [PATCH 1/2] RENAME O --- blackjax/util.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/blackjax/util.py b/blackjax/util.py index 070ca8687..600a7a961 100644 --- a/blackjax/util.py +++ b/blackjax/util.py @@ -209,7 +209,7 @@ def one_step(average_and_state, xs, return_state): _, rng_key = xs average, state = average_and_state state, info = inference_algorithm.step(rng_key, state) - average = streaming_average(expectation, state, average) + average = streaming_average(expectation(state), average) if return_state: return (average, state), (transform(state), info) else: @@ -232,7 +232,7 @@ def one_step(average_and_state, xs, return_state): return transform(final_state), state_history, info_history -def streaming_average(O, x, streaming_avg, weight=1.0, zero_prevention=0.0): +def streaming_average(expectation, streaming_avg, weight=1.0, zero_prevention=0.0): """Compute the streaming average of a function O(x) using a weight. Parameters: ---------- @@ -251,7 +251,6 @@ def streaming_average(O, x, streaming_avg, weight=1.0, zero_prevention=0.0): new streaming average """ - expectation = O(x) flat_expectation, unravel_fn = ravel_pytree(expectation) total, average = streaming_avg flat_average, _ = ravel_pytree(average) From a26d4a002b85c75a4df353e04b27282cc9fa6dbc Mon Sep 17 00:00:00 2001 From: = Date: Sun, 19 May 2024 17:53:22 +0200 Subject: [PATCH 2/2] UPDATE STREAMING AVG --- blackjax/adaptation/mclmc_adaptation.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/blackjax/adaptation/mclmc_adaptation.py b/blackjax/adaptation/mclmc_adaptation.py index dc33eb21c..27321321a 100644 --- a/blackjax/adaptation/mclmc_adaptation.py +++ b/blackjax/adaptation/mclmc_adaptation.py @@ -177,10 +177,10 @@ def step(iteration_state, weight_and_key): state, params, adaptive_state, rng_key ) + x = ravel_pytree(state.position)[0] # update the running average of x, x^2 streaming_avg = streaming_average( - O=lambda x: jnp.array([x, jnp.square(x)]), - x=ravel_pytree(state.position)[0], + expectation=jnp.array([x, jnp.square(x)]), streaming_avg=streaming_avg, weight=(1 - mask) * success * params.step_size, zero_prevention=mask,