Skip to content

Commit

Permalink
add locking
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewdipper committed Aug 5, 2024
1 parent d9ca59d commit 36fdb3f
Showing 1 changed file with 9 additions and 5 deletions.
14 changes: 9 additions & 5 deletions blackjax/progress_bar.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,14 @@
from fastprogress.fastprogress import progress_bar
from jax import lax
from jax.experimental import io_callback
from threading import Lock


def progress_bar_scan(num_samples, print_rate=None):
"Progress bar for a JAX scan"
progress_bars = {}
idx_map = {}
lock = Lock()

if print_rate is None:
if num_samples > 20:
Expand All @@ -42,14 +44,16 @@ def _calc_chain_idx(iter_num):
return idx

def _update_bar(arg):
idx = _calc_chain_idx(arg)
if arg == 0:
progress_bars[idx] = progress_bar(range(num_samples))
progress_bars[idx].update(0)
with lock:
idx = _calc_chain_idx(arg)
if arg == 0:
progress_bars[idx] = progress_bar(range(num_samples))
progress_bars[idx].update(0)
progress_bars[idx].update_bar(arg + 1)

def _close_bar(arg):
idx = _calc_chain_idx(arg)
with lock:
idx = _calc_chain_idx(arg)
progress_bars[idx].on_iter_end()

def _update_progress_bar(iter_num):
Expand Down

0 comments on commit 36fdb3f

Please sign in to comment.