Skip to content

Commit

Permalink
[mlir][spirv] Add gpu printf op lowering to spirv.CL.printf op (#78510)
Browse files Browse the repository at this point in the history
This change contains following:
	- adds lowering of printf op to spirv.CL.printf op in GPUToSPIRV pass.
	- Fixes Constant decoration parsing for spirv GlobalVariable.
	- minor modification to spirv.CL.printf op assembly format.

---------

Co-authored-by: Jakub Kuderski <[email protected]>
  • Loading branch information
drprajap and kuhar authored Sep 30, 2024
1 parent 4dfed69 commit f8ba021
Show file tree
Hide file tree
Showing 6 changed files with 207 additions and 6 deletions.
4 changes: 2 additions & 2 deletions mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCLOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -875,7 +875,7 @@ def SPIRV_CLPrintfOp : SPIRV_CLOp<"printf", 184, []> {
#### Example:

```mlir
%0 = spirv.CL.printf %0 %1 %2 : (!spirv.ptr<i8, UniformConstant>, (i32, i32)) -> i32
%0 = spirv.CL.printf %fmt %1, %2 : !spirv.ptr<i8, UniformConstant>, i32, i32 -> i32
```
}];

Expand All @@ -889,7 +889,7 @@ def SPIRV_CLPrintfOp : SPIRV_CLOp<"printf", 184, []> {
);

let assemblyFormat = [{
$format `,` $arguments attr-dict `:` `(` type($format) `,` `(` type($arguments) `)` `)` `->` type($result)
$format ( $arguments^ )? attr-dict `:` type($format) ( `,` type($arguments)^ )? `->` type($result)
}];

let hasVerifier = 0;
Expand Down
130 changes: 129 additions & 1 deletion mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,15 @@ class GPUShuffleConversion final : public OpConversionPattern<gpu::ShuffleOp> {
ConversionPatternRewriter &rewriter) const override;
};

class GPUPrintfConversion final : public OpConversionPattern<gpu::PrintfOp> {
public:
using OpConversionPattern::OpConversionPattern;

LogicalResult
matchAndRewrite(gpu::PrintfOp gpuPrintfOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override;
};

} // namespace

//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -597,6 +606,124 @@ class GPUSubgroupReduceConversion final
}
};

// Formulate a unique variable/constant name after
// searching in the module for existing variable/constant names.
// This is to avoid name collision with existing variables.
// Example: printfMsg0, printfMsg1, printfMsg2, ...
static std::string makeVarName(spirv::ModuleOp moduleOp, llvm::Twine prefix) {
std::string name;
unsigned number = 0;

do {
name.clear();
name = (prefix + llvm::Twine(number++)).str();
} while (moduleOp.lookupSymbol(name));

return name;
}

/// Pattern to convert a gpu.printf op into a SPIR-V CLPrintf op.

LogicalResult GPUPrintfConversion::matchAndRewrite(
gpu::PrintfOp gpuPrintfOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {

Location loc = gpuPrintfOp.getLoc();

auto moduleOp = gpuPrintfOp->getParentOfType<spirv::ModuleOp>();
if (!moduleOp)
return failure();

// SPIR-V global variable is used to initialize printf
// format string value, if there are multiple printf messages,
// each global var needs to be created with a unique name.
std::string globalVarName = makeVarName(moduleOp, llvm::Twine("printfMsg"));
spirv::GlobalVariableOp globalVar;

IntegerType i8Type = rewriter.getI8Type();
IntegerType i32Type = rewriter.getI32Type();

// Each character of printf format string is
// stored as a spec constant. We need to create
// unique name for this spec constant like
// @printfMsg0_sc0, @printfMsg0_sc1, ... by searching in the module
// for existing spec constant names.
auto createSpecConstant = [&](unsigned value) {
auto attr = rewriter.getI8IntegerAttr(value);
std::string specCstName =
makeVarName(moduleOp, llvm::Twine(globalVarName) + "_sc");

return rewriter.create<spirv::SpecConstantOp>(
loc, rewriter.getStringAttr(specCstName), attr);
};
{
Operation *parent =
SymbolTable::getNearestSymbolTable(gpuPrintfOp->getParentOp());

ConversionPatternRewriter::InsertionGuard guard(rewriter);

Block &entryBlock = *parent->getRegion(0).begin();
rewriter.setInsertionPointToStart(
&entryBlock); // insertion point at module level

// Create Constituents with SpecConstant by scanning format string
// Each character of format string is stored as a spec constant
// and then these spec constants are used to create a
// SpecConstantCompositeOp.
llvm::SmallString<20> formatString(adaptor.getFormat());
formatString.push_back('\0'); // Null terminate for C.
SmallVector<Attribute, 4> constituents;
for (char c : formatString) {
spirv::SpecConstantOp cSpecConstantOp = createSpecConstant(c);
constituents.push_back(SymbolRefAttr::get(cSpecConstantOp));
}

// Create SpecConstantCompositeOp to initialize the global variable
size_t contentSize = constituents.size();
auto globalType = spirv::ArrayType::get(i8Type, contentSize);
spirv::SpecConstantCompositeOp specCstComposite;
// There will be one SpecConstantCompositeOp per printf message/global var,
// so no need do lookup for existing ones.
std::string specCstCompositeName =
(llvm::Twine(globalVarName) + "_scc").str();

specCstComposite = rewriter.create<spirv::SpecConstantCompositeOp>(
loc, TypeAttr::get(globalType),
rewriter.getStringAttr(specCstCompositeName),
rewriter.getArrayAttr(constituents));

auto ptrType = spirv::PointerType::get(
globalType, spirv::StorageClass::UniformConstant);

// Define a GlobalVarOp initialized using specialized constants
// that is used to specify the printf format string
// to be passed to the SPIRV CLPrintfOp.
globalVar = rewriter.create<spirv::GlobalVariableOp>(
loc, ptrType, globalVarName, FlatSymbolRefAttr::get(specCstComposite));

globalVar->setAttr("Constant", rewriter.getUnitAttr());
}
// Get SSA value of Global variable and create pointer to i8 to point to
// the format string.
Value globalPtr = rewriter.create<spirv::AddressOfOp>(loc, globalVar);
Value fmtStr = rewriter.create<spirv::BitcastOp>(
loc,
spirv::PointerType::get(i8Type, spirv::StorageClass::UniformConstant),
globalPtr);

// Get printf arguments.
auto printfArgs = llvm::to_vector_of<Value, 4>(adaptor.getArgs());

rewriter.create<spirv::CLPrintfOp>(loc, i32Type, fmtStr, printfArgs);

// Need to erase the gpu.printf op as gpu.printf does not use result vs
// spirv::CLPrintfOp has i32 resultType so cannot replace with new SPIR-V
// printf op.
rewriter.eraseOp(gpuPrintfOp);

return success();
}

//===----------------------------------------------------------------------===//
// GPU To SPIRV Patterns.
//===----------------------------------------------------------------------===//
Expand All @@ -620,5 +747,6 @@ void mlir::populateGPUToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
SingleDimLaunchConfigConversion<gpu::SubgroupSizeOp,
spirv::BuiltIn::SubgroupSize>,
WorkGroupSizeConversion, GPUAllReduceConversion,
GPUSubgroupReduceConversion>(typeConverter, patterns.getContext());
GPUSubgroupReduceConversion, GPUPrintfConversion>(typeConverter,
patterns.getContext());
}
1 change: 1 addition & 0 deletions mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -319,6 +319,7 @@ LogicalResult spirv::Deserializer::processDecoration(ArrayRef<uint32_t> words) {
case spirv::Decoration::Restrict:
case spirv::Decoration::RestrictPointer:
case spirv::Decoration::NoContraction:
case spirv::Decoration::Constant:
if (words.size() != 2) {
return emitError(unknownLoc, "OpDecoration with ")
<< decorationName << "needs a single target <id>";
Expand Down
1 change: 1 addition & 0 deletions mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,7 @@ LogicalResult Serializer::processDecorationAttr(Location loc, uint32_t resultID,
case spirv::Decoration::Restrict:
case spirv::Decoration::RestrictPointer:
case spirv::Decoration::NoContraction:
case spirv::Decoration::Constant:
// For unit attributes and decoration attributes, the args list
// has no values so we do nothing.
if (isa<UnitAttr, DecorationAttr>(attr))
Expand Down
71 changes: 71 additions & 0 deletions mlir/test/Conversion/GPUToSPIRV/printf.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
// RUN: mlir-opt -allow-unregistered-dialect -split-input-file -convert-gpu-to-spirv -verify-diagnostics %s | FileCheck %s

module attributes {
gpu.container_module,
spirv.target_env = #spirv.target_env<#spirv.vce<v1.0, [Addresses, Int8, Kernel], []>, #spirv.resource_limits<>>
} {
func.func @main() {
%c1 = arith.constant 1 : index

gpu.launch_func @kernels::@printf
blocks in (%c1, %c1, %c1) threads in (%c1, %c1, %c1)
args()
return
}

gpu.module @kernels {
// CHECK: spirv.module @{{.*}} Physical32 OpenCL
// CHECK-DAG: spirv.SpecConstant [[SPECCST:@.*]] = {{.*}} : i8
// CHECK-DAG: spirv.SpecConstantComposite [[SPECCSTCOMPOSITE:@.*]] ([[SPECCST]], {{.*}}) : !spirv.array<[[ARRAYSIZE:.*]] x i8>
// CHECK-DAG: spirv.GlobalVariable [[PRINTMSG:@.*]] initializer([[SPECCSTCOMPOSITE]]) {Constant} : !spirv.ptr<!spirv.array<[[ARRAYSIZE]] x i8>, UniformConstant>
gpu.func @printf() kernel
attributes
{spirv.entry_point_abi = #spirv.entry_point_abi<>} {
// CHECK: [[FMTSTR_ADDR:%.*]] = spirv.mlir.addressof [[PRINTMSG]] : !spirv.ptr<!spirv.array<[[ARRAYSIZE]] x i8>, UniformConstant>
// CHECK-NEXT: [[FMTSTR_PTR:%.*]] = spirv.Bitcast [[FMTSTR_ADDR]] : !spirv.ptr<!spirv.array<[[ARRAYSIZE]] x i8>, UniformConstant> to !spirv.ptr<i8, UniformConstant>
// CHECK-NEXT {{%.*}} = spirv.CL.printf [[FMTSTR_PTR]] : !spirv.ptr<i8, UniformConstant> -> i32
gpu.printf "\nHello\n"
// CHECK: spirv.Return
gpu.return
}
}
}

// -----

module attributes {
gpu.container_module,
spirv.target_env = #spirv.target_env<#spirv.vce<v1.0, [Addresses, Int8, Kernel], []>, #spirv.resource_limits<>>
} {
func.func @main() {
%c1 = arith.constant 1 : index
%c100 = arith.constant 100: i32
%cst_f32 = arith.constant 314.4: f32

gpu.launch_func @kernels1::@printf_args
blocks in (%c1, %c1, %c1) threads in (%c1, %c1, %c1)
args(%c100: i32, %cst_f32: f32)
return
}

gpu.module @kernels1 {
// CHECK: spirv.module @{{.*}} Physical32 OpenCL {
// CHECK-DAG: spirv.SpecConstant [[SPECCST:@.*]] = {{.*}} : i8
// CHECK-DAG: spirv.SpecConstantComposite [[SPECCSTCOMPOSITE:@.*]] ([[SPECCST]], {{.*}}) : !spirv.array<[[ARRAYSIZE:.*]] x i8>
// CHECK-DAG: spirv.GlobalVariable [[PRINTMSG:@.*]] initializer([[SPECCSTCOMPOSITE]]) {Constant} : !spirv.ptr<!spirv.array<[[ARRAYSIZE]] x i8>, UniformConstant>
gpu.func @printf_args(%arg0: i32, %arg1: f32) kernel
attributes {spirv.entry_point_abi = #spirv.entry_point_abi<>} {
%0 = gpu.block_id x
%1 = gpu.block_id y
%2 = gpu.thread_id x

// CHECK: [[FMTSTR_ADDR:%.*]] = spirv.mlir.addressof [[PRINTMSG]] : !spirv.ptr<!spirv.array<[[ARRAYSIZE]] x i8>, UniformConstant>
// CHECK-NEXT: [[FMTSTR_PTR1:%.*]] = spirv.Bitcast [[FMTSTR_ADDR]] : !spirv.ptr<!spirv.array<[[ARRAYSIZE]] x i8>, UniformConstant> to !spirv.ptr<i8, UniformConstant>
// CHECK-NEXT: {{%.*}} = spirv.CL.printf [[FMTSTR_PTR1]] {{%.*}}, {{%.*}}, {{%.*}} : !spirv.ptr<i8, UniformConstant>, i32, f32, i32 -> i32
gpu.printf "\nHello, world : %d %f \n Thread id: %d\n" %arg0, %arg1, %2: i32, f32, index

// CHECK: spirv.Return
gpu.return
}
}
}
6 changes: 3 additions & 3 deletions mlir/test/Dialect/SPIRV/IR/ocl-ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -274,9 +274,9 @@ func.func @rintvec(%arg0 : vector<3xf16>) -> () {
// spirv.CL.printf
//===----------------------------------------------------------------------===//
// CHECK-LABEL: func.func @printf(
func.func @printf(%arg0 : !spirv.ptr<i8, UniformConstant>, %arg1 : i32, %arg2 : i32) -> i32 {
// CHECK: spirv.CL.printf {{%.*}}, {{%.*}}, {{%.*}} : (!spirv.ptr<i8, UniformConstant>, (i32, i32)) -> i32
%0 = spirv.CL.printf %arg0, %arg1, %arg2 : (!spirv.ptr<i8, UniformConstant>, (i32, i32)) -> i32
func.func @printf(%fmt : !spirv.ptr<i8, UniformConstant>, %arg1 : i32, %arg2 : i32) -> i32 {
// CHECK: spirv.CL.printf {{%.*}} {{%.*}}, {{%.*}} : !spirv.ptr<i8, UniformConstant>, i32, i32 -> i32
%0 = spirv.CL.printf %fmt %arg1, %arg2 : !spirv.ptr<i8, UniformConstant>, i32, i32 -> i32
return %0 : i32
}

Expand Down

0 comments on commit f8ba021

Please sign in to comment.