Execution of replica 0 failed: INVALID_ARGUMENT #24727
Unanswered
jaxengodfrey
asked this question in
Q&A
Replies: 0 comments
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
Hello,
I am using GPU enabled
numpyro
andjax
. I am getting the following error after my chain finishes sampling but beforenumpyro.infer.MCMC.run()
has finished compiling the results:I only get this error with large warmup and sample sizes, ~100k. I think it started occurring when I started saving a large number of deterministic variables during sampling. I have done this in the past with older versions of numpyro/jax without issue, same sample sizes, deterministic variables, GPU, etc.
I'm using the NUTS kernel and an 80GB Nvidia A100 GPU.
Because I don't know much about how
numpyro
usesjax
under the hood, I'm not sure how I could isolate this issue withinjax
itself to make troubleshooting easier. Any ideas/suggestions would be appreciated!Beta Was this translation helpful? Give feedback.
All reactions