diff --git a/README.md b/README.md index 9590b4cd6..06d5b46cf 100644 --- a/README.md +++ b/README.md @@ -76,8 +76,8 @@ state = nuts.init(initial_position) # Iterate rng_key = jax.random.key(0) step = jax.jit(nuts.step) -for step in range(100): - nuts_key = jax.random.fold_in(rng_key, step) +for i in range(100): + nuts_key = jax.random.fold_in(rng_key, i) state, _ = step(nuts_key, state) ```