Skip to content

Commit

Permalink
[GPU] Support multiple contraction dims in MmaSchedules (#18720)
Browse files Browse the repository at this point in the history
This adds support for multiple M, N, and K dims in problems when
deducing a GPUMMASchedule. The new heuristic is similar to the old one,
but works on pairs of M and N dims. For example:
```
tensor<M1xM0xK1xK0> * tensor<N1xN0xK1xK0> -> tensor<M1xN1xM0xN0>
```
This will try to distribute the seeded tile counts to `M0` and `N0`
(first attempting to distribute evenly, and then distributing to N
followed by N), and then distribute the residual counts to `M1` and
`N1`. The K tile counts will be partitioned to `K0` first, and then the
residual tile counts will be partitioned to `K1`.

This PR also updates the config selection logic for the TileAndFuse
pipeline to make use of the multiple contraction dimensions in mma
schedules.

---------

Signed-off-by: Max Dawkins <[email protected]>
  • Loading branch information
Max191 authored Oct 25, 2024
1 parent 0c2c627 commit 03c744e
Show file tree
Hide file tree
Showing 8 changed files with 492 additions and 232 deletions.
361 changes: 239 additions & 122 deletions compiler/src/iree/compiler/Codegen/Common/GPU/GPUHeuristics.cpp

Large diffs are not rendered by default.

58 changes: 46 additions & 12 deletions compiler/src/iree/compiler/Codegen/Common/GPU/GPUHeuristics.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,18 @@ namespace mlir::iree_compiler {

/// Struct containing information about a matmul's shape and type.
struct GPUMatmulShapeType {
int64_t mSize;
int64_t nSize;
int64_t kSize;
SmallVector<int64_t> mSizes;
SmallVector<int64_t> nSizes;
SmallVector<int64_t> kSizes;
Type aType;
Type bType;
Type cType;

GPUMatmulShapeType(int64_t m, int64_t n, int64_t k, Type a, Type b, Type c)
: mSize(m), nSize(n), kSize(k), aType(a), bType(b), cType(c) {}
: mSizes({m}), nSizes({n}), kSizes({k}), aType(a), bType(b), cType(c) {}
GPUMatmulShapeType(SmallVector<int64_t> m, SmallVector<int64_t> n,
SmallVector<int64_t> k, Type a, Type b, Type c)
: mSizes(m), nSizes(n), kSizes(k), aType(a), bType(b), cType(c) {}
};

/// Struct containing seed tile sizes for GPU MMA heuristics deduction logic.
Expand All @@ -38,14 +41,42 @@ struct GPUMMAHeuristicSeeds {
struct GPUMMASchedule {
// Index of the chosen intrinsic into the list of given MMA intrinsics
uint64_t index;
int64_t mSize; // Native MMA size along M dimension
int64_t nSize; // Native MMA size along N dimension
int64_t kSize; // Native MMA size along K dimension
int64_t mWarpCount; // Number of subgroups along M dimension
int64_t nWarpCount; // Number of subgroups along N dimension
int64_t mTileCount; // Number of tiles per subgroup along M dimension
int64_t nTileCount; // Number of tiles per subgroup along N dimension
int64_t kTileCount; // Number of tiles along K dimension
int64_t mSize; // Native MMA intrinsic size along M dimension for a subgroup.
int64_t nSize; // Native MMA intrinsic size along N dimension for a subgroup.
int64_t kSize; // Native MMA intrinsic size along K dimension for a subgroup.

// Number of subgroups along each M and N dimension.
SmallVector<int64_t> mSubgroupCounts;
SmallVector<int64_t> nSubgroupCounts;

// Tile sizes for each M, N, and K dimension. When there are multiple M, N,
// or K dimensions, the intrinsic sizes are targeted to the innermost
// dimension, and the outer dimensions can be thought of as unrolling factors
// along M, N, or K.
SmallVector<int64_t> mTileSizes; // M tile sizes per subgroup.
SmallVector<int64_t> nTileSizes; // N tile sizes per subgroup.
SmallVector<int64_t> kTileSizes; // K tile sizes.

// Constructor for multi M, N, K dim schedules.
GPUMMASchedule(uint64_t i, int64_t mIntrinsicSize, int64_t nIntrinsicSize,
int64_t kIntrinsicSize, SmallVector<int64_t> mSubgroupCounts,
SmallVector<int64_t> nSubgroupCounts,
SmallVector<int64_t> mTileSizes,
SmallVector<int64_t> nTileSizes,
SmallVector<int64_t> kTileSizes)
: index(i), mSize(mIntrinsicSize), nSize(nIntrinsicSize),
kSize(kIntrinsicSize), mSubgroupCounts(mSubgroupCounts),
nSubgroupCounts(nSubgroupCounts), mTileSizes(mTileSizes),
nTileSizes(nTileSizes), kTileSizes(kTileSizes) {}

// Constructor for single M, N, K dim schedules.
GPUMMASchedule(uint64_t i, int64_t mIntrinsicSize, int64_t nIntrinsicSize,
int64_t kIntrinsicSize, int64_t mSubgroup, int64_t nSubgroup,
int64_t mTileSize, int64_t nTileSize, int64_t kTileSize)
: index(i), mSize(mIntrinsicSize), nSize(nIntrinsicSize),
kSize(kIntrinsicSize), mSubgroupCounts({mSubgroup}),
nSubgroupCounts({nSubgroup}), mTileSizes({mTileSize}),
nTileSizes({nTileSize}), kTileSizes({kTileSize}) {}
};

/// Returns a schedule for using one of the given MMA |intrinsics| to target the
Expand All @@ -69,4 +100,7 @@ FailureOr<GPUMMASchedule> deduceAttentionSchedule(
bool transposedV = false, bool canUpcastAcc = false,
bool mustBeAligned = true);

llvm::raw_ostream &operator<<(llvm::raw_ostream &os,
const GPUMMASchedule &schedule);

} // namespace mlir::iree_compiler
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUInterfaces.h"
#include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUOps.h"
#include "iree/compiler/Codegen/Utils/Utils.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/Debug.h"
#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
Expand Down Expand Up @@ -124,20 +125,37 @@ LogicalResult setMatmulLoweringConfig(IREE::GPU::TargetAttr target,
return failure();
}

// For now we are not being smart and trying to reshape dimensions to allow
// for better usage of intrinsics, and instead are tiling all dimensions
// except the inner most m, n, and k dimensions to 1.
int64_t mDim = contractionDims.m.back();
int64_t nDim = contractionDims.n.back();
int64_t kDim = contractionDims.k.back();

// Dynamic dims are expected to be taken care of earlier in the pipeline.
if (ShapedType::isDynamic(bounds[mDim]) ||
ShapedType::isDynamic(bounds[nDim]) ||
ShapedType::isDynamic(bounds[kDim])) {
// TODO(Max191): add dynamic shape support for inner most dims.
if (ShapedType::isDynamic(bounds[contractionDims.m.back()]) ||
ShapedType::isDynamic(bounds[contractionDims.n.back()]) ||
ShapedType::isDynamic(bounds[contractionDims.k.back()])) {
return failure();
}

// Gather all static M, N, and K dimensions to deduce the MMASchedule. Dynamic
// dimensions will be tiled to 1 in workgroup tiling, so they are ignored when
// computing an MMA schedule.
SmallVector<int64_t> mDims, nDims, kDims;
for (auto mDim : contractionDims.m) {
if (!ShapedType::isDynamic(bounds[mDim])) {
mDims.push_back(mDim);
}
}
for (auto nDim : contractionDims.n) {
if (!ShapedType::isDynamic(bounds[nDim])) {
nDims.push_back(nDim);
}
}
for (auto kDim : contractionDims.k) {
if (!ShapedType::isDynamic(bounds[kDim])) {
kDims.push_back(kDim);
}
}

auto getDimBounds = [&](SmallVector<int64_t> dims) -> SmallVector<int64_t> {
return llvm::map_to_vector(dims, [&](int64_t dim) { return bounds[dim]; });
};

Value lhs = linalgOp.getDpsInputOperand(0)->get();
Value rhs = linalgOp.getDpsInputOperand(1)->get();
Value init = linalgOp.getDpsInitOperand(0)->get();
Expand All @@ -146,8 +164,9 @@ LogicalResult setMatmulLoweringConfig(IREE::GPU::TargetAttr target,
Type rhsElemType = getElementTypeOrSelf(rhs);
Type initElemType = getElementTypeOrSelf(init);

GPUMatmulShapeType problem{bounds[mDim], bounds[nDim], bounds[kDim],
lhsElemType, rhsElemType, initElemType};
GPUMatmulShapeType problem{getDimBounds(mDims), getDimBounds(nDims),
getDimBounds(kDims), lhsElemType,
rhsElemType, initElemType};

SmallVector<GPUMatmulShapeType> intrinsics;
for (IREE::GPU::MMAAttr mma : target.getWgp().getMma()) {
Expand All @@ -166,7 +185,9 @@ LogicalResult setMatmulLoweringConfig(IREE::GPU::TargetAttr target,
// Note that the following heuristic seeds are just placeholder values.
// We need to clean it up and make it adjusting to different targets.
// See https://github.com/iree-org/iree/issues/16341 for details.
if (problem.mSize * problem.nSize <= 512 * 512) {
int64_t mSize = ShapedType::getNumElements(problem.mSizes);
int64_t nSize = ShapedType::getNumElements(problem.nSizes);
if (mSize * nSize <= 512 * 512) {
// For matmuls with small M*N size, we want to distribute M*N onto more
// workgroups to fill the GPU. Use a smaller bestMNTileCountPerSubgroup
// and a larger bestKTileCountPerSubgroup.
Expand All @@ -190,10 +211,10 @@ LogicalResult setMatmulLoweringConfig(IREE::GPU::TargetAttr target,
// TODO: Drop this. This is only a consideration for other pipelines.
SmallVector<AffineMap> maps = linalgOp.getIndexingMapsArray();
bool transposedLhs =
kDim !=
kDims.back() !=
llvm::cast<AffineDimExpr>(maps[0].getResults().back()).getPosition();
bool transposedRhs =
nDim !=
nDims.back() !=
llvm::cast<AffineDimExpr>(maps[1].getResults().back()).getPosition();

// First try to find a schedule with an exactly matching intrinsic.
Expand All @@ -213,16 +234,13 @@ LogicalResult setMatmulLoweringConfig(IREE::GPU::TargetAttr target,
}

LDBG("Target Subgroup size: " << targetSubgroupSize);
LDBG("Schedule: sizes [" << schedule->mSize << ", " << schedule->nSize << ", "
<< schedule->kSize << "]");
LDBG("Schedule: tile counts [" << schedule->mTileCount << ", "
<< schedule->nTileCount << ", "
<< schedule->kTileCount << "]");
LDBG("Schedule: warp counts [" << schedule->mWarpCount << ", "
<< schedule->nWarpCount << "]");
LDBG("Schedule: " << schedule);

std::array<int64_t, 3> workgroupSize{
schedule->nWarpCount * targetSubgroupSize, schedule->mWarpCount, 1};
int64_t flatWorkgroupSize =
targetSubgroupSize *
ShapedType::getNumElements(schedule->nSubgroupCounts) *
ShapedType::getNumElements(schedule->mSubgroupCounts);
std::array<int64_t, 3> workgroupSize{flatWorkgroupSize, 1, 1};

SmallVector<int64_t> workgroupTileSizes(linalgOp.getNumLoops(), 0);
SmallVector<int64_t> reductionTileSizes(linalgOp.getNumLoops(), 0);
Expand All @@ -244,16 +262,30 @@ LogicalResult setMatmulLoweringConfig(IREE::GPU::TargetAttr target,
reductionTileSizes[k] = 1;
}

// Compute the M/N dimension tile size by multiplying subgroup information.
workgroupTileSizes[mDim] = schedule->mWarpCount * schedule->mTileCount;
workgroupTileSizes[nDim] = schedule->nWarpCount * schedule->nTileCount;

// Specify the subgroup tile sizes from the mma schedule. This is applied
subgroupTileSizes[mDim] = schedule->mTileCount;
subgroupTileSizes[nDim] = schedule->nTileCount;
// Adjust the inner bound size for packing to intrinsic shapes, since tiling
// happens after packing.
assert(bounds[mDims.back()] % schedule->mSize == 0 &&
bounds[nDims.back()] % schedule->nSize == 0 &&
"expected inner bound to be evenly divisible by schedule sizes.");
bounds[mDims.back()] /= schedule->mSize;
bounds[nDims.back()] /= schedule->nSize;

// Compute the M/N dimension tile sizes by multiplying subgroup information.
for (auto [i, mDim] : llvm::enumerate(mDims)) {
workgroupTileSizes[mDim] =
schedule->mSubgroupCounts[i] * schedule->mTileSizes[i];
subgroupTileSizes[mDim] = schedule->mTileSizes[i];
}
for (auto [i, nDim] : llvm::enumerate(nDims)) {
workgroupTileSizes[nDim] =
schedule->nSubgroupCounts[i] * schedule->nTileSizes[i];
subgroupTileSizes[nDim] = schedule->nTileSizes[i];
}

// Similarly the reduction tile size is just the post-packing tile count.
reductionTileSizes[kDim] = schedule->kTileCount;
for (auto [i, kDim] : llvm::enumerate(kDims)) {
reductionTileSizes[kDim] = schedule->kTileSizes[i];
}

IREE::GPU::MmaInterfaceAttr mmaKind =
target.getWgp().getMma()[schedule->index];
Expand Down
Loading

0 comments on commit 03c744e

Please sign in to comment.