diff --git a/blackjax/progress_bar.py b/blackjax/progress_bar.py index 303e319d9..bd41a9582 100644 --- a/blackjax/progress_bar.py +++ b/blackjax/progress_bar.py @@ -17,12 +17,14 @@ from fastprogress.fastprogress import progress_bar from jax import lax from jax.experimental import io_callback +from threading import Lock def progress_bar_scan(num_samples, print_rate=None): "Progress bar for a JAX scan" progress_bars = {} idx_map = {} + lock = Lock() if print_rate is None: if num_samples > 20: @@ -42,14 +44,16 @@ def _calc_chain_idx(iter_num): return idx def _update_bar(arg): - idx = _calc_chain_idx(arg) - if arg == 0: - progress_bars[idx] = progress_bar(range(num_samples)) - progress_bars[idx].update(0) + with lock: + idx = _calc_chain_idx(arg) + if arg == 0: + progress_bars[idx] = progress_bar(range(num_samples)) + progress_bars[idx].update(0) progress_bars[idx].update_bar(arg + 1) def _close_bar(arg): - idx = _calc_chain_idx(arg) + with lock: + idx = _calc_chain_idx(arg) progress_bars[idx].on_iter_end() def _update_progress_bar(iter_num):