Skip to content

Commit

Permalink
Enable batch matmul for result sizes > 2**32 the tensor can be split …
Browse files Browse the repository at this point in the history
…along batch axis (pytorch#133430)

Fixes pytorch#131865. Addresses the issue seen when running llama v3.1 8B parameter model on MPS backend where the batch matmul output size can go over the 32-bit indexing limit of MPS tensors, causing an assert.

Test case to reproduce the issue with the dimensions encountered in llama v3.1 and verify this fix works around it:

```
import torch
device='mps'
a = torch.randn([32, 20064, 128], dtype=torch.float32,device=device)
b = torch.randn([32, 128, 20064], dtype=torch.float32, device=device)
res = torch.bmm(a, b)
```

Notably the current change only works as long as the individual output matrix in the bmm does not exceed the number of elements 2**32. This lets us split up the computation along the batch axis to avoid going over the limit.

Added a TORCH_CHECK to raise an error if the individual matrix dimensions are too large to handle for this op until a more general workaround tiling the matmuls is available.
Pull Request resolved: pytorch#133430
Approved by: https://github.com/malfet

Co-authored-by: Nikita Shulga <[email protected]>
  • Loading branch information
2 people authored and pytorchmergebot committed Aug 30, 2024
1 parent 50efbb9 commit 92f282c
Show file tree
Hide file tree
Showing 2 changed files with 134 additions and 0 deletions.
120 changes: 120 additions & 0 deletions aten/src/ATen/native/mps/operations/LinearAlgebra.mm
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include <ATen/native/LinearAlgebraUtils.h>
#include <ATen/native/Resize.h>
// For MTLLanguageVersion_3_1
#include <ATen/native/mps/MPSGraphSequoiaOps.h>
#include <ATen/native/mps/MPSGraphSonomaOps.h>
#include <ATen/native/mps/OperationUtils.h>

Expand Down Expand Up @@ -509,11 +510,123 @@ static void linalg_lu_factor_out_mps_impl(const Tensor& A, bool pivot, Tensor& L
return output;
}

static Tensor& tiled_bmm_out_mps_impl(const Tensor& batch1, const Tensor& batch2, Tensor& result) {
if (is_macos_13_or_newer(MacOSVersion::MACOS_VER_15_0_PLUS)) {
using namespace mps;

id<MTLBuffer> aBuffer = getMTLBufferStorage(batch1);
id<MTLBuffer> bBuffer = getMTLBufferStorage(batch2);
id<MTLBuffer> resBuffer = getMTLBufferStorage(result);

MPSStream* mpsStream = getCurrentMPSStream();
id<MTLDevice> device = MPSDevice::getInstance()->device();
id<MTLComputeCommandEncoder> computeEncoder = mpsStream->commandEncoder();

dispatch_sync_with_rethrow(mpsStream->queue(), ^() {
@autoreleasepool {
mpsStream->endKernelCoalescing();

uint64_t originalBatchSize = batch1.sizes().size() > 2 ? batch1.size(0) : 1;
uint64_t aRows = batch1.size(-2);
uint64_t bRows = batch2.size(-2);
uint64_t resRows = result.size(-2);
uint64_t aCols = batch1.size(-1);
uint64_t bCols = batch2.size(-1);
uint64_t resCols = result.size(-1);
uint64_t aElemSize = batch1.element_size();
uint64_t bElemSize = batch2.element_size();
uint64_t resElemSize = result.element_size();
MPSDataType dtype = getMPSDataType(batch1);

uint64_t elemInMatrix = resRows * resCols;
uint64_t largestSupportedBatchSize = floor(pow(2, 32) / elemInMatrix);
uint64_t batchSize = std::min(largestSupportedBatchSize, originalBatchSize);
uint64_t lastBatchSize = originalBatchSize % batchSize;

id<MTLCommandBuffer> commandBuffer = mpsStream->commandBuffer();

auto matmul = [[MPSNDArrayMatrixMultiplication alloc] initWithDevice:device sourceCount:2];

MPSShape* aShape = @[ @(batchSize), @(aRows), @(aCols) ];
MPSShape* bShape = @[ @(batchSize), @(bRows), @(bCols) ];
MPSShape* resShape = @[ @(batchSize), @(resRows), @(resCols) ];
auto aDesc_ = [MPSNDArrayDescriptor descriptorWithDataType:dtype shape:aShape];
aDesc_.preferPackedRows = true;
auto bDesc_ = [MPSNDArrayDescriptor descriptorWithDataType:dtype shape:bShape];
bDesc_.preferPackedRows = true;

auto resDesc_ = [MPSNDArrayDescriptor descriptorWithDataType:dtype shape:resShape];
resDesc_.preferPackedRows = true;

getMPSProfiler().beginProfileKernel(matmul, " tiled_bmm_mps", {batch1, batch2});

// Descriptors to use for last batch if it exists
//.matrices is a readonly property so we need a separate descriptor.
MPSNDArrayDescriptor *aDescLastBatch_, *bDescLastBatch_, *resDescLastBatch_;
if (lastBatchSize != 0) {
aDescLastBatch_ = [MPSNDArrayDescriptor descriptorWithDataType:dtype
shape:@[ @(lastBatchSize), @(aRows), @(aCols) ]];
aDescLastBatch_.preferPackedRows = true;
bDescLastBatch_ = [MPSNDArrayDescriptor descriptorWithDataType:dtype
shape:@[ @(lastBatchSize), @(bRows), @(bCols) ]];
bDescLastBatch_.preferPackedRows = true;
resDescLastBatch_ =
[MPSNDArrayDescriptor descriptorWithDataType:dtype shape:@[ @(lastBatchSize), @(resRows), @(resCols) ]];
resDescLastBatch_.preferPackedRows = true;
}

uint64_t requiredIterations = ceil(float(originalBatchSize) / batchSize);
auto aDesc = aDesc_;
auto bDesc = bDesc_;
auto resDesc = resDesc_;
for (const auto i : c10::irange(requiredIterations)) {
if (i == requiredIterations - 1 && lastBatchSize != 0) {
aDesc = aDescLastBatch_;
bDesc = bDescLastBatch_;
resDesc = resDescLastBatch_;
}
const uint64_t aArrayOffset = i * batchSize * aRows * aCols;
const uint64_t bArrayOffset = i * batchSize * bRows * bCols;
const uint64_t resArrayOffset = i * batchSize * resRows * resCols;

auto aMatrix = [[[MPSNDArray alloc] initWithBuffer:aBuffer
offset:(batch1.storage_offset() + aArrayOffset) * aElemSize
descriptor:aDesc] autorelease];
auto bMatrix = [[[MPSNDArray alloc] initWithBuffer:bBuffer
offset:(batch2.storage_offset() + bArrayOffset) * bElemSize
descriptor:bDesc] autorelease];
auto resMatrix = [[[MPSNDArray alloc] initWithBuffer:resBuffer
offset:(result.storage_offset() + resArrayOffset) * resElemSize
descriptor:resDesc] autorelease];

[matmul encodeToCommandEncoder:computeEncoder
commandBuffer:commandBuffer
sourceArrays:@[ aMatrix, bMatrix ]
destinationArray:resMatrix];
}
}
});
return result;
} else {
TORCH_CHECK(false, "Tiling of batch matmul for larger than 2**32 entries only available from MacOS15 onwards");
}
}

static Tensor& bmm_out_mps_impl(const Tensor& batch1, const Tensor& batch2, Tensor& result) {
using namespace mps;

TORCH_CHECK(supportedFloatingOrComplexType(batch1), "MPS device does not support bmm for non-float inputs");

// Currently unsupported if the matmul output goes over the 32-bit indexing limit
TORCH_CHECK(
batch1.size(1) * batch2.size(2) <= pow(2, 32),
"Output size of the matrix multiplication is larger than currently supported by the MPS backend: ",
batch1.size(1),
",",
batch2.size(2),
", needs to be less than 2**32 elements.",
"File a feature request for this use case against the MPS backend at https://github.com/pytorch/pytorch/issues");

if (batch1.numel() == 0 || batch2.numel() == 0) {
result.zero_();
return result;
Expand Down Expand Up @@ -543,6 +656,13 @@ static void linalg_lu_factor_out_mps_impl(const Tensor& A, bool pivot, Tensor& L
}
}

// Check if we need to split the batch to do the computation
uint64_t resultSize = batch1.size(0) * batch1.size(1) * batch2.size(2);
if (resultSize > pow(2, 32)) {
result = tiled_bmm_out_mps_impl(batch1, batch2, result);
return result;
}

MPSStream* stream = getCurrentMPSStream();

struct CachedGraph : public mps::MPSCachedGraph {
Expand Down
14 changes: 14 additions & 0 deletions test/test_mps.py
Original file line number Diff line number Diff line change
Expand Up @@ -1914,6 +1914,20 @@ def test_bmm(self):
self.assertEqual(output_cpu, output_mps)
self.assertEqual(output_cpu.size(), output_mps.size())

@xfailIf(product_version < 15.0)
@parametrize("dtype", [torch.float16, torch.bfloat16])
def test_large_bmm(self, dtype):
batch1 = torch.randn(11, 20064, 128, dtype=dtype, device='mps')
batch2 = torch.randn(11, 128, 20064, dtype=dtype, device='mps')
output_cpu = torch.bmm(batch1.cpu(), batch2.cpu())
output_mps = torch.bmm(batch1, batch2)

# Using the low precision comparison for FP16
tol = 1e-2 if dtype == torch.float16 else None
self.assertEqual(output_cpu, output_mps, atol=tol, rtol=tol)
self.assertEqual(output_cpu.size(), output_mps.size())


def test_addr(self):
A = torch.ones(5, 10).to("mps")
B = torch.ones(5).to("mps")
Expand Down

0 comments on commit 92f282c

Please sign in to comment.