-
Notifications
You must be signed in to change notification settings - Fork 1.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
TFP JAX: The transition kernel drastically decreases speed. #1807
Comments
Hi - It looks like the colab is locked down, so I can not access it. |
Does this link allow access? I made a simulation instead of using real data, as it allows us to evaluate how the models perform with the increase in data size. I can update it in the next few days. |
Note that the data is not saved with the colab, so I can not run this, but it looks as though the problem is with your use of Downstream, I think this will lead to some wild posterior, and so TFP NUTS is (correctly) exhausting its tree doublings and doing ~10x as much work. |
Hi, I have modified the code accordingly and added simulated data to facilitate easy reproducibility. Unfortunately, the processing time remains very high. |
Dear all,
I am currently learning Bayesian analysis and utilizing
tensorflow_probability.substrates.jax
, but I've encountered some issues. While usingjax
withjit
for NUTS alone, the performance is quite fast. However, when combined with transformed transitionKernel, the speed decreases drastically. Here's a summary of the time taken:I've conducted speed tests comparing with
Numpypro
, and essentially,Numpypro
with dual averaging step size adaptation and parameter constraints is equivalent totensorflow_probability
NUTS alone.Could there be something I've missed? Is there room for optimization in this process?
Please find the data and code (.txt need to be change as .ipynb) for reproducibility enclosed:
data.csv
gitissue.txt
google Colab
Please note that I'm only using the first 100 lines of the data.
Additionally, as a potential cause, I observed similar speed loss when using the LKJ distribution for other models. (I could post one of them if needed.)
Thank you in advance for your assistance.
Sebastian
The text was updated successfully, but these errors were encountered: