Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Slow sampling of NegativeBinomial distribution #1843

Open
stefanheyder opened this issue Sep 28, 2024 · 0 comments
Open

Slow sampling of NegativeBinomial distribution #1843

stefanheyder opened this issue Sep 28, 2024 · 0 comments

Comments

@stefanheyder
Copy link

Sampling from the Negative Binomial distribution (using jax substrates), especially using a small total_count parameter is very slow, compared to a jax only implementation.

import tensorflow_probability as tfp
tfp.__version__

'0.23.0'

from tensorflow_probability.substrates.jax.distributions import (
    NegativeBinomial as NBinom,
)
from jax import numpy as jnp, random as jrn, config as config
config.update("jax_enable_x64", True)

N = 1000
mu = 1e4
small_r = 0.1
middle_r = 10
large_r = 1000
key = jrn.PRNGKey(342354234)

nbinom_small_r = NBinom(total_count=small_r, logits=jnp.log(mu) - jnp.log(small_r))
nbinom_middle_r = NBinom(total_count=middle_r, logits=jnp.log(mu) - jnp.log(middle_r))
nbinom_large_r = NBinom(total_count=large_r, logits=jnp.log(mu) - jnp.log(large_r))

%timeit nbinom_small_r.sample(seed=key, sample_shape=(N,)).block_until_ready()
%timeit nbinom_middle_r.sample(seed=key, sample_shape=(N,)).block_until_ready()
%timeit nbinom_large_r.sample(seed=key, sample_shape=(N,)).block_until_ready()

6 s ± 38 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
235 ms ± 710 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
119 ms ± 87.5 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)

def sample_nbinom(key, r, mu, shp):
    key, sk_g, sk_p = jrn.split(key, 3)
    gamma_sample = mu / r * jrn.gamma(sk_g, r, shp)
    return jrn.poisson(sk_p, gamma_sample)
%timeit sample_nbinom(jrn.PRNGKey(0), small_r, mu, (N,)).block_until_ready()
%timeit sample_nbinom(jrn.PRNGKey(0), middle_r, mu, (N,)).block_until_ready()
%timeit sample_nbinom(jrn.PRNGKey(0), large_r, mu, (N,)).block_until_ready()

747 µs ± 65.8 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
604 µs ± 38 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
532 µs ± 867 ns per loop (mean ± std. dev. of 7 runs, 1,000 loops each)

I'm using tfp 0.23.0 under Python 3.10.13, as sampling under 0.24.0 with Python 3.12. does not work for me (I encounter similar behavior as in #1838).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant