Skip to content

Commit

Permalink
MERGE
Browse files Browse the repository at this point in the history
  • Loading branch information
reubenharry committed May 20, 2024
2 parents a26d4a0 + e0a7f9e commit 4e2b7c0
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 8 deletions.
10 changes: 4 additions & 6 deletions blackjax/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@ def one_step(average_and_state, xs, return_state):
_, rng_key = xs
average, state = average_and_state
state, info = inference_algorithm.step(rng_key, state)
average = streaming_average(expectation(state), average)
average = streaming_average(expectation(transform(state)), average)
if return_state:
return (average, state), (transform(state), info)
else:
Expand All @@ -222,7 +222,7 @@ def one_step(average_and_state, xs, return_state):

xs = (jnp.arange(num_steps), keys)
((_, average), final_state), history = lax.scan(
one_step, ((0, expectation(initial_state)), initial_state), xs
one_step, ((0, expectation(transform(initial_state))), initial_state), xs
)

if not return_state_history:
Expand All @@ -236,10 +236,8 @@ def streaming_average(expectation, streaming_avg, weight=1.0, zero_prevention=0.
"""Compute the streaming average of a function O(x) using a weight.
Parameters:
----------
O
function to be averaged
x
current state
expectation
the value of the expectation at the current timestep
streaming_avg
tuple of (total, average) where total is the sum of weights and average is the current average
weight
Expand Down
4 changes: 2 additions & 2 deletions tests/test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def logdensity_fn(x):
inference_algorithm=alg,
num_steps=50,
progress_bar=False,
expectation=lambda x: x.position,
expectation=lambda x: x,
transform=lambda x: x.position,
return_state_history=True,
)
Expand All @@ -66,7 +66,7 @@ def logdensity_fn(x):
inference_algorithm=alg,
num_steps=50,
progress_bar=False,
expectation=lambda x: x.position,
expectation=lambda x: x,
transform=lambda x: x.position,
return_state_history=False,
)
Expand Down

0 comments on commit 4e2b7c0

Please sign in to comment.