From 03c744ead1482abde3ee9e70293215c5b557c629 Mon Sep 17 00:00:00 2001 From: Max191 <44243577+Max191@users.noreply.github.com> Date: Fri, 25 Oct 2024 16:37:27 -0700 Subject: [PATCH] [GPU] Support multiple contraction dims in MmaSchedules (#18720) This adds support for multiple M, N, and K dims in problems when deducing a GPUMMASchedule. The new heuristic is similar to the old one, but works on pairs of M and N dims. For example: ``` tensor * tensor -> tensor ``` This will try to distribute the seeded tile counts to `M0` and `N0` (first attempting to distribute evenly, and then distributing to N followed by N), and then distribute the residual counts to `M1` and `N1`. The K tile counts will be partitioned to `K0` first, and then the residual tile counts will be partitioned to `K1`. This PR also updates the config selection logic for the TileAndFuse pipeline to make use of the multiple contraction dimensions in mma schedules. --------- Signed-off-by: Max Dawkins --- .../Codegen/Common/GPU/GPUHeuristics.cpp | 361 ++++++++++++------ .../Codegen/Common/GPU/GPUHeuristics.h | 58 ++- .../Dialect/GPU/TargetUtils/ConfigUtils.cpp | 98 +++-- .../compiler/Codegen/LLVMGPU/KernelConfig.cpp | 79 ++-- .../test/ROCDL/config_tile_and_fuse.mlir | 72 +++- .../test/llvmgpu_convolution_to_igemm.mlir | 2 +- .../compiler/Codegen/SPIRV/KernelConfig.cpp | 24 +- .../Preprocessing/Common/PadToIntrinsics.cpp | 30 +- 8 files changed, 492 insertions(+), 232 deletions(-) diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUHeuristics.cpp b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUHeuristics.cpp index dc3078372f92..790484d2c565 100644 --- a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUHeuristics.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUHeuristics.cpp @@ -9,6 +9,7 @@ #include #include "llvm/ADT/APInt.h" +#include "llvm/ADT/Sequence.h" #include "llvm/Support/Debug.h" #include "llvm/Support/MathExtras.h" #include "llvm/Support/raw_ostream.h" @@ -20,51 +21,106 @@ using llvm::APIntOps::GreatestCommonDivisor; namespace mlir::iree_compiler { +template static llvm::raw_ostream &operator<<(llvm::raw_ostream &os, - const GPUMMASchedule &schedule) { - os << "mSize: " << schedule.mSize << ", "; - os << "nSize: " << schedule.nSize << ", "; - os << "kSize: " << schedule.kSize << ", "; - os << "mTileCount: " << schedule.mTileCount << ", "; - os << "nTileCount: " << schedule.nTileCount << ", "; - os << "kTileCount: " << schedule.kTileCount << ", "; - os << "mWarpCount: " << schedule.mWarpCount << ", "; - os << "nWarpCount: " << schedule.nWarpCount; + const llvm::SmallVectorImpl &vector) { + os << "["; + llvm::interleaveComma(vector, os); + os << "]"; return os; } +llvm::raw_ostream &operator<<(llvm::raw_ostream &os, + const GPUMMASchedule &schedule) { + os << "mSizes: " << schedule.mSize << ", "; + os << "nSizes: " << schedule.nSize << ", "; + os << "kSizes: " << schedule.kSize << ", "; + os << "mTileSizes: " << schedule.mTileSizes << ", "; + os << "nTileSizes: " << schedule.nTileSizes << ", "; + os << "kTileSizes: " << schedule.kTileSizes << ", "; + os << "mSubgroupCounts: " << schedule.mSubgroupCounts << ", "; + os << "nSubgroupCounts: " << schedule.nSubgroupCounts; + return os; +} + +// Shortened helper to compute the product of `values`. +static int64_t prod(ArrayRef values) { + return ShapedType::getNumElements(values); +} + static int64_t calculateSharedMemoryUsedInBytes(const GPUMMASchedule &schedule, int64_t lhsBitwidth, int64_t rhsBitwidth) { - int64_t tileM = schedule.mSize * schedule.mTileCount * schedule.mWarpCount; - int64_t tileN = schedule.nSize * schedule.nTileCount * schedule.nWarpCount; - int64_t tileK = schedule.kSize * schedule.kTileCount; + + int64_t tileM = schedule.mSize * prod(schedule.mTileSizes) * + prod(schedule.mSubgroupCounts); + int64_t tileN = schedule.nSize * prod(schedule.nTileSizes) * + prod(schedule.nSubgroupCounts); + int64_t tileK = schedule.kSize * prod(schedule.kTileSizes); return (tileM * tileK * lhsBitwidth + tileN * tileK * rhsBitwidth) / 8; } +/// Check that a GPUMMASchedule fits alignment restrictions. To be aligned, +/// the problem must be evenly divisible by the number of elements in the +/// schedule for each dimension. If `mustBeAligned` is false, then the innermost +/// problem dimension is allowed to be unaligned . static bool isScheduleAligned(const GPUMatmulShapeType &problem, const GPUMMASchedule &schedule, bool mustBeAligned) { - auto alignedMSize = - mustBeAligned - ? problem.mSize - : llvm::divideCeil(problem.mSize, schedule.mSize) * schedule.mSize; - auto alignedNSize = - mustBeAligned - ? problem.nSize - : llvm::divideCeil(problem.nSize, schedule.nSize) * schedule.nSize; - auto alignedKSize = - mustBeAligned - ? problem.kSize - : llvm::divideCeil(problem.kSize, schedule.kSize) * schedule.kSize; - bool isValidM = (alignedMSize % (schedule.mSize * schedule.mTileCount * - schedule.mWarpCount)) == 0; - bool isValidN = (alignedNSize % (schedule.nSize * schedule.nTileCount * - schedule.nWarpCount)) == 0; - bool isValidK = (alignedKSize % (schedule.kSize * schedule.kTileCount)) == 0; + SmallVector alignedMSizes(problem.mSizes); + alignedMSizes.back() = + mustBeAligned ? problem.mSizes.back() + : llvm::divideCeil(problem.mSizes.back(), schedule.mSize) * + schedule.mSize; + SmallVector alignedNSizes(problem.nSizes); + alignedNSizes.back() = + mustBeAligned ? problem.nSizes.back() + : llvm::divideCeil(problem.nSizes.back(), schedule.nSize) * + schedule.nSize; + SmallVector alignedKSizes(problem.kSizes); + alignedKSizes.back() = + mustBeAligned ? problem.kSizes.back() + : llvm::divideCeil(problem.kSizes.back(), schedule.kSize) * + schedule.kSize; + // Returns the number of elements in the schedule for each dimension. + auto getScheduleSizes = + [&](int64_t size, SmallVector tileCount, + std::optional> subgroupCount) { + SmallVector sizes = llvm::map_to_vector( + llvm::seq(tileCount.size()), [&](int64_t i) { + return subgroupCount ? tileCount[i] * subgroupCount.value()[i] + : tileCount[i]; + }); + sizes.back() *= size; + return sizes; + }; + // Checks whether the elements of `a` are evenly divisible by the + // corresponding elements of `b`. + auto areAligned = [](SmallVector a, SmallVector b) { + for (auto [aVal, bVal] : llvm::zip_equal(a, b)) { + if (aVal % bVal != 0) { + return false; + } + } + return true; + }; + bool isValidM = areAligned( + alignedMSizes, getScheduleSizes(schedule.mSize, schedule.mTileSizes, + schedule.mSubgroupCounts)); + bool isValidN = areAligned( + alignedNSizes, getScheduleSizes(schedule.nSize, schedule.nTileSizes, + schedule.nSubgroupCounts)); + bool isValidK = areAligned( + alignedKSizes, + getScheduleSizes(schedule.kSize, schedule.kTileSizes, std::nullopt)); return isValidM && isValidN && isValidK; } +/// Returns whether or not a GPUMMASchedule is valid for the given problem. +/// This checks that: +/// - The problem is aligned to the schedule +/// - the number of threads in the schedule workgroup can be distributed +/// to a corresponding vector.transfer read in VectorDistribute. static bool isValidMMASchedule(const GPUMatmulShapeType &problem, const GPUMMASchedule &schedule, bool mustBeAligned, int64_t subgroupSize, @@ -76,11 +132,13 @@ static bool isValidMMASchedule(const GPUMatmulShapeType &problem, const int64_t kMaxVectorLoadBitWidth = 128; int64_t elemsPerThread = kMaxVectorLoadBitWidth / problem.bType.getIntOrFloatBitWidth(); - int64_t wgThreads = schedule.mWarpCount * schedule.nWarpCount * subgroupSize; - - int64_t mWgSize = schedule.mSize * schedule.mTileCount * schedule.mWarpCount; - int64_t nWgSize = schedule.nSize * schedule.nTileCount * schedule.nWarpCount; - int64_t kWgSize = schedule.kSize * schedule.kTileCount; + int64_t wgThreads = subgroupSize * prod(schedule.mSubgroupCounts) * + prod(schedule.nSubgroupCounts); + int64_t mWgSize = schedule.mSize * prod(schedule.mTileSizes) * + prod(schedule.mSubgroupCounts); + int64_t nWgSize = schedule.nSize * prod(schedule.nTileSizes) * + prod(schedule.nSubgroupCounts); + int64_t kWgSize = schedule.kSize * prod(schedule.kTileSizes); int64_t innerLhsDimSize = transposedLhs ? mWgSize : kWgSize; int64_t innerRhsDimSize = transposedRhs ? kWgSize : nWgSize; @@ -94,6 +152,10 @@ static bool isValidMMASchedule(const GPUMatmulShapeType &problem, return isAligned && isDistributableLhs && isDistributableRhs; } +/// Tries to fit the schedule into shared memory by decrementing the size of the +/// schedule dimensions from outermost to innermost until a valid schedule is +/// found. The schedule sizes are reduced in the order of mTileSizes, +/// nTileSizes, kTileSizes, mSubgroupCounts, nSubgroupCounts. static FailureOr fitScheduleInSharedMemory( GPUMatmulShapeType intrinsic, GPUMMASchedule schedule, llvm::function_ref isScheduleValid) { @@ -105,31 +167,35 @@ static FailureOr fitScheduleInSharedMemory( llvm::dbgs() << "Shrinking schedule...\n"; }); - auto decrementIfPossible = [](int64_t &c) -> LogicalResult { - if (c <= 1) { - return failure(); + auto decrementIfPossible = + [](SmallVector &sizes) -> LogicalResult { + for (int64_t &size : sizes) { + if (size <= 1) + continue; + --size; + return success(); } - --c; - return success(); + return failure(); }; // Attempt to shrink the schedule along one of the dimensions. // TODO: A better solution should probably factor problem.mSize / - // (mWarpCount * mTileCount * mSize) and then pop off the smallest factors - // one at a time, preferably trying to keep the tile "generally square." - if (succeeded(decrementIfPossible(schedule.mTileCount))) { + // (mSubgroupCount * mTileCount * mSize) and then pop off the smallest + // factors one at a time, preferably trying to keep the tile "generally + // square." + if (succeeded(decrementIfPossible(schedule.mTileSizes))) { continue; } - if (succeeded(decrementIfPossible(schedule.nTileCount))) { + if (succeeded(decrementIfPossible(schedule.nTileSizes))) { continue; } - if (succeeded(decrementIfPossible(schedule.kTileCount))) { + if (succeeded(decrementIfPossible(schedule.kTileSizes))) { continue; } - if (succeeded(decrementIfPossible(schedule.mWarpCount))) { + if (succeeded(decrementIfPossible(schedule.mSubgroupCounts))) { continue; } - if (succeeded(decrementIfPossible(schedule.nWarpCount))) { + if (succeeded(decrementIfPossible(schedule.nSubgroupCounts))) { continue; } @@ -148,6 +214,9 @@ static FailureOr fitScheduleInSharedMemory( static LogicalResult canTargetIntrinsic(const GPUMatmulShapeType &problem, const GPUMatmulShapeType &intrinsic, bool canUpcastAcc, bool mustBeAligned) { + assert(intrinsic.mSizes.size() == 1 && intrinsic.nSizes.size() == 1 && + intrinsic.kSizes.size() == 1 && + "expected intrinsic to have a single M, N, and K dimension."); if (problem.aType != intrinsic.aType || problem.bType != intrinsic.bType) { return failure(); // Cannot use this intrinsic for mismatched types } @@ -161,17 +230,17 @@ static LogicalResult canTargetIntrinsic(const GPUMatmulShapeType &problem, } } - if (mustBeAligned && (problem.mSize % intrinsic.mSize != 0 || - problem.nSize % intrinsic.nSize != 0 || - problem.kSize % intrinsic.kSize != 0)) { + if (mustBeAligned && (problem.mSizes.back() % intrinsic.mSizes[0] != 0 || + problem.nSizes.back() % intrinsic.nSizes[0] != 0 || + problem.kSizes.back() % intrinsic.kSizes[0] != 0)) { return failure(); // Cannot use this intrinsic for misaligned cases. } // Cannot use the intrinsic when the tile size is greater than problem size. // Because tiling is a no-op, and we can't infer tiling sizes from IR. - if (!mustBeAligned && - (problem.mSize < intrinsic.mSize || problem.nSize < intrinsic.nSize || - problem.kSize < intrinsic.kSize)) { + if (!mustBeAligned && (problem.mSizes.back() < intrinsic.mSizes[0] || + problem.nSizes.back() < intrinsic.nSizes[0] || + problem.kSizes.back() < intrinsic.kSizes[0])) { return failure(); } @@ -185,77 +254,123 @@ static GPUMMASchedule getOptimalMMASchedule(const GPUMatmulShapeType &problem, const GPUMatmulShapeType &intrinsic, const GPUMMAHeuristicSeeds &seeds, uint64_t intrinsicIndex) { - int64_t mTotalTileCount = llvm::divideCeil(problem.mSize, intrinsic.mSize); - int64_t nTotalTileCount = llvm::divideCeil(problem.nSize, intrinsic.nSize); - - int64_t remainingWarps = seeds.bestSubgroupCountPerWorkgroup; + assert(intrinsic.mSizes.size() == 1 && intrinsic.nSizes.size() == 1 && + intrinsic.kSizes.size() == 1 && + "expected intrinsic to have a single M, N, and K dimension."); + // mTotalTileCounts and nTotalTileCounts represent the total number of + // intrinsics along the M or N dimensions needed to fill the problem size. + // For example, if the problem is {M:[4, 16], N:[2, 32], K[3, 128]} for a + // 16x16x16 intrinsic, then: + // - mTotalTileCounts would be 4 * (16/16) = 4 + // - nTotalTileCounts would be 2 * (32/16) = 4 + SmallVector mTotalTileCounts = problem.mSizes; + SmallVector nTotalTileCounts = problem.nSizes; + mTotalTileCounts.back() = + llvm::divideCeil(problem.mSizes.back(), intrinsic.mSizes[0]); + nTotalTileCounts.back() = + llvm::divideCeil(problem.nSizes.back(), intrinsic.nSizes[0]); + + int64_t remainingSubgroups = seeds.bestSubgroupCountPerWorkgroup; int64_t remainingTiles = seeds.bestMNTileCountPerSubgroup; - // Assign more warps to the M dimension (used later) to balance thread + // Assign more subgroups to the M dimension (used later) to balance thread // counts along X and Y dimensions. - int64_t warpSqrt = - 1ull << (llvm::divideCeil(llvm::Log2_64(remainingWarps), 2)); - int64_t tileSqrt = 1ull << (llvm::Log2_64(remainingTiles) / 2); - - int64_t mWarpCount = 0, nWarpCount = 0; - int64_t mTileCount = 0, nTileCount = 0; - - // See if the square root can divide mTotalTileCount. If so it means we can - // distribute to both dimensions evenly. Otherwise, try to distribute to N - // and then M. - if (mTotalTileCount > (warpSqrt * tileSqrt) && - mTotalTileCount % (warpSqrt * tileSqrt) == 0) { - mWarpCount = warpSqrt; - mTileCount = tileSqrt; - - remainingWarps /= warpSqrt; - remainingTiles /= tileSqrt; - - APInt nGCD = GreatestCommonDivisor(APInt(64, nTotalTileCount), - APInt(64, remainingWarps)); - nWarpCount = nGCD.getSExtValue(); - nTotalTileCount /= nWarpCount; - remainingWarps /= nWarpCount; - - nGCD = GreatestCommonDivisor(APInt(64, nTotalTileCount), - APInt(64, remainingTiles)); - nTileCount = nGCD.getSExtValue(); - } else { - APInt nGCD = GreatestCommonDivisor(APInt(64, nTotalTileCount), - APInt(64, remainingWarps)); - nWarpCount = nGCD.getSExtValue(); - nTotalTileCount /= nWarpCount; - remainingWarps /= nWarpCount; - - nGCD = GreatestCommonDivisor(APInt(64, nTotalTileCount), - APInt(64, remainingTiles)); - nTileCount = nGCD.getSExtValue(); - remainingTiles /= nTileCount; - - APInt mGCD = GreatestCommonDivisor(APInt(64, mTotalTileCount), - APInt(64, remainingWarps)); - mWarpCount = mGCD.getSExtValue(); - mTotalTileCount /= mWarpCount; - remainingWarps /= mWarpCount; - - mGCD = GreatestCommonDivisor(APInt(64, mTotalTileCount), - APInt(64, remainingTiles)); - mTileCount = mGCD.getSExtValue(); + int mDim = problem.mSizes.size() - 1; + int nDim = problem.nSizes.size() - 1; + SmallVector mTileSizes(problem.mSizes.size(), 0), + nTileSizes(problem.nSizes.size(), 0), + mSubgroupCounts(problem.mSizes.size(), 0), + nSubgroupCounts(problem.nSizes.size(), 0); + // Start at the innermost nDim and mDim, and try to distribute evenly to M and + // N for each pair of M and N dims. Otherwise, distribute to N and then M. + while (mDim >= 0 || nDim >= 0) { + int64_t subgroupSqrt = + 1ull << (llvm::divideCeil(llvm::Log2_64(remainingSubgroups), 2)); + int64_t tileSqrt = 1ull << (llvm::Log2_64(remainingTiles) / 2); + + // See if the square root can divide mTotalTileCount. If so it means we can + // distribute to both dimensions evenly to minimize the number of global + // loads. Otherwise, try to distribute to N and then M. + if (mDim >= 0 && nDim >= 0 && + mTotalTileCounts[mDim] > (subgroupSqrt * tileSqrt) && + mTotalTileCounts[mDim] % (subgroupSqrt * tileSqrt) == 0) { + mSubgroupCounts[mDim] = subgroupSqrt; + mTileSizes[mDim] = tileSqrt; + + remainingSubgroups /= subgroupSqrt; + remainingTiles /= tileSqrt; + + APInt nGCD = GreatestCommonDivisor(APInt(64, nTotalTileCounts[nDim]), + APInt(64, remainingSubgroups)); + nSubgroupCounts[nDim] = nGCD.getSExtValue(); + nTotalTileCounts[nDim] /= nSubgroupCounts[nDim]; + remainingSubgroups /= nSubgroupCounts[nDim]; + + nGCD = GreatestCommonDivisor(APInt(64, nTotalTileCounts[nDim]), + APInt(64, remainingTiles)); + nTileSizes[nDim] = nGCD.getSExtValue(); + remainingTiles /= nTileSizes[nDim]; + } else { + if (nDim >= 0) { + APInt nGCD = GreatestCommonDivisor(APInt(64, nTotalTileCounts[nDim]), + APInt(64, remainingSubgroups)); + nSubgroupCounts[nDim] = nGCD.getSExtValue(); + nTotalTileCounts[nDim] /= nSubgroupCounts[nDim]; + remainingSubgroups /= nSubgroupCounts[nDim]; + + nGCD = GreatestCommonDivisor(APInt(64, nTotalTileCounts[nDim]), + APInt(64, remainingTiles)); + nTileSizes[nDim] = nGCD.getSExtValue(); + remainingTiles /= nTileSizes[nDim]; + } + + if (mDim >= 0) { + APInt mGCD = GreatestCommonDivisor(APInt(64, mTotalTileCounts[mDim]), + APInt(64, remainingSubgroups)); + mSubgroupCounts[mDim] = mGCD.getSExtValue(); + mTotalTileCounts[mDim] /= mSubgroupCounts[mDim]; + remainingSubgroups /= mSubgroupCounts[mDim]; + + mGCD = GreatestCommonDivisor(APInt(64, mTotalTileCounts[mDim]), + APInt(64, remainingTiles)); + mTileSizes[mDim] = mGCD.getSExtValue(); + remainingTiles /= mTileSizes[mDim]; + } + } + --mDim; + --nDim; } - const uint64_t kTotalTileCount = - llvm::divideCeil(problem.kSize, intrinsic.kSize); + // kTotalTileCounts is similar to m/nTotalTileCounts, representing the total + // number of intrinsics along the K dimensions needed to fill the problem. + // For the problem described above {M:[4, 16], N:[2, 32], K[3, 128]} with a + // 16x16x16 intrinsic, then: + // - kTotalTileCounts would be 3 * (128/16) = 24 + SmallVector kTotalTileCounts = problem.kSizes; + kTotalTileCounts.back() = + llvm::divideCeil(problem.kSizes.back(), intrinsic.kSizes[0]); + // Compute the ideal number of intrinsics along K per subgroup based on the + // seed. int64_t bestKTileCountPerSubgroup = seeds.bestKElementCountPerSubgroup ? llvm::divideCeil(seeds.bestKElementCountPerSubgroup, - intrinsic.kSize) + intrinsic.kSizes[0]) : seeds.bestKTileCountPerSubgroup; - APInt kGCD = GreatestCommonDivisor(APInt(64, kTotalTileCount), - APInt(64, bestKTileCountPerSubgroup)); - int64_t kTileCount = kGCD.getSExtValue(); + SmallVector kTileSizes(problem.kSizes.size(), 0); + // Start at the innermost K dim, and tile each dim to try to satisfy the ideal + // K intrinsic count per subgroup with the overall product of K tile counts. + int kDim = problem.kSizes.size() - 1; + while (kDim >= 0) { + APInt kGCD = GreatestCommonDivisor(APInt(64, kTotalTileCounts[kDim]), + APInt(64, bestKTileCountPerSubgroup)); + kTileSizes[kDim] = kGCD.getSExtValue(); + bestKTileCountPerSubgroup /= kTileSizes[kDim]; + --kDim; + } - return GPUMMASchedule{intrinsicIndex, intrinsic.mSize, intrinsic.nSize, - intrinsic.kSize, mWarpCount, nWarpCount, - mTileCount, nTileCount, kTileCount}; + return GPUMMASchedule{ + intrinsicIndex, intrinsic.mSizes[0], intrinsic.nSizes[0], + intrinsic.kSizes[0], mSubgroupCounts, nSubgroupCounts, + mTileSizes, nTileSizes, kTileSizes}; } FailureOr deduceMMASchedule( @@ -297,7 +412,6 @@ FailureOr deduceMMASchedule( return isAligned && sharedMemoryUsed <= sharedMemLimitInBytes; }; - return fitScheduleInSharedMemory(intrinsic, schedule, isValidSchedule); } return failure(); @@ -309,7 +423,10 @@ FailureOr deduceAttentionSchedule( const GPUMMAHeuristicSeeds &pvMatmulSeeds, int64_t sharedMemLimitInBytes, int64_t subgroupSize, bool transposedQ, bool transposedK, bool transposedV, bool canUpcastAcc, bool mustBeAligned) { - + assert(pvMatmul.mSizes.size() == 1 && pvMatmul.nSizes.size() == 1 && + pvMatmul.kSizes.size() == 1 && qkMatmul.mSizes.size() == 1 && + qkMatmul.nSizes.size() == 1 && qkMatmul.kSizes.size() == 1 && + "unimplemented: multi M/N/K attention schedule"); for (auto [index, intrinsic] : llvm::enumerate(intrinsics)) { if (failed(canTargetIntrinsic(qkMatmul, intrinsic, canUpcastAcc, mustBeAligned))) { @@ -329,7 +446,7 @@ FailureOr deduceAttentionSchedule( llvm::dbgs() << " " << schedule << "\n"; }); - int64_t intrinsicK = intrinsic.kSize; + int64_t intrinsicK = intrinsic.kSizes[0]; auto isValidSchedule = [&](const GPUMMASchedule &schedule) -> bool { // Create a mma schedule for qkMatmul in attention. // qkMatmul.M = pvMatmul.M @@ -339,11 +456,11 @@ FailureOr deduceAttentionSchedule( schedule.mSize, schedule.kSize, intrinsicK, - /*mWarpCount=*/schedule.mWarpCount, - /*nWarpCount=*/1, - schedule.mTileCount, - schedule.kTileCount, - qkMatmul.kSize / intrinsicK}; + /*mSubgroupCount=*/schedule.mSubgroupCounts[0], + /*nSubgroupCount=*/1, + schedule.mTileSizes[0], + schedule.kTileSizes[0], + qkMatmul.kSizes[0] / intrinsicK}; bool isQKAligned = isValidMMASchedule(qkMatmul, qkSchedule, mustBeAligned, subgroupSize, diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUHeuristics.h b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUHeuristics.h index 8211443a2e12..13f6a56c1b6f 100644 --- a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUHeuristics.h +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUHeuristics.h @@ -10,15 +10,18 @@ namespace mlir::iree_compiler { /// Struct containing information about a matmul's shape and type. struct GPUMatmulShapeType { - int64_t mSize; - int64_t nSize; - int64_t kSize; + SmallVector mSizes; + SmallVector nSizes; + SmallVector kSizes; Type aType; Type bType; Type cType; GPUMatmulShapeType(int64_t m, int64_t n, int64_t k, Type a, Type b, Type c) - : mSize(m), nSize(n), kSize(k), aType(a), bType(b), cType(c) {} + : mSizes({m}), nSizes({n}), kSizes({k}), aType(a), bType(b), cType(c) {} + GPUMatmulShapeType(SmallVector m, SmallVector n, + SmallVector k, Type a, Type b, Type c) + : mSizes(m), nSizes(n), kSizes(k), aType(a), bType(b), cType(c) {} }; /// Struct containing seed tile sizes for GPU MMA heuristics deduction logic. @@ -38,14 +41,42 @@ struct GPUMMAHeuristicSeeds { struct GPUMMASchedule { // Index of the chosen intrinsic into the list of given MMA intrinsics uint64_t index; - int64_t mSize; // Native MMA size along M dimension - int64_t nSize; // Native MMA size along N dimension - int64_t kSize; // Native MMA size along K dimension - int64_t mWarpCount; // Number of subgroups along M dimension - int64_t nWarpCount; // Number of subgroups along N dimension - int64_t mTileCount; // Number of tiles per subgroup along M dimension - int64_t nTileCount; // Number of tiles per subgroup along N dimension - int64_t kTileCount; // Number of tiles along K dimension + int64_t mSize; // Native MMA intrinsic size along M dimension for a subgroup. + int64_t nSize; // Native MMA intrinsic size along N dimension for a subgroup. + int64_t kSize; // Native MMA intrinsic size along K dimension for a subgroup. + + // Number of subgroups along each M and N dimension. + SmallVector mSubgroupCounts; + SmallVector nSubgroupCounts; + + // Tile sizes for each M, N, and K dimension. When there are multiple M, N, + // or K dimensions, the intrinsic sizes are targeted to the innermost + // dimension, and the outer dimensions can be thought of as unrolling factors + // along M, N, or K. + SmallVector mTileSizes; // M tile sizes per subgroup. + SmallVector nTileSizes; // N tile sizes per subgroup. + SmallVector kTileSizes; // K tile sizes. + + // Constructor for multi M, N, K dim schedules. + GPUMMASchedule(uint64_t i, int64_t mIntrinsicSize, int64_t nIntrinsicSize, + int64_t kIntrinsicSize, SmallVector mSubgroupCounts, + SmallVector nSubgroupCounts, + SmallVector mTileSizes, + SmallVector nTileSizes, + SmallVector kTileSizes) + : index(i), mSize(mIntrinsicSize), nSize(nIntrinsicSize), + kSize(kIntrinsicSize), mSubgroupCounts(mSubgroupCounts), + nSubgroupCounts(nSubgroupCounts), mTileSizes(mTileSizes), + nTileSizes(nTileSizes), kTileSizes(kTileSizes) {} + + // Constructor for single M, N, K dim schedules. + GPUMMASchedule(uint64_t i, int64_t mIntrinsicSize, int64_t nIntrinsicSize, + int64_t kIntrinsicSize, int64_t mSubgroup, int64_t nSubgroup, + int64_t mTileSize, int64_t nTileSize, int64_t kTileSize) + : index(i), mSize(mIntrinsicSize), nSize(nIntrinsicSize), + kSize(kIntrinsicSize), mSubgroupCounts({mSubgroup}), + nSubgroupCounts({nSubgroup}), mTileSizes({mTileSize}), + nTileSizes({nTileSize}), kTileSizes({kTileSize}) {} }; /// Returns a schedule for using one of the given MMA |intrinsics| to target the @@ -69,4 +100,7 @@ FailureOr deduceAttentionSchedule( bool transposedV = false, bool canUpcastAcc = false, bool mustBeAligned = true); +llvm::raw_ostream &operator<<(llvm::raw_ostream &os, + const GPUMMASchedule &schedule); + } // namespace mlir::iree_compiler diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/ConfigUtils.cpp b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/ConfigUtils.cpp index ca23b0ca6e06..58bfdc0a028b 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/ConfigUtils.cpp +++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/ConfigUtils.cpp @@ -13,6 +13,7 @@ #include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUInterfaces.h" #include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUOps.h" #include "iree/compiler/Codegen/Utils/Utils.h" +#include "llvm/ADT/STLExtras.h" #include "llvm/Support/Casting.h" #include "llvm/Support/Debug.h" #include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h" @@ -124,20 +125,37 @@ LogicalResult setMatmulLoweringConfig(IREE::GPU::TargetAttr target, return failure(); } - // For now we are not being smart and trying to reshape dimensions to allow - // for better usage of intrinsics, and instead are tiling all dimensions - // except the inner most m, n, and k dimensions to 1. - int64_t mDim = contractionDims.m.back(); - int64_t nDim = contractionDims.n.back(); - int64_t kDim = contractionDims.k.back(); - - // Dynamic dims are expected to be taken care of earlier in the pipeline. - if (ShapedType::isDynamic(bounds[mDim]) || - ShapedType::isDynamic(bounds[nDim]) || - ShapedType::isDynamic(bounds[kDim])) { + // TODO(Max191): add dynamic shape support for inner most dims. + if (ShapedType::isDynamic(bounds[contractionDims.m.back()]) || + ShapedType::isDynamic(bounds[contractionDims.n.back()]) || + ShapedType::isDynamic(bounds[contractionDims.k.back()])) { return failure(); } + // Gather all static M, N, and K dimensions to deduce the MMASchedule. Dynamic + // dimensions will be tiled to 1 in workgroup tiling, so they are ignored when + // computing an MMA schedule. + SmallVector mDims, nDims, kDims; + for (auto mDim : contractionDims.m) { + if (!ShapedType::isDynamic(bounds[mDim])) { + mDims.push_back(mDim); + } + } + for (auto nDim : contractionDims.n) { + if (!ShapedType::isDynamic(bounds[nDim])) { + nDims.push_back(nDim); + } + } + for (auto kDim : contractionDims.k) { + if (!ShapedType::isDynamic(bounds[kDim])) { + kDims.push_back(kDim); + } + } + + auto getDimBounds = [&](SmallVector dims) -> SmallVector { + return llvm::map_to_vector(dims, [&](int64_t dim) { return bounds[dim]; }); + }; + Value lhs = linalgOp.getDpsInputOperand(0)->get(); Value rhs = linalgOp.getDpsInputOperand(1)->get(); Value init = linalgOp.getDpsInitOperand(0)->get(); @@ -146,8 +164,9 @@ LogicalResult setMatmulLoweringConfig(IREE::GPU::TargetAttr target, Type rhsElemType = getElementTypeOrSelf(rhs); Type initElemType = getElementTypeOrSelf(init); - GPUMatmulShapeType problem{bounds[mDim], bounds[nDim], bounds[kDim], - lhsElemType, rhsElemType, initElemType}; + GPUMatmulShapeType problem{getDimBounds(mDims), getDimBounds(nDims), + getDimBounds(kDims), lhsElemType, + rhsElemType, initElemType}; SmallVector intrinsics; for (IREE::GPU::MMAAttr mma : target.getWgp().getMma()) { @@ -166,7 +185,9 @@ LogicalResult setMatmulLoweringConfig(IREE::GPU::TargetAttr target, // Note that the following heuristic seeds are just placeholder values. // We need to clean it up and make it adjusting to different targets. // See https://github.com/iree-org/iree/issues/16341 for details. - if (problem.mSize * problem.nSize <= 512 * 512) { + int64_t mSize = ShapedType::getNumElements(problem.mSizes); + int64_t nSize = ShapedType::getNumElements(problem.nSizes); + if (mSize * nSize <= 512 * 512) { // For matmuls with small M*N size, we want to distribute M*N onto more // workgroups to fill the GPU. Use a smaller bestMNTileCountPerSubgroup // and a larger bestKTileCountPerSubgroup. @@ -190,10 +211,10 @@ LogicalResult setMatmulLoweringConfig(IREE::GPU::TargetAttr target, // TODO: Drop this. This is only a consideration for other pipelines. SmallVector maps = linalgOp.getIndexingMapsArray(); bool transposedLhs = - kDim != + kDims.back() != llvm::cast(maps[0].getResults().back()).getPosition(); bool transposedRhs = - nDim != + nDims.back() != llvm::cast(maps[1].getResults().back()).getPosition(); // First try to find a schedule with an exactly matching intrinsic. @@ -213,16 +234,13 @@ LogicalResult setMatmulLoweringConfig(IREE::GPU::TargetAttr target, } LDBG("Target Subgroup size: " << targetSubgroupSize); - LDBG("Schedule: sizes [" << schedule->mSize << ", " << schedule->nSize << ", " - << schedule->kSize << "]"); - LDBG("Schedule: tile counts [" << schedule->mTileCount << ", " - << schedule->nTileCount << ", " - << schedule->kTileCount << "]"); - LDBG("Schedule: warp counts [" << schedule->mWarpCount << ", " - << schedule->nWarpCount << "]"); + LDBG("Schedule: " << schedule); - std::array workgroupSize{ - schedule->nWarpCount * targetSubgroupSize, schedule->mWarpCount, 1}; + int64_t flatWorkgroupSize = + targetSubgroupSize * + ShapedType::getNumElements(schedule->nSubgroupCounts) * + ShapedType::getNumElements(schedule->mSubgroupCounts); + std::array workgroupSize{flatWorkgroupSize, 1, 1}; SmallVector workgroupTileSizes(linalgOp.getNumLoops(), 0); SmallVector reductionTileSizes(linalgOp.getNumLoops(), 0); @@ -244,16 +262,30 @@ LogicalResult setMatmulLoweringConfig(IREE::GPU::TargetAttr target, reductionTileSizes[k] = 1; } - // Compute the M/N dimension tile size by multiplying subgroup information. - workgroupTileSizes[mDim] = schedule->mWarpCount * schedule->mTileCount; - workgroupTileSizes[nDim] = schedule->nWarpCount * schedule->nTileCount; - - // Specify the subgroup tile sizes from the mma schedule. This is applied - subgroupTileSizes[mDim] = schedule->mTileCount; - subgroupTileSizes[nDim] = schedule->nTileCount; + // Adjust the inner bound size for packing to intrinsic shapes, since tiling + // happens after packing. + assert(bounds[mDims.back()] % schedule->mSize == 0 && + bounds[nDims.back()] % schedule->nSize == 0 && + "expected inner bound to be evenly divisible by schedule sizes."); + bounds[mDims.back()] /= schedule->mSize; + bounds[nDims.back()] /= schedule->nSize; + + // Compute the M/N dimension tile sizes by multiplying subgroup information. + for (auto [i, mDim] : llvm::enumerate(mDims)) { + workgroupTileSizes[mDim] = + schedule->mSubgroupCounts[i] * schedule->mTileSizes[i]; + subgroupTileSizes[mDim] = schedule->mTileSizes[i]; + } + for (auto [i, nDim] : llvm::enumerate(nDims)) { + workgroupTileSizes[nDim] = + schedule->nSubgroupCounts[i] * schedule->nTileSizes[i]; + subgroupTileSizes[nDim] = schedule->nTileSizes[i]; + } // Similarly the reduction tile size is just the post-packing tile count. - reductionTileSizes[kDim] = schedule->kTileCount; + for (auto [i, kDim] : llvm::enumerate(kDims)) { + reductionTileSizes[kDim] = schedule->kTileSizes[i]; + } IREE::GPU::MmaInterfaceAttr mmaKind = target.getWgp().getMma()[schedule->index]; diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp index ff002ace5b0f..4b64cda3adc9 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp @@ -301,6 +301,11 @@ setConvolutionVectorDistributionConfig(IREE::GPU::TargetAttr target, Type rhsElemType = getElementTypeOrSelf(rhs); Type initElemType = getElementTypeOrSelf(init); + // TODO(Max191): Support multiple M/N/K dimension problems for MMASchedules + // once the pipeline is able to support it. After adding multiple dimensions, + // all instances of schedule->m/nSubgroupCounts[0] and + // schedule->m/n/kTileSizes[0] need to use the full list of sizes instead of + // just the first element. GPUMatmulShapeType problem{bounds[mDim], bounds[nDim], bounds[kDim], lhsElemType, rhsElemType, initElemType}; @@ -339,8 +344,9 @@ setConvolutionVectorDistributionConfig(IREE::GPU::TargetAttr target, return failure(); } - std::array workgroupSize{ - schedule->nWarpCount * targetSubgroupSize, schedule->mWarpCount, 1}; + std::array workgroupSize{schedule->nSubgroupCounts[0] * + targetSubgroupSize, + schedule->mSubgroupCounts[0], 1}; SmallVector workgroupTileSizes(op.getNumLoops(), 0); SmallVector reductionTileSizes(op.getNumLoops(), 0); @@ -360,11 +366,11 @@ setConvolutionVectorDistributionConfig(IREE::GPU::TargetAttr target, } // Compute the M/N dimension tile size by multiply subgroup information. workgroupTileSizes[mDim] = - schedule->mWarpCount * schedule->mTileCount * schedule->mSize; + schedule->mSubgroupCounts[0] * schedule->mTileSizes[0] * schedule->mSize; workgroupTileSizes[nDim] = - schedule->nWarpCount * schedule->nTileCount * schedule->nSize; + schedule->nSubgroupCounts[0] * schedule->nTileSizes[0] * schedule->nSize; - reductionTileSizes[kDim] = schedule->kTileCount * schedule->kSize; + reductionTileSizes[kDim] = schedule->kTileSizes[0] * schedule->kSize; // Tile all filter loop dimensions to 1. for (int64_t filterDim : convolutionDims->filterLoop) { @@ -386,8 +392,8 @@ setConvolutionVectorDistributionConfig(IREE::GPU::TargetAttr target, // for later access in the pipeline. SmallVector pipelineAttrs; auto scheduleAttr = IREE::GPU::MMAScheduleAttr::get( - context, target.getWgp().getMma()[schedule->index], schedule->mWarpCount, - schedule->nWarpCount); + context, target.getWgp().getMma()[schedule->index], + schedule->mSubgroupCounts[0], schedule->nSubgroupCounts[0]); pipelineAttrs.emplace_back(StringAttr::get(context, "mma_schedule"), scheduleAttr); @@ -489,6 +495,11 @@ setMatmulVectorDistributionConfig(IREE::GPU::TargetAttr target, rhsElemType = getElementTypeOrSelf(rhsOp.getDpsInputs()[0]); } + // TODO(Max191): Support multiple M/N/K dimension problems for MMASchedules + // once the pipeline is able to support it. After adding multiple dimensions, + // all instances of schedule->m/nSubgroupCounts[0] and + // schedule->m/n/kTileSizes[0] need to use the full list of sizes instead of + // just the first element. GPUMatmulShapeType problem{bounds[mDim], bounds[nDim], bounds[kDim], lhsElemType, rhsElemType, initElemType}; @@ -509,7 +520,7 @@ setMatmulVectorDistributionConfig(IREE::GPU::TargetAttr target, // Note that the following heuristic seeds are just placeholder values. // We need to clean it up and make it adjusting to different targets. // See https://github.com/iree-org/iree/issues/16341 for details. - if (problem.mSize * problem.nSize <= clGPUMatmulCThreshold) { + if (problem.mSizes[0] * problem.nSizes[0] <= clGPUMatmulCThreshold) { // For matmuls with small M*N size, we want to distribute M*N onto more // workgroups to fill the GPU. Use a smaller bestMNTileCountPerSubgroup // and a larger bestKTileCountPerSubgroup. @@ -573,16 +584,11 @@ setMatmulVectorDistributionConfig(IREE::GPU::TargetAttr target, } LDBG("Target Subgroup size: " << targetSubgroupSize); - LDBG("Schedule: sizes [" << schedule->mSize << ", " << schedule->nSize << ", " - << schedule->kSize << "]"); - LDBG("Schedule: tile counts [" << schedule->mTileCount << ", " - << schedule->nTileCount << ", " - << schedule->kTileCount << "]"); - LDBG("Schedule: warp counts [" << schedule->mWarpCount << ", " - << schedule->nWarpCount << "]"); + LDBG("Schedule: " << schedule); - std::array workgroupSize{ - schedule->nWarpCount * targetSubgroupSize, schedule->mWarpCount, 1}; + std::array workgroupSize{schedule->nSubgroupCounts[0] * + targetSubgroupSize, + schedule->mSubgroupCounts[0], 1}; SmallVector workgroupTileSizes(op.getNumLoops(), 0); SmallVector reductionTileSizes(op.getNumLoops(), 0); @@ -605,11 +611,11 @@ setMatmulVectorDistributionConfig(IREE::GPU::TargetAttr target, // Compute the M/N dimension tile size by multiply subgroup information. workgroupTileSizes[mDim] = - schedule->mWarpCount * schedule->mTileCount * schedule->mSize; + schedule->mSubgroupCounts[0] * schedule->mTileSizes[0] * schedule->mSize; workgroupTileSizes[nDim] = - schedule->nWarpCount * schedule->nTileCount * schedule->nSize; + schedule->nSubgroupCounts[0] * schedule->nTileSizes[0] * schedule->nSize; - reductionTileSizes[kDim] = schedule->kTileCount * schedule->kSize; + reductionTileSizes[kDim] = schedule->kTileSizes[0] * schedule->kSize; LLVM_DEBUG(debugPrintContractionInfo("Workgroup tile sizes", op.getNumLoops(), *contractionDims, workgroupTileSizes)); @@ -631,8 +637,8 @@ setMatmulVectorDistributionConfig(IREE::GPU::TargetAttr target, // for later access in the pipeline. SmallVector pipelineAttrs; auto scheduleAttr = IREE::GPU::MMAScheduleAttr::get( - context, target.getWgp().getMma()[schedule->index], schedule->mWarpCount, - schedule->nWarpCount); + context, target.getWgp().getMma()[schedule->index], + schedule->mSubgroupCounts[0], schedule->nSubgroupCounts[0]); pipelineAttrs.emplace_back(StringAttr::get(context, "mma_schedule"), scheduleAttr); @@ -772,22 +778,17 @@ setAttentionVectorDistributionConfig(IREE::GPU::TargetAttr target, // TODO: Due to a bug in layout configuration, we cannot set warp count on // the N dimension. This is however ok, because we generally do not want to // distribute subgroups on N dimension anyway. - if (schedule->nWarpCount != 1) { - schedule->nTileCount *= schedule->nWarpCount; - schedule->nWarpCount = 1; + if (schedule->nSubgroupCounts[0] != 1) { + schedule->nTileSizes[0] *= schedule->nSubgroupCounts[0]; + schedule->nSubgroupCounts[0] = 1; } LDBG("Target Subgroup size: " << targetSubgroupSize); - LDBG("Schedule: sizes [" << schedule->mSize << ", " << schedule->nSize << ", " - << schedule->kSize << "]"); - LDBG("Schedule: tile counts [" << schedule->mTileCount << ", " - << schedule->nTileCount << ", " - << schedule->kTileCount << "]"); - LDBG("Schedule: warp counts [" << schedule->mWarpCount << ", " - << schedule->nWarpCount << "]"); + LDBG("Schedule: " << schedule); - std::array workgroupSize{ - schedule->nWarpCount * targetSubgroupSize, schedule->mWarpCount, 1}; + std::array workgroupSize{schedule->nSubgroupCounts[0] * + targetSubgroupSize, + schedule->mSubgroupCounts[0], 1}; SmallVector workgroupTileSizes(opInfo.getDomainRank(), 0); SmallVector reductionTileSizes(op.getNumLoops(), 0); @@ -811,11 +812,11 @@ setAttentionVectorDistributionConfig(IREE::GPU::TargetAttr target, // Compute the M/N dimension tile size by multiply subgroup information. workgroupTileSizes[mDim] = - schedule->mWarpCount * schedule->mTileCount * schedule->mSize; + schedule->mSubgroupCounts[0] * schedule->mTileSizes[0] * schedule->mSize; workgroupTileSizes[nDim] = - schedule->nWarpCount * schedule->nTileCount * schedule->nSize; + schedule->nSubgroupCounts[0] * schedule->nTileSizes[0] * schedule->nSize; - reductionTileSizes[k2Dim] = schedule->kTileCount * schedule->kSize; + reductionTileSizes[k2Dim] = schedule->kTileSizes[0] * schedule->kSize; MLIRContext *context = op.getContext(); SmallVector attrs; @@ -831,8 +832,8 @@ setAttentionVectorDistributionConfig(IREE::GPU::TargetAttr target, // for later access in the pipeline. SmallVector pipelineAttrs; auto scheduleAttr = IREE::GPU::MMAScheduleAttr::get( - context, target.getWgp().getMma()[schedule->index], schedule->mWarpCount, - schedule->nWarpCount); + context, target.getWgp().getMma()[schedule->index], + schedule->mSubgroupCounts[0], schedule->nSubgroupCounts[0]); pipelineAttrs.emplace_back(StringAttr::get(context, "mma_schedule"), scheduleAttr); diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_tile_and_fuse.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_tile_and_fuse.mlir index b98e85a79713..819b8826bb1d 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_tile_and_fuse.mlir +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_tile_and_fuse.mlir @@ -37,11 +37,79 @@ func.func @expanded_matmul_transpose_b(%lhs: tensor<2x64x2048xf16>, %rhs: tensor // CHECK-SAME: mma_kind = #iree_gpu.mma_layout // CHECK-SAME: promote_operands = [0, 1] // CHECK-SAME: reduction = [0, 0, 0, 0, 4] -// CHECK-SAME: subgroup = [0, 0, 4, 1, 0] +// CHECK-SAME: subgroup = [1, 1, 4, 1, 0] // CHECK-SAME: workgroup = [1, 1, 4, 4, 0] // ----- +#map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d2, d4, d5)> +#map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d1, d3, d4, d5)> +#map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)> +func.func @multi_dim_mma_schedule(%lhs: tensor<10x32x128x16xf16>, %rhs: tensor<4x32x128x16xf16>) -> tensor<10x4x32x32xf16> { + %c0 = arith.constant 0 : index + %cst = arith.constant 0.000000e+00 : f16 + %5 = tensor.empty() : tensor<10x4x32x32xf16> + %6 = linalg.fill ins(%cst : f16) outs(%5 : tensor<10x4x32x32xf16>) -> tensor<10x4x32x32xf16> + %7 = linalg.generic { + indexing_maps = [#map, #map1, #map2], + iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction"]} + ins(%lhs, %rhs : tensor<10x32x128x16xf16>, tensor<4x32x128x16xf16>) outs(%6 : tensor<10x4x32x32xf16>) { + ^bb0(%in: f16, %in_0: f16, %out: f16): + %8 = arith.mulf %in, %in_0 : f16 + %9 = arith.addf %8, %out : f16 + linalg.yield %9 : f16 + } -> tensor<10x4x32x32xf16> + return %7 : tensor<10x4x32x32xf16> +} + +// CHECK-LABEL: func.func @multi_dim_mma_schedule +// CHECK-SAME: #iree_codegen.translation_info + +// CHECK: linalg.generic {{.*}}lowering_config = #iree_gpu.lowering_config +// CHECK-SAME: mma_kind = #iree_gpu.mma_layout +// CHECK-SAME: promote_operands = [0, 1] +// CHECK-SAME: reduction = [0, 0, 0, 0, 4, 1] +// CHECK-SAME: subgroup = [2, 2, 1, 1, 0, 0] +// CHECK-SAME: workgroup = [2, 2, 2, 2, 0, 0] + +// ----- + +#map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d3, d5, d6)> +#map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d2, d4, d5, d6)> +#map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3, d4)> +func.func @dynamic_multi_dim_mma_schedule(%lhs: tensor, %rhs: tensor) -> tensor { + %c0 = arith.constant 0 : index + %cst = arith.constant 0.000000e+00 : f16 + %d0 = tensor.dim %lhs, %c0 : tensor + %d2 = tensor.dim %rhs, %c0 : tensor + %5 = tensor.empty(%d0, %d2) : tensor + %6 = linalg.fill ins(%cst : f16) outs(%5 : tensor) -> tensor + %7 = linalg.generic { + indexing_maps = [#map, #map1, #map2], + iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "reduction", "reduction"]} + ins(%lhs, %rhs : tensor, tensor) outs(%6 : tensor) { + ^bb0(%in: f16, %in_0: f16, %out: f16): + %8 = arith.mulf %in, %in_0 : f16 + %9 = arith.addf %8, %out : f16 + linalg.yield %9 : f16 + } -> tensor + return %7 : tensor +} + +// CHECK-LABEL: func.func @dynamic_multi_dim_mma_schedule +// CHECK-SAME: #iree_codegen.translation_info + +// CHECK: linalg.generic {{.*}}lowering_config = #iree_gpu.lowering_config +// CHECK-SAME: mma_kind = #iree_gpu.mma_layout +// CHECK-SAME: promote_operands = [0, 1] +// CHECK-SAME: reduction = [0, 0, 0, 0, 0, 1, 1] +// CHECK-SAME: subgroup = [0, 1, 0, 1, 1, 0, 0] +// CHECK-SAME: workgroup = [1, 2, 1, 1, 2, 0, 0] + +// ----- + func.func @mfma_matmul_1024x1024x1024(%lhs: tensor<1024x1024xf16>, %rhs: tensor<1024x1024xf16>) -> tensor<1024x1024xf32> { %cst = arith.constant 0.000000e+00 : f32 %c0 = arith.constant 0 : index @@ -52,7 +120,7 @@ func.func @mfma_matmul_1024x1024x1024(%lhs: tensor<1024x1024xf16>, %rhs: tensor< } // CHECK-LABEL: func.func @mfma_matmul_1024x1024x1024 -// CHECK-SAME: #iree_codegen.translation_info // Verify that the fill does not have the lowering config propagated to it. diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/llvmgpu_convolution_to_igemm.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/llvmgpu_convolution_to_igemm.mlir index 1fa2bae99a8e..9618281c699e 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/llvmgpu_convolution_to_igemm.mlir +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/llvmgpu_convolution_to_igemm.mlir @@ -33,4 +33,4 @@ func.func public @set_lowering_config(%arg0: tensor<1x34x34x128xf32>, %arg1: ten // CHECK-SAME: lowering_config = #iree_gpu.lowering_config< // CHECK-SAME: {mma_kind = #iree_gpu.mma_layout, // CHECK-SAME: promote_operands = [0, 1], reduction = [0, 0, 0, 0, 8], -// CHECK-SAME: subgroup = [0, 0, 2, 2, 0], workgroup = [1, 1, 2, 8, 0]}> +// CHECK-SAME: subgroup = [1, 1, 2, 2, 0], workgroup = [1, 1, 2, 8, 0]}> diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/KernelConfig.cpp b/compiler/src/iree/compiler/Codegen/SPIRV/KernelConfig.cpp index 16a1acf4316f..bbdec5c83f6d 100644 --- a/compiler/src/iree/compiler/Codegen/SPIRV/KernelConfig.cpp +++ b/compiler/src/iree/compiler/Codegen/SPIRV/KernelConfig.cpp @@ -884,6 +884,11 @@ setCooperativeMatrixConfig(IREE::GPU::TargetAttr target, linalg::LinalgOp op, Type lhsElem = getElementType(lhs); Type rhsElem = getElementType(rhs); Type initElem = getElementType(init); + // TODO(Max191): Support multiple M/N/K dimension problems for MMASchedules + // once the pipeline is able to support it. After adding multiple dimensions, + // all instances of schedule->m/nSubgroupCounts[0] and + // schedule->m/n/kTileSizes[0] need to use the full list of sizes instead of + // just the first element. GPUMatmulShapeType problem(dimM, dimN, dimK, lhsElem, rhsElem, initElem); SmallVector intrinsics; @@ -921,8 +926,9 @@ setCooperativeMatrixConfig(IREE::GPU::TargetAttr target, linalg::LinalgOp op, auto pipeline = CodeGenPipeline::SPIRVCooperativeMatrixVectorize; - std::array workgroupSize{schedule->nWarpCount * subgroupSize, - schedule->mWarpCount, 1}; + std::array workgroupSize{schedule->nSubgroupCounts[0] * + subgroupSize, + schedule->mSubgroupCounts[0], 1}; SmallVector vectorSizes(kIndex + 1, 0); if (isBM) @@ -934,21 +940,23 @@ setCooperativeMatrixConfig(IREE::GPU::TargetAttr target, linalg::LinalgOp op, SmallVector subgroupTileSizes(lastParallelDim + 1, 0); if (isBM) subgroupTileSizes[bIndex] = 1; - subgroupTileSizes[mIndex] = schedule->mTileCount * vectorSizes[mIndex]; - subgroupTileSizes[nIndex] = schedule->nTileCount * vectorSizes[nIndex]; + subgroupTileSizes[mIndex] = schedule->mTileSizes[0] * vectorSizes[mIndex]; + subgroupTileSizes[nIndex] = schedule->nTileSizes[0] * vectorSizes[nIndex]; SmallVector workgroupTileSizes(lastParallelDim + 1, 0); if (isBM) workgroupTileSizes[bIndex] = 1; - workgroupTileSizes[mIndex] = schedule->mWarpCount * subgroupTileSizes[mIndex]; - workgroupTileSizes[nIndex] = schedule->nWarpCount * subgroupTileSizes[nIndex]; + workgroupTileSizes[mIndex] = + schedule->mSubgroupCounts[0] * subgroupTileSizes[mIndex]; + workgroupTileSizes[nIndex] = + schedule->nSubgroupCounts[0] * subgroupTileSizes[nIndex]; // Also create one level for reduction. This is needed because of // SPIRVTileAndPromotePass requires it. // TODO(#10499): Consolidate tiling configuration across different pipelines. SmallVector reductionTileSizes; reductionTileSizes.append(kIndex, 0); - reductionTileSizes.push_back(schedule->kTileCount * schedule->kSize); + reductionTileSizes.push_back(schedule->kTileSizes[0] * schedule->kSize); TileSizesListType tileSizes = {workgroupTileSizes, subgroupTileSizes, reductionTileSizes, vectorSizes}; @@ -956,7 +964,7 @@ setCooperativeMatrixConfig(IREE::GPU::TargetAttr target, linalg::LinalgOp op, // Don't do multibuffering if the inner reduction loop is folded out. auto pipelineDepth = softwarePipelineDepth; auto storeStage = softwarePipelineStoreStage; - if (schedule->kTileCount <= 1) { + if (schedule->kTileSizes[0] <= 1) { pipelineDepth = 0; storeStage = 0; } diff --git a/compiler/src/iree/compiler/Preprocessing/Common/PadToIntrinsics.cpp b/compiler/src/iree/compiler/Preprocessing/Common/PadToIntrinsics.cpp index ba415b3fb656..922e50882775 100644 --- a/compiler/src/iree/compiler/Preprocessing/Common/PadToIntrinsics.cpp +++ b/compiler/src/iree/compiler/Preprocessing/Common/PadToIntrinsics.cpp @@ -242,16 +242,16 @@ padConvOp(RewriterBase &rewriter, linalg::LinalgOp linalgOp, return llvm::divideCeil(value, padTo) * padTo - value; }; - if (mSize % intrinsic.mSize != 0) { - mPadding = getPadding(mSize, intrinsic.mSize); + if (mSize % intrinsic.mSizes[0] != 0) { + mPadding = getPadding(mSize, intrinsic.mSizes[0]); } - if (nSize % intrinsic.nSize != 0) { - nPadding = getPadding(nSize, intrinsic.nSize); + if (nSize % intrinsic.nSizes[0] != 0) { + nPadding = getPadding(nSize, intrinsic.nSizes[0]); } - if (kSize % intrinsic.kSize != 0) { - kPadding = getPadding(kSize, intrinsic.kSize); + if (kSize % intrinsic.kSizes[0] != 0) { + kPadding = getPadding(kSize, intrinsic.kSizes[0]); } if (!mPadding && !nPadding && !kPadding) { @@ -381,7 +381,7 @@ static void padContractionLikeOp( for (GPUMatmulShapeType &intrinsic : intrinsics) { std::optional mPadding, nPadding, kPadding; SmallVector> dimsToExpandCandidate; - if (mSize % intrinsic.mSize != 0 || ShapedType::isDynamic(mSize)) { + if (mSize % intrinsic.mSizes[0] != 0 || ShapedType::isDynamic(mSize)) { OpFoldResult mSizeExpr = rewriter.getIndexAttr(mSize); if (ShapedType::isDynamic(mSize)) { auto mOperandDimPair = getSrcOperandAndDim(mDim); @@ -390,12 +390,12 @@ static void padContractionLikeOp( auto [mOperand, mOperandDim] = mOperandDimPair.value(); mSizeExpr = rewriter.create(loc, mOperand, mOperandDim) .getResult(); - dimsToExpandCandidate.emplace_back(mDim, intrinsic.mSize); + dimsToExpandCandidate.emplace_back(mDim, intrinsic.mSizes[0]); } - mPadding = getPadding(mSizeExpr, intrinsic.mSize); + mPadding = getPadding(mSizeExpr, intrinsic.mSizes[0]); } - if (nSize % intrinsic.nSize != 0 || ShapedType::isDynamic(nSize)) { + if (nSize % intrinsic.nSizes[0] != 0 || ShapedType::isDynamic(nSize)) { OpFoldResult nSizeExpr = rewriter.getIndexAttr(nSize); if (ShapedType::isDynamic(nSize)) { auto nOperandDimPair = getSrcOperandAndDim(nDim); @@ -404,12 +404,12 @@ static void padContractionLikeOp( auto [nOperand, nOperandDim] = nOperandDimPair.value(); nSizeExpr = rewriter.create(loc, nOperand, nOperandDim) .getResult(); - dimsToExpandCandidate.emplace_back(nDim, intrinsic.nSize); + dimsToExpandCandidate.emplace_back(nDim, intrinsic.nSizes[0]); } - nPadding = getPadding(nSizeExpr, intrinsic.nSize); + nPadding = getPadding(nSizeExpr, intrinsic.nSizes[0]); } - if (kSize % intrinsic.kSize != 0 || ShapedType::isDynamic(kSize)) { + if (kSize % intrinsic.kSizes[0] != 0 || ShapedType::isDynamic(kSize)) { OpFoldResult kSizeExpr = rewriter.getIndexAttr(kSize); if (ShapedType::isDynamic(kSize)) { auto kOperandDimPair = getSrcOperandAndDim(kDim); @@ -418,9 +418,9 @@ static void padContractionLikeOp( auto [kOperand, kOperandDim] = kOperandDimPair.value(); kSizeExpr = rewriter.create(loc, kOperand, kOperandDim) .getResult(); - dimsToExpandCandidate.emplace_back(kDim, intrinsic.kSize); + dimsToExpandCandidate.emplace_back(kDim, intrinsic.kSizes[0]); } - kPadding = getPadding(kSizeExpr, intrinsic.kSize); + kPadding = getPadding(kSizeExpr, intrinsic.kSizes[0]); } if (!mPadding && !nPadding && !kPadding) {