Skip to content

Commit

Permalink
[LLVMGPU][NFC] Unify hal.interface conversion for ROCM and CUDA (#7568)
Browse files Browse the repository at this point in the history
  • Loading branch information
ThomasRaoux authored Nov 9, 2021
1 parent 68aeb86 commit 96d9624
Show file tree
Hide file tree
Showing 4 changed files with 21 additions and 50 deletions.
58 changes: 13 additions & 45 deletions iree/compiler/Codegen/LLVMGPU/ConvertToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -336,36 +336,18 @@ class ConvertIREEConstantOp : public ConvertToLLVMPattern {
};

/// A pattern to convert hal.interface.workgroup.id/count/size into
/// corresponding NVVM/ROCDL ops.
template <typename InterfaceOpTy, typename XOp, typename YOp, typename ZOp>
/// corresponding GPU ops.
template <typename InterfaceOpTy, typename NewOpTy>
struct HALInterfaceWorkgroupOpsConverter final
: public OpConversionPattern<InterfaceOpTy> {
using OpConversionPattern<InterfaceOpTy>::OpConversionPattern;

LogicalResult matchAndRewrite(
InterfaceOpTy op, typename InterfaceOpTy::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Location loc = op.getLoc();
Type i32Type = rewriter.getI32Type();
Value newOp;
int32_t index = static_cast<int32_t>(op.dimension().getSExtValue());
switch (index) {
case 0:
newOp = rewriter.create<XOp>(loc, i32Type);
break;
case 1:
newOp = rewriter.create<YOp>(loc, i32Type);
break;
case 2:
newOp = rewriter.create<ZOp>(loc, i32Type);
break;
default:
return failure();
}

newOp =
rewriter.create<LLVM::SExtOp>(loc, rewriter.getIntegerType(64), newOp);
rewriter.replaceOp(op, {newOp});
std::array<const char *, 3> dimAttr{"x", "y", "z"};
rewriter.replaceOpWithNewOp<NewOpTy>(op, op.getType(), dimAttr[index]);
return success();
}
};
Expand All @@ -374,31 +356,9 @@ struct HALInterfaceWorkgroupOpsConverter final

void populateLLVMConversionPatterns(MLIRContext *context,
OwningRewritePatternList &patterns,
LLVMTypeConverter &converter,
bool useROCM) {
LLVMTypeConverter &converter) {
patterns.insert<ConvertFunc, ConvertIREEBindingOp, ConvertIREEConstantOp>(
context, converter);
if (useROCM) {
patterns.insert<HALInterfaceWorkgroupOpsConverter<
IREE::HAL::InterfaceWorkgroupIDOp, ROCDL::BlockIdXOp,
ROCDL::BlockIdYOp, ROCDL::BlockIdZOp>,
HALInterfaceWorkgroupOpsConverter<
IREE::HAL::InterfaceWorkgroupCountOp, ROCDL::GridDimXOp,
ROCDL::GridDimYOp, ROCDL::GridDimZOp>,
HALInterfaceWorkgroupOpsConverter<
IREE::HAL::InterfaceWorkgroupSizeOp, ROCDL::BlockDimXOp,
ROCDL::BlockDimYOp, ROCDL::BlockDimZOp>>(context);
} else {
patterns.insert<HALInterfaceWorkgroupOpsConverter<
IREE::HAL::InterfaceWorkgroupIDOp, NVVM::BlockIdXOp,
NVVM::BlockIdYOp, NVVM::BlockIdZOp>,
HALInterfaceWorkgroupOpsConverter<
IREE::HAL::InterfaceWorkgroupCountOp, NVVM::GridDimXOp,
NVVM::GridDimYOp, NVVM::GridDimZOp>,
HALInterfaceWorkgroupOpsConverter<
IREE::HAL::InterfaceWorkgroupSizeOp, NVVM::BlockDimXOp,
NVVM::BlockDimYOp, NVVM::BlockDimZOp>>(context);
}
}

void populateScalarizeMathOps(RewritePatternSet &patterns) {
Expand All @@ -418,6 +378,14 @@ void populateConvertSharedMemoryAllocOps(RewritePatternSet &patterns) {
patterns.add<ConvertSharedMemAllocOp>(patterns.getContext());
}

void populateLowerHALInterfaceOp(RewritePatternSet &patterns) {
patterns.insert<HALInterfaceWorkgroupOpsConverter<
IREE::HAL::InterfaceWorkgroupIDOp, gpu::BlockIdOp>,
HALInterfaceWorkgroupOpsConverter<
IREE::HAL::InterfaceWorkgroupCountOp, gpu::GridDimOp>>(
patterns.getContext());
}

std::unique_ptr<OperationPass<ModuleOp>> createTestLLVMGPULegalizePass() {
return std::make_unique<TestLLVMGPULegalizeOpPass>();
}
Expand Down
5 changes: 4 additions & 1 deletion iree/compiler/Codegen/LLVMGPU/ConvertToLLVM.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,13 @@ namespace iree_compiler {

void populateLLVMConversionPatterns(MLIRContext *context,
OwningRewritePatternList &patterns,
LLVMTypeConverter &converter, bool useROCM);
LLVMTypeConverter &converter);

void populateScalarizeMathOps(RewritePatternSet &patterns);

/// Lower hal.interface ops to the equivalent gpu ops.
void populateLowerHALInterfaceOp(RewritePatternSet &patterns);

/// Add patterns to convert AllocOp of shared memory to a global variable.
void populateConvertSharedMemoryAllocOps(RewritePatternSet &patterns);

Expand Down
4 changes: 2 additions & 2 deletions iree/compiler/Codegen/LLVMGPU/ConvertToNVVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ struct ConvertToNVVMPass : public ConvertToNVVMBase<ConvertToNVVMPass> {
OwningRewritePatternList patterns(&getContext());
populateScalarizeMathOps(patterns);
populateConvertSharedMemoryAllocOps(patterns);
populateLowerHALInterfaceOp(patterns);
vector::populateVectorToVectorCanonicalizationPatterns(patterns);
vector::populateVectorBroadcastLoweringPatterns(patterns);
vector::populateVectorContractLoweringPatterns(
Expand All @@ -73,8 +74,7 @@ struct ConvertToNVVMPass : public ConvertToNVVMBase<ConvertToNVVMPass> {
}
{
OwningRewritePatternList llvmPatterns(&getContext());
populateLLVMConversionPatterns(&getContext(), llvmPatterns, converter,
false);
populateLLVMConversionPatterns(&getContext(), llvmPatterns, converter);
populateMathToLLVMConversionPatterns(converter, llvmPatterns);
populateMemRefToLLVMConversionPatterns(converter, llvmPatterns);
populateStdToLLVMConversionPatterns(converter, llvmPatterns);
Expand Down
4 changes: 2 additions & 2 deletions iree/compiler/Codegen/LLVMGPU/ConvertToROCDL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ struct ConvertToROCDLPass : public ConvertToROCDLBase<ConvertToROCDLPass> {
OwningRewritePatternList patterns(&getContext());
populateScalarizeMathOps(patterns);
populateConvertSharedMemoryAllocOps(patterns);
populateLowerHALInterfaceOp(patterns);
vector::populateVectorToVectorCanonicalizationPatterns(patterns);
vector::populateVectorBroadcastLoweringPatterns(patterns);
vector::populateVectorContractLoweringPatterns(
Expand All @@ -73,8 +74,7 @@ struct ConvertToROCDLPass : public ConvertToROCDLBase<ConvertToROCDLPass> {
}
{
OwningRewritePatternList llvmPatterns(&getContext());
populateLLVMConversionPatterns(&getContext(), llvmPatterns, converter,
true);
populateLLVMConversionPatterns(&getContext(), llvmPatterns, converter);
populateMathToLLVMConversionPatterns(converter, llvmPatterns);
populateMemRefToLLVMConversionPatterns(converter, llvmPatterns);
populateStdToLLVMConversionPatterns(converter, llvmPatterns);
Expand Down

0 comments on commit 96d9624

Please sign in to comment.