From 0a8f58f04e27c26de658f3faa69d03f5ad7d2910 Mon Sep 17 00:00:00 2001 From: Dimple Prajapati Date: Mon, 30 Sep 2024 12:39:13 -0700 Subject: [PATCH] [mlir][spirv] Add gpu printf op lowering to spirv.CL.printf op (#78510) This change contains following: - adds lowering of printf op to spirv.CL.printf op in GPUToSPIRV pass. - Fixes Constant decoration parsing for spirv GlobalVariable. - minor modification to spirv.CL.printf op assembly format. --------- Co-authored-by: Jakub Kuderski --- .../mlir/Dialect/SPIRV/IR/SPIRVCLOps.td | 4 +- mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp | 130 +++++++++++++++++- .../SPIRV/Deserialization/Deserializer.cpp | 1 + .../Target/SPIRV/Serialization/Serializer.cpp | 1 + mlir/test/Conversion/GPUToSPIRV/printf.mlir | 71 ++++++++++ mlir/test/Dialect/SPIRV/IR/ocl-ops.mlir | 6 +- 6 files changed, 207 insertions(+), 6 deletions(-) create mode 100644 mlir/test/Conversion/GPUToSPIRV/printf.mlir diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCLOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCLOps.td index c7c2fe8bc742c12..5d086325fa5b1c7 100644 --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCLOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCLOps.td @@ -875,7 +875,7 @@ def SPIRV_CLPrintfOp : SPIRV_CLOp<"printf", 184, []> { #### Example: ```mlir - %0 = spirv.CL.printf %0 %1 %2 : (!spirv.ptr, (i32, i32)) -> i32 + %0 = spirv.CL.printf %fmt %1, %2 : !spirv.ptr, i32, i32 -> i32 ``` }]; @@ -889,7 +889,7 @@ def SPIRV_CLPrintfOp : SPIRV_CLOp<"printf", 184, []> { ); let assemblyFormat = [{ - $format `,` $arguments attr-dict `:` `(` type($format) `,` `(` type($arguments) `)` `)` `->` type($result) + $format ( $arguments^ )? attr-dict `:` type($format) ( `,` type($arguments)^ )? `->` type($result) }]; let hasVerifier = 0; diff --git a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp index a8ff9247e796ab9..53b4c720ae56d20 100644 --- a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp +++ b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp @@ -121,6 +121,15 @@ class GPUShuffleConversion final : public OpConversionPattern { ConversionPatternRewriter &rewriter) const override; }; +class GPUPrintfConversion final : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(gpu::PrintfOp gpuPrintfOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override; +}; + } // namespace //===----------------------------------------------------------------------===// @@ -597,6 +606,124 @@ class GPUSubgroupReduceConversion final } }; +// Formulate a unique variable/constant name after +// searching in the module for existing variable/constant names. +// This is to avoid name collision with existing variables. +// Example: printfMsg0, printfMsg1, printfMsg2, ... +static std::string makeVarName(spirv::ModuleOp moduleOp, llvm::Twine prefix) { + std::string name; + unsigned number = 0; + + do { + name.clear(); + name = (prefix + llvm::Twine(number++)).str(); + } while (moduleOp.lookupSymbol(name)); + + return name; +} + +/// Pattern to convert a gpu.printf op into a SPIR-V CLPrintf op. + +LogicalResult GPUPrintfConversion::matchAndRewrite( + gpu::PrintfOp gpuPrintfOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + + Location loc = gpuPrintfOp.getLoc(); + + auto moduleOp = gpuPrintfOp->getParentOfType(); + if (!moduleOp) + return failure(); + + // SPIR-V global variable is used to initialize printf + // format string value, if there are multiple printf messages, + // each global var needs to be created with a unique name. + std::string globalVarName = makeVarName(moduleOp, llvm::Twine("printfMsg")); + spirv::GlobalVariableOp globalVar; + + IntegerType i8Type = rewriter.getI8Type(); + IntegerType i32Type = rewriter.getI32Type(); + + // Each character of printf format string is + // stored as a spec constant. We need to create + // unique name for this spec constant like + // @printfMsg0_sc0, @printfMsg0_sc1, ... by searching in the module + // for existing spec constant names. + auto createSpecConstant = [&](unsigned value) { + auto attr = rewriter.getI8IntegerAttr(value); + std::string specCstName = + makeVarName(moduleOp, llvm::Twine(globalVarName) + "_sc"); + + return rewriter.create( + loc, rewriter.getStringAttr(specCstName), attr); + }; + { + Operation *parent = + SymbolTable::getNearestSymbolTable(gpuPrintfOp->getParentOp()); + + ConversionPatternRewriter::InsertionGuard guard(rewriter); + + Block &entryBlock = *parent->getRegion(0).begin(); + rewriter.setInsertionPointToStart( + &entryBlock); // insertion point at module level + + // Create Constituents with SpecConstant by scanning format string + // Each character of format string is stored as a spec constant + // and then these spec constants are used to create a + // SpecConstantCompositeOp. + llvm::SmallString<20> formatString(adaptor.getFormat()); + formatString.push_back('\0'); // Null terminate for C. + SmallVector constituents; + for (char c : formatString) { + spirv::SpecConstantOp cSpecConstantOp = createSpecConstant(c); + constituents.push_back(SymbolRefAttr::get(cSpecConstantOp)); + } + + // Create SpecConstantCompositeOp to initialize the global variable + size_t contentSize = constituents.size(); + auto globalType = spirv::ArrayType::get(i8Type, contentSize); + spirv::SpecConstantCompositeOp specCstComposite; + // There will be one SpecConstantCompositeOp per printf message/global var, + // so no need do lookup for existing ones. + std::string specCstCompositeName = + (llvm::Twine(globalVarName) + "_scc").str(); + + specCstComposite = rewriter.create( + loc, TypeAttr::get(globalType), + rewriter.getStringAttr(specCstCompositeName), + rewriter.getArrayAttr(constituents)); + + auto ptrType = spirv::PointerType::get( + globalType, spirv::StorageClass::UniformConstant); + + // Define a GlobalVarOp initialized using specialized constants + // that is used to specify the printf format string + // to be passed to the SPIRV CLPrintfOp. + globalVar = rewriter.create( + loc, ptrType, globalVarName, FlatSymbolRefAttr::get(specCstComposite)); + + globalVar->setAttr("Constant", rewriter.getUnitAttr()); + } + // Get SSA value of Global variable and create pointer to i8 to point to + // the format string. + Value globalPtr = rewriter.create(loc, globalVar); + Value fmtStr = rewriter.create( + loc, + spirv::PointerType::get(i8Type, spirv::StorageClass::UniformConstant), + globalPtr); + + // Get printf arguments. + auto printfArgs = llvm::to_vector_of(adaptor.getArgs()); + + rewriter.create(loc, i32Type, fmtStr, printfArgs); + + // Need to erase the gpu.printf op as gpu.printf does not use result vs + // spirv::CLPrintfOp has i32 resultType so cannot replace with new SPIR-V + // printf op. + rewriter.eraseOp(gpuPrintfOp); + + return success(); +} + //===----------------------------------------------------------------------===// // GPU To SPIRV Patterns. //===----------------------------------------------------------------------===// @@ -620,5 +747,6 @@ void mlir::populateGPUToSPIRVPatterns(SPIRVTypeConverter &typeConverter, SingleDimLaunchConfigConversion, WorkGroupSizeConversion, GPUAllReduceConversion, - GPUSubgroupReduceConversion>(typeConverter, patterns.getContext()); + GPUSubgroupReduceConversion, GPUPrintfConversion>(typeConverter, + patterns.getContext()); } diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp index 38293f7106a05a5..6c7fe41069824fc 100644 --- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp +++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp @@ -319,6 +319,7 @@ LogicalResult spirv::Deserializer::processDecoration(ArrayRef words) { case spirv::Decoration::Restrict: case spirv::Decoration::RestrictPointer: case spirv::Decoration::NoContraction: + case spirv::Decoration::Constant: if (words.size() != 2) { return emitError(unknownLoc, "OpDecoration with ") << decorationName << "needs a single target "; diff --git a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp index 7719eb68b2c2e01..f355982e9ed8841 100644 --- a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp +++ b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp @@ -286,6 +286,7 @@ LogicalResult Serializer::processDecorationAttr(Location loc, uint32_t resultID, case spirv::Decoration::Restrict: case spirv::Decoration::RestrictPointer: case spirv::Decoration::NoContraction: + case spirv::Decoration::Constant: // For unit attributes and decoration attributes, the args list // has no values so we do nothing. if (isa(attr)) diff --git a/mlir/test/Conversion/GPUToSPIRV/printf.mlir b/mlir/test/Conversion/GPUToSPIRV/printf.mlir new file mode 100644 index 000000000000000..bc091124ea4c6fc --- /dev/null +++ b/mlir/test/Conversion/GPUToSPIRV/printf.mlir @@ -0,0 +1,71 @@ +// RUN: mlir-opt -allow-unregistered-dialect -split-input-file -convert-gpu-to-spirv -verify-diagnostics %s | FileCheck %s + +module attributes { + gpu.container_module, + spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>> +} { + func.func @main() { + %c1 = arith.constant 1 : index + + gpu.launch_func @kernels::@printf + blocks in (%c1, %c1, %c1) threads in (%c1, %c1, %c1) + args() + return + } + + gpu.module @kernels { + // CHECK: spirv.module @{{.*}} Physical32 OpenCL + // CHECK-DAG: spirv.SpecConstant [[SPECCST:@.*]] = {{.*}} : i8 + // CHECK-DAG: spirv.SpecConstantComposite [[SPECCSTCOMPOSITE:@.*]] ([[SPECCST]], {{.*}}) : !spirv.array<[[ARRAYSIZE:.*]] x i8> + // CHECK-DAG: spirv.GlobalVariable [[PRINTMSG:@.*]] initializer([[SPECCSTCOMPOSITE]]) {Constant} : !spirv.ptr, UniformConstant> + gpu.func @printf() kernel + attributes + {spirv.entry_point_abi = #spirv.entry_point_abi<>} { + // CHECK: [[FMTSTR_ADDR:%.*]] = spirv.mlir.addressof [[PRINTMSG]] : !spirv.ptr, UniformConstant> + // CHECK-NEXT: [[FMTSTR_PTR:%.*]] = spirv.Bitcast [[FMTSTR_ADDR]] : !spirv.ptr, UniformConstant> to !spirv.ptr + // CHECK-NEXT {{%.*}} = spirv.CL.printf [[FMTSTR_PTR]] : !spirv.ptr -> i32 + gpu.printf "\nHello\n" + // CHECK: spirv.Return + gpu.return + } + } +} + +// ----- + +module attributes { + gpu.container_module, + spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>> +} { + func.func @main() { + %c1 = arith.constant 1 : index + %c100 = arith.constant 100: i32 + %cst_f32 = arith.constant 314.4: f32 + + gpu.launch_func @kernels1::@printf_args + blocks in (%c1, %c1, %c1) threads in (%c1, %c1, %c1) + args(%c100: i32, %cst_f32: f32) + return + } + + gpu.module @kernels1 { + // CHECK: spirv.module @{{.*}} Physical32 OpenCL { + // CHECK-DAG: spirv.SpecConstant [[SPECCST:@.*]] = {{.*}} : i8 + // CHECK-DAG: spirv.SpecConstantComposite [[SPECCSTCOMPOSITE:@.*]] ([[SPECCST]], {{.*}}) : !spirv.array<[[ARRAYSIZE:.*]] x i8> + // CHECK-DAG: spirv.GlobalVariable [[PRINTMSG:@.*]] initializer([[SPECCSTCOMPOSITE]]) {Constant} : !spirv.ptr, UniformConstant> + gpu.func @printf_args(%arg0: i32, %arg1: f32) kernel + attributes {spirv.entry_point_abi = #spirv.entry_point_abi<>} { + %0 = gpu.block_id x + %1 = gpu.block_id y + %2 = gpu.thread_id x + + // CHECK: [[FMTSTR_ADDR:%.*]] = spirv.mlir.addressof [[PRINTMSG]] : !spirv.ptr, UniformConstant> + // CHECK-NEXT: [[FMTSTR_PTR1:%.*]] = spirv.Bitcast [[FMTSTR_ADDR]] : !spirv.ptr, UniformConstant> to !spirv.ptr + // CHECK-NEXT: {{%.*}} = spirv.CL.printf [[FMTSTR_PTR1]] {{%.*}}, {{%.*}}, {{%.*}} : !spirv.ptr, i32, f32, i32 -> i32 + gpu.printf "\nHello, world : %d %f \n Thread id: %d\n" %arg0, %arg1, %2: i32, f32, index + + // CHECK: spirv.Return + gpu.return + } + } +} diff --git a/mlir/test/Dialect/SPIRV/IR/ocl-ops.mlir b/mlir/test/Dialect/SPIRV/IR/ocl-ops.mlir index 81ba471d3f51e3d..8f021ed3d663d34 100644 --- a/mlir/test/Dialect/SPIRV/IR/ocl-ops.mlir +++ b/mlir/test/Dialect/SPIRV/IR/ocl-ops.mlir @@ -274,9 +274,9 @@ func.func @rintvec(%arg0 : vector<3xf16>) -> () { // spirv.CL.printf //===----------------------------------------------------------------------===// // CHECK-LABEL: func.func @printf( -func.func @printf(%arg0 : !spirv.ptr, %arg1 : i32, %arg2 : i32) -> i32 { - // CHECK: spirv.CL.printf {{%.*}}, {{%.*}}, {{%.*}} : (!spirv.ptr, (i32, i32)) -> i32 - %0 = spirv.CL.printf %arg0, %arg1, %arg2 : (!spirv.ptr, (i32, i32)) -> i32 +func.func @printf(%fmt : !spirv.ptr, %arg1 : i32, %arg2 : i32) -> i32 { + // CHECK: spirv.CL.printf {{%.*}} {{%.*}}, {{%.*}} : !spirv.ptr, i32, i32 -> i32 + %0 = spirv.CL.printf %fmt %arg1, %arg2 : !spirv.ptr, i32, i32 -> i32 return %0 : i32 }