From 96d96242ad6face5c603840c59a7d13e279ac2f9 Mon Sep 17 00:00:00 2001 From: Thomas Date: Tue, 9 Nov 2021 08:19:33 -0800 Subject: [PATCH] [LLVMGPU][NFC] Unify hal.interface conversion for ROCM and CUDA (#7568) --- .../Codegen/LLVMGPU/ConvertToLLVM.cpp | 58 +++++-------------- iree/compiler/Codegen/LLVMGPU/ConvertToLLVM.h | 5 +- .../Codegen/LLVMGPU/ConvertToNVVM.cpp | 4 +- .../Codegen/LLVMGPU/ConvertToROCDL.cpp | 4 +- 4 files changed, 21 insertions(+), 50 deletions(-) diff --git a/iree/compiler/Codegen/LLVMGPU/ConvertToLLVM.cpp b/iree/compiler/Codegen/LLVMGPU/ConvertToLLVM.cpp index dd54ba92baf2..116680f11625 100644 --- a/iree/compiler/Codegen/LLVMGPU/ConvertToLLVM.cpp +++ b/iree/compiler/Codegen/LLVMGPU/ConvertToLLVM.cpp @@ -336,8 +336,8 @@ class ConvertIREEConstantOp : public ConvertToLLVMPattern { }; /// A pattern to convert hal.interface.workgroup.id/count/size into -/// corresponding NVVM/ROCDL ops. -template +/// corresponding GPU ops. +template struct HALInterfaceWorkgroupOpsConverter final : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; @@ -345,27 +345,9 @@ struct HALInterfaceWorkgroupOpsConverter final LogicalResult matchAndRewrite( InterfaceOpTy op, typename InterfaceOpTy::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { - Location loc = op.getLoc(); - Type i32Type = rewriter.getI32Type(); - Value newOp; int32_t index = static_cast(op.dimension().getSExtValue()); - switch (index) { - case 0: - newOp = rewriter.create(loc, i32Type); - break; - case 1: - newOp = rewriter.create(loc, i32Type); - break; - case 2: - newOp = rewriter.create(loc, i32Type); - break; - default: - return failure(); - } - - newOp = - rewriter.create(loc, rewriter.getIntegerType(64), newOp); - rewriter.replaceOp(op, {newOp}); + std::array dimAttr{"x", "y", "z"}; + rewriter.replaceOpWithNewOp(op, op.getType(), dimAttr[index]); return success(); } }; @@ -374,31 +356,9 @@ struct HALInterfaceWorkgroupOpsConverter final void populateLLVMConversionPatterns(MLIRContext *context, OwningRewritePatternList &patterns, - LLVMTypeConverter &converter, - bool useROCM) { + LLVMTypeConverter &converter) { patterns.insert( context, converter); - if (useROCM) { - patterns.insert, - HALInterfaceWorkgroupOpsConverter< - IREE::HAL::InterfaceWorkgroupCountOp, ROCDL::GridDimXOp, - ROCDL::GridDimYOp, ROCDL::GridDimZOp>, - HALInterfaceWorkgroupOpsConverter< - IREE::HAL::InterfaceWorkgroupSizeOp, ROCDL::BlockDimXOp, - ROCDL::BlockDimYOp, ROCDL::BlockDimZOp>>(context); - } else { - patterns.insert, - HALInterfaceWorkgroupOpsConverter< - IREE::HAL::InterfaceWorkgroupCountOp, NVVM::GridDimXOp, - NVVM::GridDimYOp, NVVM::GridDimZOp>, - HALInterfaceWorkgroupOpsConverter< - IREE::HAL::InterfaceWorkgroupSizeOp, NVVM::BlockDimXOp, - NVVM::BlockDimYOp, NVVM::BlockDimZOp>>(context); - } } void populateScalarizeMathOps(RewritePatternSet &patterns) { @@ -418,6 +378,14 @@ void populateConvertSharedMemoryAllocOps(RewritePatternSet &patterns) { patterns.add(patterns.getContext()); } +void populateLowerHALInterfaceOp(RewritePatternSet &patterns) { + patterns.insert, + HALInterfaceWorkgroupOpsConverter< + IREE::HAL::InterfaceWorkgroupCountOp, gpu::GridDimOp>>( + patterns.getContext()); +} + std::unique_ptr> createTestLLVMGPULegalizePass() { return std::make_unique(); } diff --git a/iree/compiler/Codegen/LLVMGPU/ConvertToLLVM.h b/iree/compiler/Codegen/LLVMGPU/ConvertToLLVM.h index 62ed63288de2..ba99be75bd6f 100644 --- a/iree/compiler/Codegen/LLVMGPU/ConvertToLLVM.h +++ b/iree/compiler/Codegen/LLVMGPU/ConvertToLLVM.h @@ -14,10 +14,13 @@ namespace iree_compiler { void populateLLVMConversionPatterns(MLIRContext *context, OwningRewritePatternList &patterns, - LLVMTypeConverter &converter, bool useROCM); + LLVMTypeConverter &converter); void populateScalarizeMathOps(RewritePatternSet &patterns); +/// Lower hal.interface ops to the equivalent gpu ops. +void populateLowerHALInterfaceOp(RewritePatternSet &patterns); + /// Add patterns to convert AllocOp of shared memory to a global variable. void populateConvertSharedMemoryAllocOps(RewritePatternSet &patterns); diff --git a/iree/compiler/Codegen/LLVMGPU/ConvertToNVVM.cpp b/iree/compiler/Codegen/LLVMGPU/ConvertToNVVM.cpp index 5cf1c24434b7..0d4a4398346e 100644 --- a/iree/compiler/Codegen/LLVMGPU/ConvertToNVVM.cpp +++ b/iree/compiler/Codegen/LLVMGPU/ConvertToNVVM.cpp @@ -54,6 +54,7 @@ struct ConvertToNVVMPass : public ConvertToNVVMBase { OwningRewritePatternList patterns(&getContext()); populateScalarizeMathOps(patterns); populateConvertSharedMemoryAllocOps(patterns); + populateLowerHALInterfaceOp(patterns); vector::populateVectorToVectorCanonicalizationPatterns(patterns); vector::populateVectorBroadcastLoweringPatterns(patterns); vector::populateVectorContractLoweringPatterns( @@ -73,8 +74,7 @@ struct ConvertToNVVMPass : public ConvertToNVVMBase { } { OwningRewritePatternList llvmPatterns(&getContext()); - populateLLVMConversionPatterns(&getContext(), llvmPatterns, converter, - false); + populateLLVMConversionPatterns(&getContext(), llvmPatterns, converter); populateMathToLLVMConversionPatterns(converter, llvmPatterns); populateMemRefToLLVMConversionPatterns(converter, llvmPatterns); populateStdToLLVMConversionPatterns(converter, llvmPatterns); diff --git a/iree/compiler/Codegen/LLVMGPU/ConvertToROCDL.cpp b/iree/compiler/Codegen/LLVMGPU/ConvertToROCDL.cpp index 465db82a02ef..009e23f5e4e1 100644 --- a/iree/compiler/Codegen/LLVMGPU/ConvertToROCDL.cpp +++ b/iree/compiler/Codegen/LLVMGPU/ConvertToROCDL.cpp @@ -54,6 +54,7 @@ struct ConvertToROCDLPass : public ConvertToROCDLBase { OwningRewritePatternList patterns(&getContext()); populateScalarizeMathOps(patterns); populateConvertSharedMemoryAllocOps(patterns); + populateLowerHALInterfaceOp(patterns); vector::populateVectorToVectorCanonicalizationPatterns(patterns); vector::populateVectorBroadcastLoweringPatterns(patterns); vector::populateVectorContractLoweringPatterns( @@ -73,8 +74,7 @@ struct ConvertToROCDLPass : public ConvertToROCDLBase { } { OwningRewritePatternList llvmPatterns(&getContext()); - populateLLVMConversionPatterns(&getContext(), llvmPatterns, converter, - true); + populateLLVMConversionPatterns(&getContext(), llvmPatterns, converter); populateMathToLLVMConversionPatterns(converter, llvmPatterns); populateMemRefToLLVMConversionPatterns(converter, llvmPatterns); populateStdToLLVMConversionPatterns(converter, llvmPatterns);