Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feature request: redo timing comparisons in tutorial 5 (densenet) comparing jax with pytorch 2.0 #85

Open
murphyk opened this issue Mar 19, 2023 · 1 comment
Labels
enhancement New feature or request JAX Notebook in JAX/Flax PyTorch Notebook in PyTorch/Lightning

Comments

@murphyk
Copy link

murphyk commented Mar 19, 2023

In https://uvadlc-notebooks.readthedocs.io/en/latest/tutorial_notebooks/JAX/tutorial5/Inception_ResNet_DenseNet.html
you claim jax is faster than pytorch 1. Is this still true using torch.compile from pytorch 2?

@phlippe phlippe added enhancement New feature or request JAX Notebook in JAX/Flax PyTorch Notebook in PyTorch/Lightning labels Jun 1, 2023
@phlippe
Copy link
Owner

phlippe commented Jun 1, 2023

In first checks, PyTorch 2.0 is a bit faster than PyTorch 1, but doesn't reach the performance of JAX, e.g., on the Transformer models. I noticed that it can give quite some boost in inference, though. Still, PyTorch currently fails to compile the CNNs (some issue in the compilation, both locally and on Colab), so will wait until that is stable and then redo the speed comparisons.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request JAX Notebook in JAX/Flax PyTorch Notebook in PyTorch/Lightning
Projects
None yet
Development

No branches or pull requests

2 participants