You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Sampling from the Negative Binomial distribution (using jax substrates), especially using a small total_count parameter is very slow, compared to a jax only implementation.
Sampling from the Negative Binomial distribution (using jax substrates), especially using a small
total_count
parameter is very slow, compared to a jax only implementation.'0.23.0'
6 s ± 38 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
235 ms ± 710 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
119 ms ± 87.5 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
747 µs ± 65.8 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
604 µs ± 38 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
532 µs ± 867 ns per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
I'm using tfp 0.23.0 under Python 3.10.13, as sampling under 0.24.0 with Python 3.12. does not work for me (I encounter similar behavior as in #1838).
The text was updated successfully, but these errors were encountered: