diff --git a/blackjax/util.py b/blackjax/util.py index 600a7a961..02c27e51c 100644 --- a/blackjax/util.py +++ b/blackjax/util.py @@ -209,7 +209,7 @@ def one_step(average_and_state, xs, return_state): _, rng_key = xs average, state = average_and_state state, info = inference_algorithm.step(rng_key, state) - average = streaming_average(expectation(state), average) + average = streaming_average(expectation(transform(state)), average) if return_state: return (average, state), (transform(state), info) else: @@ -222,7 +222,7 @@ def one_step(average_and_state, xs, return_state): xs = (jnp.arange(num_steps), keys) ((_, average), final_state), history = lax.scan( - one_step, ((0, expectation(initial_state)), initial_state), xs + one_step, ((0, expectation(transform(initial_state))), initial_state), xs ) if not return_state_history: @@ -236,10 +236,8 @@ def streaming_average(expectation, streaming_avg, weight=1.0, zero_prevention=0. """Compute the streaming average of a function O(x) using a weight. Parameters: ---------- - O - function to be averaged - x - current state + expectation + the value of the expectation at the current timestep streaming_avg tuple of (total, average) where total is the sum of weights and average is the current average weight diff --git a/tests/test_util.py b/tests/test_util.py index 83955acd7..1f03498dd 100644 --- a/tests/test_util.py +++ b/tests/test_util.py @@ -55,7 +55,7 @@ def logdensity_fn(x): inference_algorithm=alg, num_steps=50, progress_bar=False, - expectation=lambda x: x.position, + expectation=lambda x: x, transform=lambda x: x.position, return_state_history=True, ) @@ -66,7 +66,7 @@ def logdensity_fn(x): inference_algorithm=alg, num_steps=50, progress_bar=False, - expectation=lambda x: x.position, + expectation=lambda x: x, transform=lambda x: x.position, return_state_history=False, )