Skip to content

Jax for ML estimation - Slower than scipy? #5876

Answered by shoyer
saketkc asked this question in Q&A
Discussion options

You must be logged in to vote

I just wrote a new draft "FAQ" entry on "Benchmarking JAX code": https://github.com/google/jax/pull/5879/files

Please a take and let me know if it helps answer your question.

In this particular case, I think part of the issue is that your example using JIT only compiles the gradient function, not the entire thing. Ideally you would JIT the entire function that you're benchmarking, including the call to jax.scipy.optimize.minimize:

@jax.jit
def fit_nbinom_bfgs_autograd_jit(y, mu):
  ...

Replies: 1 comment 1 reply

Comment options

You must be logged in to vote
1 reply
@saketkc
Comment options

Answer selected by saketkc
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants