Printing iteration number in jitted functions/loops using host_callback.id_tap
#4763
-
Hello! I'm writing MCMC samplers and optimisers in JAX, and I would like to be able to print the iteration number in jitted JAX functions. There are related issues (such as #196) which explain that you can't simply use Python's print function, but there's the experimental So I wrote a def progress_bar(i, n_iter, print_rate):
if i%print_rate==0:
print(f"Iteration {i}/{n_iter}")
else:
pass If I run it in a normal Python loop it works as expected: def my_python_loop():
"""
Python loop that increments `a` 100 times
"""
n_iter, print_rate = 100, 10
a = 0
for i in range(n_iter):
host_callback.id_tap(progress_bar, i, n_iter=n_iter, print_rate=print_rate)
a += 1
return a
my_python_loop() This outputs the expected result:
However if I Furthermore, if I modify the function in the following way: @jit
def my_python_loop2(a):
"""
Python loop that increments `a` 100 times
"""
n_iter, print_rate = 100, 10
for i in range(n_iter):
a = host_callback.id_tap(progress_bar, a, n_iter=n_iter, print_rate=print_rate)
a += 1
return a Then this function prints out the iterations. But this is only because I pass So my question: is there a practical way to print out the iteration number in jitted functions (in particular using the |
Beta Was this translation helpful? Give feedback.
Replies: 2 comments 1 reply
-
You definitely need to include a data dependency in order to ensure the callback runs at the correct time. Here are two versions that work with from jax.experimental import host_callback
from jax import jit
def progress_bar(arg, transforms):
i, n_iter, print_rate = arg
if i % print_rate==0:
print(f"Iteration {i}/{n_iter}")
else:
pass
@jit
def my_python_loop2(a):
"""
Python loop that increments `a` 100 times
"""
n_iter, print_rate = 100, 10
for i in range(n_iter):
a = host_callback.id_tap(progress_bar, (i, n_iter, print_rate), result=a)
a += 1
return a
@jit
def my_python_loop3(a):
"""
Python loop that increments `a` 100 times
"""
n_iter, print_rate = 100, 10
a, _ = jax.lax.scan(
lambda a, i: (
host_callback.id_tap(progress_bar, (i, n_iter, print_rate), result=a) + 1,
None),
init=a,
xs=jax.numpy.arange(n_iter),
)
return a (Note that this is using the newer host_callback interface, you may need to upgrade to the newest version of JAX to use this) |
Beta Was this translation helpful? Give feedback.
-
Hi Stephan, thanks very much for your answer! It hadn't "clicked" for me how to use Also, I made the following modification to the solution so that the Python print function doesn't get called at every iteration (and rather only gets called when it actually needs to print). This makes the loop run faster: import jax
from jax.experimental import host_callback
def _print_consumer(arg, transrorm):
i, n_iter = arg
print(f"Iteration {i}/{n_iter}")
@jax.jit
def progress_bar(arg, result):
"Print progress of loop only if iteration number is a multiple of the print_rate"
i, n_iter, print_rate = arg
result = jax.lax.cond(
i%print_rate==0,
lambda _: host_callback.id_tap(_print_consumer, (i, n_iter), result=result),
lambda _: result,
operand=None)
return result
@jax.jit
def jax_loop(a):
"""
Jax loop that increments `a` 100 times
"""
n_iter, print_rate = 100, 10
def body(carry, x):
carry = progress_bar((x, n_iter, print_rate), carry)
carry += 1
return carry, None
carry, _ = jax.lax.scan(body, a, jax.numpy.arange(n_iter))
return carry Also, was this question more appropriate for the Discussions section rather than the issue tracker? Thanks again. |
Beta Was this translation helpful? Give feedback.
You definitely need to include a data dependency in order to ensure the callback runs at the correct time.
Here are two versions that work with
jit
, with and withoutscan
: