From 834f55d5fa6d5f76c78d31f3ac2c90b3fe2d4e25 Mon Sep 17 00:00:00 2001 From: Gilad Turok <36947659+gil2rok@users.noreply.github.com> Date: Tue, 13 Aug 2024 02:26:05 -0400 Subject: [PATCH] Harmonize Quickstart example (#717) --- README.md | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index a8d847cf9..9590b4cd6 100644 --- a/README.md +++ b/README.md @@ -75,9 +75,10 @@ state = nuts.init(initial_position) # Iterate rng_key = jax.random.key(0) +step = jax.jit(nuts.step) for step in range(100): nuts_key = jax.random.fold_in(rng_key, step) - state, _ = nuts.step(nuts_key, state) + state, _ = step(nuts_key, state) ``` See [the documentation](https://blackjax-devs.github.io/blackjax/index.html) for more examples of how to use the library: how to write inference loops for one or several chains, how to use the Stan warmup, etc.