Skip to content

Commit

Permalink
[Codegen][CPU] Enable scalable transfer lowerings (#18170)
Browse files Browse the repository at this point in the history
This enables a general scalable lowering for `transfer_write(transpose)`
when ArmSME is _not_ available. The ArmSME dialect already had its own
(more specific) lowerings for cases like this, which is why these
lowerings are disabled when SME is available.

Depends on: llvm/llvm-project#101353

---------

Signed-off-by: Benjamin Maxwell <[email protected]>
  • Loading branch information
MacDue authored Aug 16, 2024
1 parent 551cd54 commit 8a1d78b
Show file tree
Hide file tree
Showing 6 changed files with 44 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ void LLVMCPULowerExecutableTargetPass::runOnOperation() {
pipelineOpts.enableVectorMasking =
isX86(target) || isRISCV(target) ||
(isAArch64(target) && hasAnySVEFeature(target));
pipelineOpts.enableAArch64SSVE =
pipelineOpts.enableAArch64SME =
isAArch64(target) && hasAnySVEFeature(target) && hasSMEFeature(target);
pipelineOpts.enableAArch64I8mm = isAArch64(target) && hasI8mmFeature(target);
pipelineOpts.enablePeeling = isLoopPeelingEnabled(funcOp);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,10 @@ void LLVMCPUVectorTransferLoweringPass::runOnOperation() {
/*maxTransferRank=*/1);
auto vectorTransferToSCFOptions =
VectorTransferToSCFOptions().enableFullUnroll();
if (enableScalableLowerings) {
vectorTransferToSCFOptions.enableLowerScalable();
}

populateVectorToSCFConversionPatterns(patterns, vectorTransferToSCFOptions);
(void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
}
Expand Down
18 changes: 16 additions & 2 deletions compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -306,7 +306,15 @@ void buildLLVMCPUVectorLoweringPipeline(
// lower them and can't be optimized away anymore.
funcPassManager.addPass(createCanonicalizerPass());

funcPassManager.addPass(createLLVMCPUVectorTransferLoweringPass());
LLVMCPUVectorTransferLoweringPassOptions transferLoweringOptions{};
if (!options.enableArmSME) {
// The ArmSME dialect has its own (more specific) lowerings for scalable
// vectors that occur later in the pipeline, so only enable the general
// lowerings if SME is not available.
transferLoweringOptions.enableScalableLowerings = true;
}
funcPassManager.addPass(
createLLVMCPUVectorTransferLoweringPass(transferLoweringOptions));
funcPassManager.addPass(createLLVMCPUVectorTransposeLoweringPass(
LLVMCPUVectorTransposeLoweringPassOptions{
options.lowerVectorTransposeToAVX2}));
Expand Down Expand Up @@ -354,6 +362,7 @@ void addCPUBufferOpsTileAndVectorizePipeline(
options.lowerVectorTransposeToAVX2 = pipelineOpt.lowerToAVX2;
options.splitVectorTransfersTo = "linalg-copy";
options.enableArmI8mm = pipelineOpt.enableAArch64I8mm;
options.enableArmSME = pipelineOpt.enableAArch64SME;
buildLLVMCPUVectorLoweringPipeline(funcPassManager, options);
}
}
Expand Down Expand Up @@ -396,7 +405,7 @@ void addMultiTilingExpertPassPipeline(OpPassManager &funcPassManager,
funcPassManager.addPass(createLLVMCPUPeelPass());
}

if (pipelineOpt.enableAArch64SSVE) {
if (pipelineOpt.enableAArch64SME) {
funcPassManager.addPass(createLLVMCPU2DScalableTo1DScalablePass());
}

Expand Down Expand Up @@ -432,6 +441,7 @@ void addMultiTilingExpertPassPipeline(OpPassManager &funcPassManager,
options.lowerVectorTransposeToAVX2 = pipelineOpt.lowerToAVX2;
options.splitVectorTransfersTo = "linalg-copy";
options.enableArmI8mm = pipelineOpt.enableAArch64I8mm;
options.enableArmSME = pipelineOpt.enableAArch64SME;
buildLLVMCPUVectorLoweringPipeline(funcPassManager, options);
}
}
Expand Down Expand Up @@ -494,6 +504,7 @@ void addConvTileAndDecomposeExpertPassPipeline(
options.lowerVectorTransposeToAVX2 = pipelineOpt.lowerToAVX2;
options.splitVectorTransfersTo = "shuffle";
options.enableArmI8mm = pipelineOpt.enableAArch64I8mm;
options.enableArmSME = pipelineOpt.enableAArch64SME;
buildLLVMCPUVectorLoweringPipeline(funcPassManager, options);
}
}
Expand Down Expand Up @@ -542,6 +553,7 @@ void addMmt4dTilingExpertPassPipeline(OpPassManager &funcPassManager,
options.lowerVectorTransposeToAVX2 = pipelineOpt.lowerToAVX2;
options.splitVectorTransfersTo = "linalg-copy";
options.enableArmI8mm = pipelineOpt.enableAArch64I8mm;
options.enableArmSME = pipelineOpt.enableAArch64SME;
buildLLVMCPUVectorLoweringPipeline(funcPassManager, options);
}

Expand Down Expand Up @@ -583,6 +595,7 @@ void addCPUDataTilingPipeline(OpPassManager &funcPassManager,
options.lowerVectorTransposeToAVX2 = pipelineOpt.lowerToAVX2;
options.splitVectorTransfersTo = "linalg-copy";
options.enableArmI8mm = pipelineOpt.enableAArch64I8mm;
options.enableArmSME = pipelineOpt.enableAArch64SME;
buildLLVMCPUVectorLoweringPipeline(funcPassManager, options);
}
}
Expand Down Expand Up @@ -623,6 +636,7 @@ void addCPULinalgExtTileAndVectorizePipeline(
options.lowerVectorTransposeToAVX2 = pipelineOpt.lowerToAVX2;
options.splitVectorTransfersTo = "linalg-copy";
options.enableArmI8mm = pipelineOpt.enableAArch64I8mm;
options.enableArmSME = pipelineOpt.enableAArch64SME;
buildLLVMCPUVectorLoweringPipeline(funcPassManager, options);
}
}
Expand Down
3 changes: 2 additions & 1 deletion compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ struct LLVMCPUVectorLoweringPassOptions {
std::string splitVectorTransfersTo = "";
bool lowerVectorTransposeToAVX2 = false;
bool enableArmI8mm = false;
bool enableArmSME = false;
};

std::unique_ptr<InterfacePass<mlir::FunctionOpInterface>>
Expand Down Expand Up @@ -72,7 +73,7 @@ struct LLVMCPUPipelineOptions {
bool useConfiguredVectorSizes = true;
bool enablePeeling = false;
bool enableVectorMasking = false;
bool enableAArch64SSVE = false;
bool enableAArch64SME = false;
bool enableAArch64I8mm = false;
bool lowerToAVX2 = false;
};
Expand Down
5 changes: 5 additions & 0 deletions compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,11 @@ def LLVMCPUVirtualVectorLoweringPass :
def LLVMCPUVectorTransferLoweringPass :
InterfacePass<"iree-llvmcpu-vector-transfer-lowering", "mlir::FunctionOpInterface"> {
let summary = "Pass to lower transfer ops to simpler ops like `vector.load`, `vector.store`, `vector.broadcast`, and a set of scf ops.";
let options = [
Option<"enableScalableLowerings", "enable-scalable-lowerings", "bool",
/*default=*/"false",
"Enables scalable vector specific transfer lowerings">,
];
}

def LLVMCPUVectorTransposeLoweringPass :
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -212,3 +212,19 @@ func.func @gather_strided_memref() {
// CHECK-LABEL: func.func @gather_strided_memref
// CHECK-NOT: memref.subview {{.*}} : memref<2592000xf32, strided<[3]>
// CHECK-NOT: vector.gather %subview[%c0] [%7], %cst_0, %cst : memref<2592000xf32, strided<[3]>

// -----

func.func @scalable_transpose_store(%vec: vector<4x[4]xf32>, %dest: memref<?x?xf32>, %i: index, %j: index) {
%transpose = vector.transpose %vec, [1, 0] : vector<4x[4]xf32> to vector<[4]x4xf32>
vector.transfer_write %transpose, %dest[%i, %j] {in_bounds = [true, true]} : vector<[4]x4xf32>, memref<?x?xf32>
return
}

/// Note: The lowering for this is implemented/tested upstream (this just checks
/// it is enabled in IREE).

// CHECK-LABEL: func.func @scalable_transpose_store
// CHECK-NOT: vector.transpose
// CHECK: vector.store {{.*}} : memref<?x?xf32>, vector<4xf32>
// CHECK-NOT: vector.transpose

0 comments on commit 8a1d78b

Please sign in to comment.