-
Thanks for creating this library! I was very excited after hearing @jakevdp's wonderful talk. I tried using JAX for MLE estimation of a Negative Binomial random variable. My baseline was It's likely I missed something. I could use Notebook that reproduces the above figure is on Colab |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 1 reply
-
I just wrote a new draft "FAQ" entry on "Benchmarking JAX code": https://github.com/google/jax/pull/5879/files Please a take and let me know if it helps answer your question. In this particular case, I think part of the issue is that your example using JIT only compiles the gradient function, not the entire thing. Ideally you would JIT the entire function that you're benchmarking, including the call to @jax.jit
def fit_nbinom_bfgs_autograd_jit(y, mu):
... |
Beta Was this translation helpful? Give feedback.
I just wrote a new draft "FAQ" entry on "Benchmarking JAX code": https://github.com/google/jax/pull/5879/files
Please a take and let me know if it helps answer your question.
In this particular case, I think part of the issue is that your example using JIT only compiles the gradient function, not the entire thing. Ideally you would JIT the entire function that you're benchmarking, including the call to
jax.scipy.optimize.minimize
: