From 92f282ca520377f054d3c06a873ac093f797145b Mon Sep 17 00:00:00 2001 From: Joona Havukainen Date: Fri, 30 Aug 2024 14:08:43 +0000 Subject: [PATCH] Enable batch matmul for result sizes > 2**32 the tensor can be split along batch axis (#133430) Fixes #131865. Addresses the issue seen when running llama v3.1 8B parameter model on MPS backend where the batch matmul output size can go over the 32-bit indexing limit of MPS tensors, causing an assert. Test case to reproduce the issue with the dimensions encountered in llama v3.1 and verify this fix works around it: ``` import torch device='mps' a = torch.randn([32, 20064, 128], dtype=torch.float32,device=device) b = torch.randn([32, 128, 20064], dtype=torch.float32, device=device) res = torch.bmm(a, b) ``` Notably the current change only works as long as the individual output matrix in the bmm does not exceed the number of elements 2**32. This lets us split up the computation along the batch axis to avoid going over the limit. Added a TORCH_CHECK to raise an error if the individual matrix dimensions are too large to handle for this op until a more general workaround tiling the matmuls is available. Pull Request resolved: https://github.com/pytorch/pytorch/pull/133430 Approved by: https://github.com/malfet Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com> --- .../native/mps/operations/LinearAlgebra.mm | 120 ++++++++++++++++++ test/test_mps.py | 14 ++ 2 files changed, 134 insertions(+) diff --git a/aten/src/ATen/native/mps/operations/LinearAlgebra.mm b/aten/src/ATen/native/mps/operations/LinearAlgebra.mm index b1158865c4e68..e40454307ac97 100644 --- a/aten/src/ATen/native/mps/operations/LinearAlgebra.mm +++ b/aten/src/ATen/native/mps/operations/LinearAlgebra.mm @@ -5,6 +5,7 @@ #include #include // For MTLLanguageVersion_3_1 +#include #include #include @@ -509,11 +510,123 @@ static void linalg_lu_factor_out_mps_impl(const Tensor& A, bool pivot, Tensor& L return output; } +static Tensor& tiled_bmm_out_mps_impl(const Tensor& batch1, const Tensor& batch2, Tensor& result) { + if (is_macos_13_or_newer(MacOSVersion::MACOS_VER_15_0_PLUS)) { + using namespace mps; + + id aBuffer = getMTLBufferStorage(batch1); + id bBuffer = getMTLBufferStorage(batch2); + id resBuffer = getMTLBufferStorage(result); + + MPSStream* mpsStream = getCurrentMPSStream(); + id device = MPSDevice::getInstance()->device(); + id computeEncoder = mpsStream->commandEncoder(); + + dispatch_sync_with_rethrow(mpsStream->queue(), ^() { + @autoreleasepool { + mpsStream->endKernelCoalescing(); + + uint64_t originalBatchSize = batch1.sizes().size() > 2 ? batch1.size(0) : 1; + uint64_t aRows = batch1.size(-2); + uint64_t bRows = batch2.size(-2); + uint64_t resRows = result.size(-2); + uint64_t aCols = batch1.size(-1); + uint64_t bCols = batch2.size(-1); + uint64_t resCols = result.size(-1); + uint64_t aElemSize = batch1.element_size(); + uint64_t bElemSize = batch2.element_size(); + uint64_t resElemSize = result.element_size(); + MPSDataType dtype = getMPSDataType(batch1); + + uint64_t elemInMatrix = resRows * resCols; + uint64_t largestSupportedBatchSize = floor(pow(2, 32) / elemInMatrix); + uint64_t batchSize = std::min(largestSupportedBatchSize, originalBatchSize); + uint64_t lastBatchSize = originalBatchSize % batchSize; + + id commandBuffer = mpsStream->commandBuffer(); + + auto matmul = [[MPSNDArrayMatrixMultiplication alloc] initWithDevice:device sourceCount:2]; + + MPSShape* aShape = @[ @(batchSize), @(aRows), @(aCols) ]; + MPSShape* bShape = @[ @(batchSize), @(bRows), @(bCols) ]; + MPSShape* resShape = @[ @(batchSize), @(resRows), @(resCols) ]; + auto aDesc_ = [MPSNDArrayDescriptor descriptorWithDataType:dtype shape:aShape]; + aDesc_.preferPackedRows = true; + auto bDesc_ = [MPSNDArrayDescriptor descriptorWithDataType:dtype shape:bShape]; + bDesc_.preferPackedRows = true; + + auto resDesc_ = [MPSNDArrayDescriptor descriptorWithDataType:dtype shape:resShape]; + resDesc_.preferPackedRows = true; + + getMPSProfiler().beginProfileKernel(matmul, " tiled_bmm_mps", {batch1, batch2}); + + // Descriptors to use for last batch if it exists + //.matrices is a readonly property so we need a separate descriptor. + MPSNDArrayDescriptor *aDescLastBatch_, *bDescLastBatch_, *resDescLastBatch_; + if (lastBatchSize != 0) { + aDescLastBatch_ = [MPSNDArrayDescriptor descriptorWithDataType:dtype + shape:@[ @(lastBatchSize), @(aRows), @(aCols) ]]; + aDescLastBatch_.preferPackedRows = true; + bDescLastBatch_ = [MPSNDArrayDescriptor descriptorWithDataType:dtype + shape:@[ @(lastBatchSize), @(bRows), @(bCols) ]]; + bDescLastBatch_.preferPackedRows = true; + resDescLastBatch_ = + [MPSNDArrayDescriptor descriptorWithDataType:dtype shape:@[ @(lastBatchSize), @(resRows), @(resCols) ]]; + resDescLastBatch_.preferPackedRows = true; + } + + uint64_t requiredIterations = ceil(float(originalBatchSize) / batchSize); + auto aDesc = aDesc_; + auto bDesc = bDesc_; + auto resDesc = resDesc_; + for (const auto i : c10::irange(requiredIterations)) { + if (i == requiredIterations - 1 && lastBatchSize != 0) { + aDesc = aDescLastBatch_; + bDesc = bDescLastBatch_; + resDesc = resDescLastBatch_; + } + const uint64_t aArrayOffset = i * batchSize * aRows * aCols; + const uint64_t bArrayOffset = i * batchSize * bRows * bCols; + const uint64_t resArrayOffset = i * batchSize * resRows * resCols; + + auto aMatrix = [[[MPSNDArray alloc] initWithBuffer:aBuffer + offset:(batch1.storage_offset() + aArrayOffset) * aElemSize + descriptor:aDesc] autorelease]; + auto bMatrix = [[[MPSNDArray alloc] initWithBuffer:bBuffer + offset:(batch2.storage_offset() + bArrayOffset) * bElemSize + descriptor:bDesc] autorelease]; + auto resMatrix = [[[MPSNDArray alloc] initWithBuffer:resBuffer + offset:(result.storage_offset() + resArrayOffset) * resElemSize + descriptor:resDesc] autorelease]; + + [matmul encodeToCommandEncoder:computeEncoder + commandBuffer:commandBuffer + sourceArrays:@[ aMatrix, bMatrix ] + destinationArray:resMatrix]; + } + } + }); + return result; + } else { + TORCH_CHECK(false, "Tiling of batch matmul for larger than 2**32 entries only available from MacOS15 onwards"); + } +} + static Tensor& bmm_out_mps_impl(const Tensor& batch1, const Tensor& batch2, Tensor& result) { using namespace mps; TORCH_CHECK(supportedFloatingOrComplexType(batch1), "MPS device does not support bmm for non-float inputs"); + // Currently unsupported if the matmul output goes over the 32-bit indexing limit + TORCH_CHECK( + batch1.size(1) * batch2.size(2) <= pow(2, 32), + "Output size of the matrix multiplication is larger than currently supported by the MPS backend: ", + batch1.size(1), + ",", + batch2.size(2), + ", needs to be less than 2**32 elements.", + "File a feature request for this use case against the MPS backend at https://github.com/pytorch/pytorch/issues"); + if (batch1.numel() == 0 || batch2.numel() == 0) { result.zero_(); return result; @@ -543,6 +656,13 @@ static void linalg_lu_factor_out_mps_impl(const Tensor& A, bool pivot, Tensor& L } } + // Check if we need to split the batch to do the computation + uint64_t resultSize = batch1.size(0) * batch1.size(1) * batch2.size(2); + if (resultSize > pow(2, 32)) { + result = tiled_bmm_out_mps_impl(batch1, batch2, result); + return result; + } + MPSStream* stream = getCurrentMPSStream(); struct CachedGraph : public mps::MPSCachedGraph { diff --git a/test/test_mps.py b/test/test_mps.py index 281c68b42cae1..f7f36e57c82ad 100644 --- a/test/test_mps.py +++ b/test/test_mps.py @@ -1914,6 +1914,20 @@ def test_bmm(self): self.assertEqual(output_cpu, output_mps) self.assertEqual(output_cpu.size(), output_mps.size()) + @xfailIf(product_version < 15.0) + @parametrize("dtype", [torch.float16, torch.bfloat16]) + def test_large_bmm(self, dtype): + batch1 = torch.randn(11, 20064, 128, dtype=dtype, device='mps') + batch2 = torch.randn(11, 128, 20064, dtype=dtype, device='mps') + output_cpu = torch.bmm(batch1.cpu(), batch2.cpu()) + output_mps = torch.bmm(batch1, batch2) + + # Using the low precision comparison for FP16 + tol = 1e-2 if dtype == torch.float16 else None + self.assertEqual(output_cpu, output_mps, atol=tol, rtol=tol) + self.assertEqual(output_cpu.size(), output_mps.size()) + + def test_addr(self): A = torch.ones(5, 10).to("mps") B = torch.ones(5).to("mps")