forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Enable batch matmul for result sizes > 2**32 the tensor can be split …
…along batch axis (pytorch#133430) Fixes pytorch#131865. Addresses the issue seen when running llama v3.1 8B parameter model on MPS backend where the batch matmul output size can go over the 32-bit indexing limit of MPS tensors, causing an assert. Test case to reproduce the issue with the dimensions encountered in llama v3.1 and verify this fix works around it: ``` import torch device='mps' a = torch.randn([32, 20064, 128], dtype=torch.float32,device=device) b = torch.randn([32, 128, 20064], dtype=torch.float32, device=device) res = torch.bmm(a, b) ``` Notably the current change only works as long as the individual output matrix in the bmm does not exceed the number of elements 2**32. This lets us split up the computation along the batch axis to avoid going over the limit. Added a TORCH_CHECK to raise an error if the individual matrix dimensions are too large to handle for this op until a more general workaround tiling the matmuls is available. Pull Request resolved: pytorch#133430 Approved by: https://github.com/malfet Co-authored-by: Nikita Shulga <[email protected]>
- Loading branch information
1 parent
50efbb9
commit 92f282c
Showing
2 changed files
with
134 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters