Skip to content

Commit

Permalink
Merge branch 'ciguaran_waste_free_smc' of github.com:ciguaran/blackja…
Browse files Browse the repository at this point in the history
…x into ciguaran_waste_free_smc
  • Loading branch information
ciguaran committed Aug 16, 2024
2 parents c06b6ab + 02dccbe commit 3eb18cb
Show file tree
Hide file tree
Showing 10 changed files with 230 additions and 144 deletions.
7 changes: 4 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -75,9 +75,10 @@ state = nuts.init(initial_position)

# Iterate
rng_key = jax.random.key(0)
for step in range(100):
nuts_key = jax.random.fold_in(rng_key, step)
state, _ = nuts.step(nuts_key, state)
step = jax.jit(nuts.step)
for i in range(100):
nuts_key = jax.random.fold_in(rng_key, i)
state, _ = step(nuts_key, state)
```

See [the documentation](https://blackjax-devs.github.io/blackjax/index.html) for more examples of how to use the library: how to write inference loops for one or several chains, how to use the Stan warmup, etc.
Expand Down
8 changes: 4 additions & 4 deletions blackjax/adaptation/mclmc_adaptation.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from jax.flatten_util import ravel_pytree

from blackjax.diagnostics import effective_sample_size
from blackjax.util import pytree_size, streaming_average_update
from blackjax.util import incremental_value_update, pytree_size


class MCLMCAdaptationState(NamedTuple):
Expand Down Expand Up @@ -199,9 +199,9 @@ def step(iteration_state, weight_and_key):

x = ravel_pytree(state.position)[0]
# update the running average of x, x^2
streaming_avg = streaming_average_update(
current_value=jnp.array([x, jnp.square(x)]),
previous_weight_and_average=streaming_avg,
streaming_avg = incremental_value_update(
expectation=jnp.array([x, jnp.square(x)]),
incremental_val=streaming_avg,
weight=(1 - mask) * success * params.step_size,
zero_prevention=mask,
)
Expand Down
15 changes: 7 additions & 8 deletions blackjax/adaptation/window_adaptation.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
dual_averaging_adaptation,
)
from blackjax.base import AdaptationAlgorithm
from blackjax.progress_bar import progress_bar_scan
from blackjax.progress_bar import gen_scan_fn
from blackjax.types import Array, ArrayLikeTree, PRNGKey
from blackjax.util import pytree_size

Expand Down Expand Up @@ -333,17 +333,16 @@ def run(rng_key: PRNGKey, position: ArrayLikeTree, num_steps: int = 1000):

if progress_bar:
print("Running window adaptation")
one_step_ = jax.jit(progress_bar_scan(num_steps)(one_step))
else:
one_step_ = jax.jit(one_step)

scan_fn = gen_scan_fn(num_steps, progress_bar=progress_bar)
start_state = (init_state, init_adaptation_state)
keys = jax.random.split(rng_key, num_steps)
schedule = build_schedule(num_steps)
last_state, info = jax.lax.scan(
one_step_,
(init_state, init_adaptation_state),
last_state, info = scan_fn(
one_step,
start_state,
(jnp.arange(num_steps), keys, schedule),
)

last_chain_state, last_warmup_state, *_ = last_state

step_size, inverse_mass_matrix = adapt_final(last_warmup_state)
Expand Down
69 changes: 47 additions & 22 deletions blackjax/progress_bar.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,56 +14,64 @@
"""Progress bar decorators for use with step functions.
Adapted from Jeremie Coullon's blog post :cite:p:`progress_bar`.
"""
from threading import Lock

from fastprogress.fastprogress import progress_bar
from jax import lax
from jax.experimental import io_callback
from jax.numpy import array


def progress_bar_scan(num_samples, print_rate=None):
"Progress bar for a JAX scan"
progress_bars = {}
idx_counter = 0
lock = Lock()

if print_rate is None:
if num_samples > 20:
print_rate = int(num_samples / 20)
else:
print_rate = 1 # if you run the sampler for less than 20 iterations

def _define_bar(arg):
del arg
progress_bars[0] = progress_bar(range(num_samples))
progress_bars[0].update(0)
def _calc_chain_idx(iter_num):
nonlocal idx_counter
with lock:
idx = idx_counter
idx_counter += 1
return idx

def _update_bar(arg, chain_id):
chain_id = int(chain_id)
if arg == 0:
chain_id = _calc_chain_idx(arg)
progress_bars[chain_id] = progress_bar(range(num_samples))
progress_bars[chain_id].update(0)

def _update_bar(arg):
progress_bars[0].update_bar(arg + 1)
progress_bars[chain_id].update_bar(arg + 1)
return chain_id

def _close_bar(arg):
del arg
progress_bars[0].on_iter_end()
def _close_bar(arg, chain_id):
progress_bars[int(chain_id)].on_iter_end()

def _update_progress_bar(iter_num):
def _update_progress_bar(iter_num, chain_id):
"Updates progress bar of a JAX scan or loop"
_ = lax.cond(
iter_num == 0,
lambda _: io_callback(_define_bar, None, iter_num),
lambda _: None,
operand=None,
)

_ = lax.cond(
chain_id = lax.cond(
# update every multiple of `print_rate` except at the end
(iter_num % print_rate == 0) | (iter_num == (num_samples - 1)),
lambda _: io_callback(_update_bar, None, iter_num),
lambda _: None,
lambda _: io_callback(_update_bar, array(0), iter_num, chain_id),
lambda _: chain_id,
operand=None,
)

_ = lax.cond(
iter_num == num_samples - 1,
lambda _: io_callback(_close_bar, None, None),
lambda _: io_callback(_close_bar, None, iter_num + 1, chain_id),
lambda _: None,
operand=None,
)
return chain_id

def _progress_bar_scan(func):
"""Decorator that adds a progress bar to `body_fun` used in `lax.scan`.
Expand All @@ -77,9 +85,26 @@ def wrapper_progress_bar(carry, x):
iter_num, *_ = x
else:
iter_num = x
_update_progress_bar(iter_num)
return func(carry, x)
subcarry, chain_id = carry
chain_id = _update_progress_bar(iter_num, chain_id)
subcarry, y = func(subcarry, x)

return (subcarry, chain_id), y

return wrapper_progress_bar

return _progress_bar_scan


def gen_scan_fn(num_samples, progress_bar, print_rate=None):
if progress_bar:

def scan_wrap(f, init, *args, **kwargs):
func = progress_bar_scan(num_samples, print_rate)(f)
carry = (init, -1)
(last_state, _), output = lax.scan(func, carry, *args, **kwargs)
return last_state, output

return scan_wrap
else:
return lax.scan
Loading

0 comments on commit 3eb18cb

Please sign in to comment.