You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
In first checks, PyTorch 2.0 is a bit faster than PyTorch 1, but doesn't reach the performance of JAX, e.g., on the Transformer models. I noticed that it can give quite some boost in inference, though. Still, PyTorch currently fails to compile the CNNs (some issue in the compilation, both locally and on Colab), so will wait until that is stable and then redo the speed comparisons.
In https://uvadlc-notebooks.readthedocs.io/en/latest/tutorial_notebooks/JAX/tutorial5/Inception_ResNet_DenseNet.html
you claim jax is faster than pytorch 1. Is this still true using torch.compile from pytorch 2?
The text was updated successfully, but these errors were encountered: