Skip to content

Commit

Permalink
UPDATE STREAMING AVG
Browse files Browse the repository at this point in the history
  • Loading branch information
reubenharry committed May 19, 2024
1 parent 0ab0694 commit 9dd740f
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions blackjax/adaptation/mclmc_adaptation.py
Original file line number Diff line number Diff line change
Expand Up @@ -460,10 +460,10 @@ def step(iteration_state, weight_and_key):
adaptive_state = adaptive_state._replace(log_step_size=jnp.log(step_size))
# step_size = 1e-3

x = ravel_pytree(state.position)[0]
# update the running average of x, x^2
streaming_avg = streaming_average(
O=lambda x: jnp.array([x, jnp.square(x)]),
x=ravel_pytree(state.position)[0],
expectation=jnp.array([x, jnp.square(x)]),
streaming_avg=streaming_avg,
weight=(1 - mask) * success * step_size,
zero_prevention=mask,
Expand Down

0 comments on commit 9dd740f

Please sign in to comment.