Skip to content

Commit

Permalink
Update index.md (#711)
Browse files Browse the repository at this point in the history
The jitted step remained unused, leading to the example running with an uncompiled nuts.step. 

Changing this reduces the execution time by a factor of 30 on my system and showcases blackjax' speed.
  • Loading branch information
johannahaffner authored Jul 31, 2024
1 parent f8db9aa commit 441412a
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ rng_key = jax.random.key(0)
step = jax.jit(nuts.step)
for i in range(1_000):
nuts_key = jax.random.fold_in(rng_key, i)
state, _ = nuts.step(nuts_key, state)
state, _ = step(nuts_key, state)
```

:::{note}
Expand Down

0 comments on commit 441412a

Please sign in to comment.