Skip to content

Commit

Permalink
[Transform] Introduce microkernel dialect optimization passes (#296)
Browse files Browse the repository at this point in the history
  • Loading branch information
Haixin Huang authored Sep 5, 2024
1 parent 23269d7 commit bc0014b
Show file tree
Hide file tree
Showing 16 changed files with 2,279 additions and 62 deletions.
31 changes: 31 additions & 0 deletions include/gc/Transforms/Microkernel/MicrokernelPasses.td
Original file line number Diff line number Diff line change
Expand Up @@ -76,4 +76,35 @@ def ConvertMicrokernelToDnnlFunc: Pass<"convert-microkernel-to-dnnl-func", "::ml
"microkernel::MicrokernelDialect"];
}

def EarlyDispatchMicrokernel: Pass<"early-dispatch-microkernel", "::mlir::ModuleOp"> {
let summary = "Early dispatch microkernel during compile time";
let description = [{
Early dispatch microkernel during compile time.
}];
let dependentDialects = ["func::FuncDialect",
"memref::MemRefDialect",
"LLVM::LLVMDialect",
"microkernel::MicrokernelDialect"];
}

def MergeBranchMicrokernelContext: Pass<"merge-branch-microkernel-context", "::mlir::ModuleOp"> {
let summary = "Find and merge identical microkernel context operations in branches into one";
let description = [{
Find and merge identical microkernel context operations in branches into one.
}];
let dependentDialects = ["func::FuncDialect",
"memref::MemRefDialect"];
}

def MicrokernelInvariantCodeMotion: Pass<"microkernel-invariant-code-motion", "::mlir::ModuleOp"> {
let summary = "Hoist invariant microkernel code to avoid redundant execution";
let description = [{
Hoist invariant microkernel code to avoid redundant execution.
}];
let dependentDialects = ["func::FuncDialect",
"memref::MemRefDialect",
"LLVM::LLVMDialect",
"microkernel::MicrokernelDialect"];
}

#endif // GC_DIALECT_MICROKERNELPASSES
95 changes: 53 additions & 42 deletions lib/gc/Dialect/Microkernel/MicrokernelOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,42 @@ static LogicalResult verifyBrgemmFlags(ArrayAttr flags, Operation *op,
return success();
}

static bool isTypeSupported(Type outType, Type operandAType,
Type operandBType) {
if (!outType.isF32() && !outType.isSignedInteger(32))
return false;

if (outType.isF32()) {
if (!(operandAType.isF32() && operandBType.isF32()) &&
!(operandAType.isBF16() && operandBType.isBF16()))
return false;
}
if (outType.isSignedInteger(32)) {
if (!(operandAType.isSignedInteger(8) ||
operandAType.isUnsignedInteger(8)) &&
(operandBType.isSignedInteger(8) || operandBType.isUnsignedInteger(8)))
return false;
}
return true;
}

// TODO(haixin): could use compiler-wide VNNI utils?
static bool isInVnniLayout(ShapedType type) {
if (!type.getElementType().isBF16() &&
!type.getElementType().isSignedInteger(8) &&
!type.getElementType().isUnsignedInteger(8))
return false;

auto blockingFactor = 0;
if (type.getElementType().isBF16())
blockingFactor = 2;
else if (type.getElementType().isSignedInteger(8) ||
type.getElementType().isUnsignedInteger(8))
blockingFactor = 4;

return type.getShape().back() == blockingFactor;
}

/////////////////////////////////////////////////////
// Start of BrgemmOp

Expand Down Expand Up @@ -308,9 +344,8 @@ static inline ArrayRef<int64_t> getShapedValueShape(Value val) {
assert((llvm::isa<TensorType>(val.getType()) ||
llvm::isa<MemRefType>(val.getType())) &&
"Expecting shaped value");
if (auto tensorTy = dyn_cast_or_null<TensorType>(val.getType())) {
if (auto tensorTy = dyn_cast_or_null<TensorType>(val.getType()))
return tensorTy.getShape();
}
auto memrefTy = dyn_cast_or_null<MemRefType>(val.getType());
return memrefTy.getShape();
}
Expand All @@ -331,15 +366,27 @@ LogicalResult BrgemmOp::verify() {
return op.emitOpError()
<< "expect inputs and its related info to be size 2\n";

auto elemTypeA = getElementTypeOrSelf(ins[0]);
auto elemTypeB = getElementTypeOrSelf(ins[1]);
auto elemTypeC = getElementTypeOrSelf(out);
if (!isTypeSupported(elemTypeC, elemTypeA, elemTypeB))
return op.emitOpError() << "unsupported input matrix types\n";

ArrayRef<int64_t> dimA = getShapedValueShape(ins[0]);
ArrayRef<int64_t> dimB = getShapedValueShape(ins[1]);
ArrayRef<int64_t> dimC = getShapedValueShape(out);
if (dimA.size() != 3)
return op.emitOpError() << "expect input A to be 3D\n";
if (dimB.size() != 3 && dimB.size() != 4)
return op.emitOpError() << "expect input B to be 3D or 4D\n";
if (dimB.size() == 4 && (dimB[3] != 2 && dimB[3] != 4))
return op.emitOpError() << "expect input B vnni step to be 2 or 4\n";
if (!elemTypeB.isF32()) {
if (dimB.size() != 4 ||
!isInVnniLayout(dyn_cast<ShapedType>(ins[1].getType())))
return op.emitOpError()
<< "expect a 4d VNNI input B for non-F32 operand: " << ins[1];
} else {
if (dimB.size() != 3)
return op.emitOpError()
<< "expect a 3d input B for F32 operand: " << ins[1];
}
if (dimC.size() != 2)
return op.emitOpError() << "expect input C to be 2D\n";
for (auto dim : batchDims)
Expand Down Expand Up @@ -558,42 +605,6 @@ LogicalResult BrgemmDispatchOp::verify() {
/////////////////////////////////////////////////////
// Start of BrgemmExecuteOp

// TODO(haixin): could use compiler-wide VNNI utils?
static bool isInVnniLayout(MemRefType memref) {
if (!memref.getElementType().isBF16() &&
!memref.getElementType().isSignedInteger(8) &&
!memref.getElementType().isUnsignedInteger(8))
return false;

auto blockingFactor = 0;
if (memref.getElementType().isBF16())
blockingFactor = 2;
else if (memref.getElementType().isSignedInteger(8) ||
memref.getElementType().isUnsignedInteger(8))
blockingFactor = 4;

return memref.getShape().back() == blockingFactor;
}

static bool isTypeSupported(Type outType, Type operandAType,
Type operandBType) {
if (!outType.isF32() && !outType.isSignedInteger(32))
return false;

if (outType.isF32()) {
if (!(operandAType.isF32() && operandBType.isF32()) &&
!(operandAType.isBF16() && operandBType.isBF16()))
return false;
}
if (outType.isSignedInteger(32)) {
if (!(operandAType.isSignedInteger(8) ||
operandAType.isUnsignedInteger(8)) &&
(operandBType.isSignedInteger(8) || operandBType.isUnsignedInteger(8)))
return false;
}
return true;
}

LogicalResult BrgemmExecuteOp::verify() {
BrgemmExecuteOp &brgemmOp = *this;

Expand Down
1 change: 1 addition & 0 deletions lib/gc/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ add_subdirectory(Utils)
gc_set_mlir_link_components(MLIR_LINK_COMPONENTS
MLIRIR
MLIRSupport
MLIRMicrokernelTransforms
MLIRBufferizationToMemRef
MLIRBufferizationPipelines)

Expand Down
5 changes: 4 additions & 1 deletion lib/gc/Transforms/Microkernel/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
gc_set_mlir_link_components(MLIR_LINK_COMPONENTS MLIRIR)
gc_set_mlir_link_components(MLIR_LINK_COMPONENTS MLIRIR MLIRMicrokernel)

include(onednn)

gc_add_mlir_dialect_library(MLIRMicrokernelTransforms
ConvertLinalgToMicrokernel.cpp
ExpandMicrokernel.cpp
ConvertMicrokernelToDnnlFunc.cpp
EarlyDispatchMicrokernel.cpp
MicrokernelInvariantCodeMotion.cpp
MergeBranchMicrokernelContext.cpp

ADDITIONAL_HEADER_DIRS
${PROJECT_SOURCE_DIR}/include/
Expand Down
20 changes: 17 additions & 3 deletions lib/gc/Transforms/Microkernel/ConvertLinalgToMicrokernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,23 @@ static FailureOr<BrgemmDims> inferBrgemmDims(linalg::LinalgOp linalgOp) {
else
return failure();

OpOperand *operandA = linalgOp.getDpsInputOperands()[0];
OpOperand *operandB = linalgOp.getDpsInputOperands()[1];
Type operandBElemType = getElementTypeOrSelf(operandB->get());
if (operandBElemType.isF32()) {
if (kAffinePos.size() == 2) {
LLVM_DEBUG(llvm::dbgs() << "[checkStructure] Wrong dimensions for input "
"B, should be non-VNNI\n");
return failure();
}
} else {
if (kAffinePos.size() == 1) {
LLVM_DEBUG(llvm::dbgs() << "[checkStructure] Wrong dimensions for input "
"B, should be VNNI\n");
return failure();
}
}

LLVM_DEBUG(llvm::dbgs() << "[inferBrgemmDims] Candidate dims: "
<< "\n");
LLVM_DEBUG(llvm::dbgs() << "[inferBrgemmDims] m pos in affine: " << mAffinePos
Expand All @@ -169,9 +186,6 @@ static FailureOr<BrgemmDims> inferBrgemmDims(linalg::LinalgOp linalgOp) {
LLVM_DEBUG(llvm::dbgs() << "[inferBrgemmDims] batch pos in affine: "
<< batchAffinePos << "\n");

OpOperand *operandA = linalgOp.getDpsInputOperands()[0];
OpOperand *operandB = linalgOp.getDpsInputOperands()[1];

BrgemmDims brgemmDims;

#define CHECK_GET_POS_IN_DOMAIN(dim, dimPos, operand) \
Expand Down
Loading

0 comments on commit bc0014b

Please sign in to comment.