Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TFP JAX: The transition kernel drastically decreases speed. #1807

Open
SebastianSosa opened this issue Apr 9, 2024 · 4 comments
Open

TFP JAX: The transition kernel drastically decreases speed. #1807

SebastianSosa opened this issue Apr 9, 2024 · 4 comments

Comments

@SebastianSosa
Copy link

SebastianSosa commented Apr 9, 2024

Dear all,

I am currently learning Bayesian analysis and utilizing tensorflow_probability.substrates.jax, but I've encountered some issues. While using jax with jit for NUTS alone, the performance is quite fast. However, when combined with transformed transitionKernel, the speed decreases drastically. Here's a summary of the time taken:

  • TFP GPU: NUTS alone took 118.2952 seconds
  • TFP GPU: NUTS + Bijector took 1986.8306 seconds
  • TFP GPU: NUTS + DualAveragingStepSizeAdaptation took 141.0955 seconds
  • TFP GPU: NUTS + Bijector + DualAveragingStepSizeAdaptation took 2397.5875 seconds
  • Numpypro GPU: NUTS + Bijector + DualAveragingStepSizeAdaptation took 180 seconds

I've conducted speed tests comparing with Numpypro, and essentially, Numpypro with dual averaging step size adaptation and parameter constraints is equivalent to tensorflow_probability NUTS alone.

Could there be something I've missed? Is there room for optimization in this process?

Please find the data and code (.txt need to be change as .ipynb) for reproducibility enclosed:
data.csv
gitissue.txt
google Colab

Please note that I'm only using the first 100 lines of the data.

Additionally, as a potential cause, I observed similar speed loss when using the LKJ distribution for other models. (I could post one of them if needed.)

Thank you in advance for your assistance.

Sebastian

@SebastianSosa SebastianSosa changed the title TFP JAX: transitionKernel drastically reduces speed TFP JAX: The transition kernel drastically increases speed. Apr 10, 2024
@SebastianSosa SebastianSosa changed the title TFP JAX: The transition kernel drastically increases speed. TFP JAX: The transition kernel drastically decreases speed. Apr 10, 2024
@ColCarroll
Copy link
Collaborator

Hi - It looks like the colab is locked down, so I can not access it.

@SebastianSosa
Copy link
Author

Does this link allow access?

I made a simulation instead of using real data, as it allows us to evaluate how the models perform with the increase in data size. I can update it in the next few days.

@ColCarroll
Copy link
Collaborator

Note that the data is not saved with the colab, so I can not run this, but it looks as though the problem is with your use of tfp.bijectors.CorrelationCholesky(ni). Note that CorrelationCholesky doesn't take any parameters, and ni is silently being accepted as an argument to validate_args.

Downstream, I think this will lead to some wild posterior, and so TFP NUTS is (correctly) exhausting its tree doublings and doing ~10x as much work.

@SebastianSosa
Copy link
Author

Hi,

I have modified the code accordingly and added simulated data to facilitate easy reproducibility. Unfortunately, the processing time remains very high.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants