From bd9bceaa8fdb9dfe0c8c21376f984333aec6dc36 Mon Sep 17 00:00:00 2001 From: "Levytskyy, Vyacheslav" Date: Fri, 11 Oct 2024 09:11:07 -0700 Subject: [PATCH] update type inference for function pointers and update test cases --- llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp | 107 ++++++++++++++---- .../fp-simple-hierarchy.ll | 49 ++++++-- .../SPV_INTEL_function_pointers/fp_const.ll | 37 +++--- .../CodeGen/SPIRV/instructions/select-phi.ll | 8 +- 4 files changed, 149 insertions(+), 52 deletions(-) diff --git a/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp b/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp index 4ac06cc19f03dc..8b7e9c48de6c75 100644 --- a/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp @@ -69,6 +69,7 @@ class SPIRVEmitIntrinsics SPIRVGlobalRegistry *GR = nullptr; Function *F = nullptr; bool TrackConstants = true; + bool HaveFunPtrs = false; DenseMap AggrConsts; DenseMap AggrConstTypes; DenseSet AggrStores; @@ -714,6 +715,37 @@ static bool deduceOperandElementTypeCalledFunction( return true; } +// Try to deduce element type for a function pointer. +static void deduceOperandElementTypeFunctionPointer( + SPIRVGlobalRegistry *GR, Instruction *I, CallInst *CI, + SmallVector> &Ops, Type *&KnownElemTy) { + Value *Op = CI->getCalledOperand(); + if (!Op || !isPointerTy(Op->getType())) + return; + Ops.push_back(std::make_pair(Op, std::numeric_limits::max())); + FunctionType *FTy = CI->getFunctionType(); + bool IsNewFTy = false; + SmallVector ArgTys; + for (Value *Arg : CI->args()) { + Type *ArgTy = Arg->getType(); + if (ArgTy->isPointerTy()) + if (Type *ElemTy = GR->findDeducedElementType(Arg)) { + IsNewFTy = true; + ArgTy = TypedPointerType::get(ElemTy, getPointerAddressSpace(ArgTy)); + } + ArgTys.push_back(ArgTy); + } + Type *RetTy = FTy->getReturnType(); + if (I->getType()->isPointerTy()) + if (Type *ElemTy = GR->findDeducedElementType(I)) { + IsNewFTy = true; + RetTy = + TypedPointerType::get(ElemTy, getPointerAddressSpace(I->getType())); + } + KnownElemTy = + IsNewFTy ? FunctionType::get(RetTy, ArgTys, FTy->isVarArg()) : FTy; +} + // If the Instruction has Pointer operands with unresolved types, this function // tries to deduce them. If the Instruction has Pointer operands with known // types which differ from expected, this function tries to insert a bitcast to @@ -820,17 +852,11 @@ void SPIRVEmitIntrinsics::deduceOperandElementType(Instruction *I, Ops.push_back(std::make_pair(Op0, 0)); } } else if (CallInst *CI = dyn_cast(I)) { - if (!CI->isIndirectCall()) { + if (!CI->isIndirectCall()) deduceOperandElementTypeCalledFunction(GR, I, InstrSet, CI, Ops, KnownElemTy); - } else if (TM->getSubtarget(*F).canUseExtension( - SPIRV::Extension::SPV_INTEL_function_pointers)) { - Value *Op = CI->getCalledOperand(); - if (!Op || !isPointerTy(Op->getType())) - return; - Ops.push_back(std::make_pair(Op, std::numeric_limits::max())); - KnownElemTy = CI->getFunctionType(); - } + else if (HaveFunPtrs) + deduceOperandElementTypeFunctionPointer(GR, I, CI, Ops, KnownElemTy); } // There is no enough info to deduce types or all is valid. @@ -1710,23 +1736,53 @@ void SPIRVEmitIntrinsics::processParamTypes(Function *F, IRBuilder<> &B) { } } +static FunctionType *getFunctionPointerElemType(Function *F, + SPIRVGlobalRegistry *GR) { + FunctionType *FTy = F->getFunctionType(); + bool IsNewFTy = false; + SmallVector ArgTys; + for (Argument &Arg : F->args()) { + Type *ArgTy = Arg.getType(); + if (ArgTy->isPointerTy()) + if (Type *ElemTy = GR->findDeducedElementType(&Arg)) { + IsNewFTy = true; + ArgTy = TypedPointerType::get(ElemTy, getPointerAddressSpace(ArgTy)); + } + ArgTys.push_back(ArgTy); + } + return IsNewFTy + ? FunctionType::get(FTy->getReturnType(), ArgTys, FTy->isVarArg()) + : FTy; +} + bool SPIRVEmitIntrinsics::processFunctionPointers(Module &M) { - bool IsExt = false; SmallVector Worklist; for (auto &F : M) { - if (!IsExt) { - if (!TM->getSubtarget(F).canUseExtension( - SPIRV::Extension::SPV_INTEL_function_pointers)) - return false; - IsExt = true; - } - if (!F.isDeclaration() || F.isIntrinsic()) + if (F.isIntrinsic()) continue; - for (User *U : F.users()) { - CallInst *CI = dyn_cast(U); - if (!CI || CI->getCalledFunction() != &F) { - Worklist.push_back(&F); - break; + if (F.isDeclaration()) { + for (User *U : F.users()) { + CallInst *CI = dyn_cast(U); + if (!CI || CI->getCalledFunction() != &F) { + Worklist.push_back(&F); + break; + } + } + } else { + if (F.user_empty()) + continue; + Type *FPElemTy = GR->findDeducedElementType(&F); + if (!FPElemTy) + FPElemTy = getFunctionPointerElemType(&F, GR); + for (User *U : F.users()) { + IntrinsicInst *II = dyn_cast(U); + if (!II || II->arg_size() != 3 || II->getOperand(0) != &F) + continue; + if (II->getIntrinsicID() == Intrinsic::spv_assign_ptr_type || + II->getIntrinsicID() == Intrinsic::spv_ptrcast) { + updateAssignType(II, &F, PoisonValue::get(FPElemTy)); + break; + } } } } @@ -1765,6 +1821,10 @@ bool SPIRVEmitIntrinsics::runOnFunction(Function &Func) { InstrSet = ST.isOpenCLEnv() ? SPIRV::InstructionSet::OpenCL_std : SPIRV::InstructionSet::GLSL_std_450; + if (!F) + HaveFunPtrs = + ST.canUseExtension(SPIRV::Extension::SPV_INTEL_function_pointers); + F = &Func; IRBuilder<> B(Func.getContext()); AggrConsts.clear(); @@ -1910,7 +1970,8 @@ bool SPIRVEmitIntrinsics::runOnModule(Module &M) { } Changed |= postprocessTypes(); - Changed |= processFunctionPointers(M); + if (HaveFunPtrs) + Changed |= processFunctionPointers(M); return Changed; } diff --git a/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_function_pointers/fp-simple-hierarchy.ll b/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_function_pointers/fp-simple-hierarchy.ll index 0178e1192d7ea7..d5a8fb3e7baafa 100644 --- a/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_function_pointers/fp-simple-hierarchy.ll +++ b/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_function_pointers/fp-simple-hierarchy.ll @@ -1,22 +1,47 @@ -; RUN: llc -O0 -mtriple=spirv64-unknown-unknown --spirv-ext=+SPV_INTEL_function_pointers %s -o - | FileCheck %s +; RUN: llc -O0 -mtriple=spirv32-unknown-unknown --spirv-ext=+SPV_INTEL_function_pointers %s -o - | FileCheck %s ; TODO: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown %s -o - -filetype=obj | spirv-val %} -; CHECK: OpFunction +; CHECK-DAG: OpName %[[I9:.*]] "_ZN13BaseIncrement9incrementEPi" +; CHECK-DAG: OpName %[[I29:.*]] "_ZN12IncrementBy29incrementEPi" +; CHECK-DAG: OpName %[[I49:.*]] "_ZN12IncrementBy49incrementEPi" +; CHECK-DAG: OpName %[[I89:.*]] "_ZN12IncrementBy89incrementEPi" -%classid = type { %arrayid } -%arrayid = type { [1 x i64] } -%struct.obj_storage_t = type { %storage } -%storage = type { [8 x i8] } +; CHECK-DAG: %[[TyVoid:.*]] = OpTypeVoid +; CHECK-DAG: %[[TyArr:.*]] = OpTypeArray +; CHECK-DAG: %[[TyStruct1:.*]] = OpTypeStruct %[[TyArr]] +; CHECK-DAG: %[[TyStruct2:.*]] = OpTypeStruct %[[TyStruct1]] +; CHECK-DAG: %[[TyPtrStruct2:.*]] = OpTypePointer Generic %[[TyStruct2]] +; CHECK-DAG: %[[TyFun:.*]] = OpTypeFunction %[[TyVoid]] %[[TyPtrStruct2]] %[[#]] +; CHECK-DAG: %[[TyPtrFun:.*]] = OpTypePointer Generic %[[TyFun]] +; CHECK-DAG: %[[TyPtrPtrFun:.*]] = OpTypePointer Generic %[[TyPtrFun]] + +; CHECK: %[[I9]] = OpFunction +; CHECK: %[[I29]] = OpFunction +; CHECK: %[[I49]] = OpFunction +; CHECK: %[[I89]] = OpFunction + +; CHECK: %[[Arg1:.*]] = OpPhi %[[TyPtrStruct2]] +; CHECK: %[[VTbl:.*]] = OpBitcast %[[TyPtrPtrFun]] %[[#]] +; CHECK: %[[FP:.*]] = OpLoad %[[TyPtrFun]] %[[VTbl]] +; CHECK: %[[#]] = OpFunctionPointerCallINTEL %[[TyVoid]] %[[FP]] %[[Arg1]] %[[#]] + +%"cls::id" = type { %"cls::detail::array" } +%"cls::detail::array" = type { [1 x i64] } +%struct.obj_storage_t = type { %"struct.aligned_storage::type" } +%"struct.aligned_storage::type" = type { [8 x i8] } @_ZTV12IncrementBy8 = linkonce_odr dso_local unnamed_addr addrspace(1) constant { [3 x ptr addrspace(4)] } { [3 x ptr addrspace(4)] [ptr addrspace(4) null, ptr addrspace(4) null, ptr addrspace(4) addrspacecast (ptr @_ZN12IncrementBy89incrementEPi to ptr addrspace(4))] }, align 8 @_ZTV13BaseIncrement = linkonce_odr dso_local unnamed_addr addrspace(1) constant { [3 x ptr addrspace(4)] } { [3 x ptr addrspace(4)] [ptr addrspace(4) null, ptr addrspace(4) null, ptr addrspace(4) addrspacecast (ptr @_ZN13BaseIncrement9incrementEPi to ptr addrspace(4))] }, align 8 @_ZTV12IncrementBy4 = linkonce_odr dso_local unnamed_addr addrspace(1) constant { [3 x ptr addrspace(4)] } { [3 x ptr addrspace(4)] [ptr addrspace(4) null, ptr addrspace(4) null, ptr addrspace(4) addrspacecast (ptr @_ZN12IncrementBy49incrementEPi to ptr addrspace(4))] }, align 8 @_ZTV12IncrementBy2 = linkonce_odr dso_local unnamed_addr addrspace(1) constant { [3 x ptr addrspace(4)] } { [3 x ptr addrspace(4)] [ptr addrspace(4) null, ptr addrspace(4) null, ptr addrspace(4) addrspacecast (ptr @_ZN12IncrementBy29incrementEPi to ptr addrspace(4))] }, align 8 +@__spirv_BuiltInWorkgroupId = external dso_local local_unnamed_addr addrspace(1) constant <3 x i64>, align 32 +@__spirv_BuiltInGlobalLinearId = external dso_local local_unnamed_addr addrspace(1) constant i64, align 8 +@__spirv_BuiltInWorkgroupSize = external dso_local local_unnamed_addr addrspace(1) constant <3 x i64>, align 32 -define weak_odr dso_local spir_kernel void @foo(ptr addrspace(1) noundef align 8 %_arg_StorageAcc, ptr noundef byval(%classid) align 8 %_arg_StorageAcc3, i32 noundef %_arg_TestCase, ptr addrspace(1) noundef align 4 %_arg_DataAcc) { +define weak_odr dso_local spir_kernel void @foo(ptr addrspace(1) noundef align 8 %_arg_StorageAcc, ptr noundef byval(%"cls::id") align 8 %_arg_StorageAcc3, i32 noundef %_arg_TestCase, ptr addrspace(1) noundef align 4 %_arg_DataAcc) { entry: - %0 = load i64, ptr %_arg_StorageAcc3, align 8 - %add.ptr.i = getelementptr inbounds %struct.obj_storage_t, ptr addrspace(1) %_arg_StorageAcc, i64 %0 + %r0 = load i64, ptr %_arg_StorageAcc3, align 8 + %add.ptr.i = getelementptr inbounds %struct.obj_storage_t, ptr addrspace(1) %_arg_StorageAcc, i64 %r0 %arrayidx.ascast.i = addrspacecast ptr addrspace(1) %add.ptr.i to ptr addrspace(4) %cmp.i = icmp ugt i32 %_arg_TestCase, 3 br i1 %cmp.i, label %entry.critedge, label %if.end.1 @@ -51,9 +76,9 @@ if.end.2: ; preds = %if.end.1 exit: ; preds = %if.end.2, %if.end.3, %if.end.4, %if.end.5, %entry.critedge %vtable.i = phi ptr addrspace(4) [ %vtable.i.pre, %entry.critedge ], [ inttoptr (i64 ptrtoint (ptr addrspace(1) getelementptr inbounds inrange(-16, 8) (i8, ptr addrspace(1) @_ZTV12IncrementBy8, i64 16) to i64) to ptr addrspace(4)), %if.end.5 ], [ inttoptr (i64 ptrtoint (ptr addrspace(1) getelementptr inbounds inrange(-16, 8) (i8, ptr addrspace(1) @_ZTV12IncrementBy4, i64 16) to i64) to ptr addrspace(4)), %if.end.4 ], [ inttoptr (i64 ptrtoint (ptr addrspace(1) getelementptr inbounds inrange(-16, 8) (i8, ptr addrspace(1) @_ZTV12IncrementBy2, i64 16) to i64) to ptr addrspace(4)), %if.end.3 ], [ inttoptr (i64 ptrtoint (ptr addrspace(1) getelementptr inbounds inrange(-16, 8) (i8, ptr addrspace(1) @_ZTV13BaseIncrement, i64 16) to i64) to ptr addrspace(4)), %if.end.2 ] %retval.0.i = phi ptr addrspace(4) [ null, %entry.critedge ], [ %arrayidx.ascast.i, %if.end.5 ], [ %arrayidx.ascast.i, %if.end.4 ], [ %arrayidx.ascast.i, %if.end.3 ], [ %arrayidx.ascast.i, %if.end.2 ] - %1 = addrspacecast ptr addrspace(1) %_arg_DataAcc to ptr addrspace(4) - %2 = load ptr addrspace(4), ptr addrspace(4) %vtable.i, align 8 - tail call spir_func addrspace(4) void %2(ptr addrspace(4) noundef align 8 dereferenceable_or_null(8) %retval.0.i, ptr addrspace(4) noundef %1) + %r1 = addrspacecast ptr addrspace(1) %_arg_DataAcc to ptr addrspace(4) + %r2 = load ptr addrspace(4), ptr addrspace(4) %vtable.i, align 8 + tail call spir_func addrspace(4) void %r2(ptr addrspace(4) noundef align 8 dereferenceable_or_null(8) %retval.0.i, ptr addrspace(4) noundef %r1) ret void } diff --git a/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_function_pointers/fp_const.ll b/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_function_pointers/fp_const.ll index 5f073e95cb68f2..b4faba9a4eb8e3 100644 --- a/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_function_pointers/fp_const.ll +++ b/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_function_pointers/fp_const.ll @@ -5,30 +5,39 @@ ; CHECK-DAG: OpCapability FunctionPointersINTEL ; CHECK-DAG: OpCapability Int64 ; CHECK: OpExtension "SPV_INTEL_function_pointers" -; CHECK-DAG: %[[TyInt8:.*]] = OpTypeInt 8 0 + ; CHECK-DAG: %[[TyVoid:.*]] = OpTypeVoid ; CHECK-DAG: %[[TyInt64:.*]] = OpTypeInt 64 0 -; CHECK-DAG: %[[TyFunFp:.*]] = OpTypeFunction %[[TyVoid]] %[[TyInt64]] -; CHECK-DAG: %[[ConstInt64:.*]] = OpConstant %[[TyInt64]] 42 -; CHECK-DAG: %[[TyPtrFunFp:.*]] = OpTypePointer Function %[[TyFunFp]] -; CHECK-DAG: %[[ConstFunFp:.*]] = OpConstantFunctionPointerINTEL %[[TyPtrFunFp]] %[[DefFunFp:.*]] -; CHECK: %[[FunPtr1:.*]] = OpBitcast %[[#]] %[[ConstFunFp]] -; CHECK: %[[FunPtr2:.*]] = OpLoad %[[#]] %[[FunPtr1]] -; CHECK: OpFunctionPointerCallINTEL %[[TyInt64]] %[[FunPtr2]] %[[ConstInt64]] -; CHECK: OpReturn +; CHECK-DAG: %[[TyFun:.*]] = OpTypeFunction %[[TyInt64]] %[[TyInt64]] +; CHECK-DAG: %[[TyInt8:.*]] = OpTypeInt 8 0 +; CHECK-DAG: %[[TyPtrFun:.*]] = OpTypePointer Function %[[TyFun]] +; CHECK-DAG: %[[ConstFunFp:.*]] = OpConstantFunctionPointerINTEL %[[TyPtrFun]] %[[DefFunFp:.*]] +; CHECK-DAG: %[[TyPtrPtrFun:.*]] = OpTypePointer Function %[[TyPtrFun]] +; CHECK-DAG: %[[TyPtrInt8:.*]] = OpTypePointer Function %[[TyInt8]] +; CHECK-DAG: %[[TyPtrPtrInt8:.*]] = OpTypePointer Function %[[TyPtrInt8]] +; CHECK: OpFunction +; CHECK: %[[Var:.*]] = OpVariable %[[TyPtrPtrInt8]] Function +; CHECK: %[[SAddr:.*]] = OpBitcast %[[TyPtrPtrFun]] %[[Var]] +; CHECK: OpStore %[[SAddr]] %[[ConstFunFp]] +; CHECK: %[[LAddr:.*]] = OpBitcast %[[TyPtrPtrFun]] %[[Var]] +; CHECK: %[[FP:.*]] = OpLoad %[[TyPtrFun]] %[[LAddr]] +; CHECK: OpFunctionPointerCallINTEL %[[TyInt64]] %[[FP]] %[[#]] ; CHECK: OpFunctionEnd -; CHECK: %[[DefFunFp]] = OpFunction %[[TyVoid]] None %[[TyFunFp]] + +; CHECK: %[[DefFunFp]] = OpFunction %[[TyInt64]] None %[[TyFun]] target triple = "spir64-unknown-unknown" define spir_kernel void @test() { entry: - %0 = load ptr, ptr @foo - %1 = call i64 %0(i64 42) + %fp = alloca ptr + store ptr @foo, ptr %fp + %tocall = load ptr, ptr %fp + %res = call i64 %tocall(i64 42) ret void } -define void @foo(i64 %a) { +define i64 @foo(i64 %a) { entry: - ret void + ret i64 %a } diff --git a/llvm/test/CodeGen/SPIRV/instructions/select-phi.ll b/llvm/test/CodeGen/SPIRV/instructions/select-phi.ll index 3828fe89e60aec..16be7cd3b8db62 100644 --- a/llvm/test/CodeGen/SPIRV/instructions/select-phi.ll +++ b/llvm/test/CodeGen/SPIRV/instructions/select-phi.ll @@ -1,3 +1,6 @@ +; This test case checks how phi-nodes with different operand types select +; a result type. Majority of operands makes it i8* in this case. + ; RUN: llc -O0 -mtriple=spirv32-unknown-unknown %s -o - | FileCheck %s ; RUN: llc -O0 -mtriple=spirv64-unknown-unknown %s -o - | FileCheck %s @@ -15,14 +18,13 @@ ; CHECK: %[[Branch1:.*]] = OpLabel ; CHECK: %[[Res1:.*]] = OpVariable %[[StructPtr]] Function +; CHECK: %[[Res1Casted:.*]] = OpBitcast %[[CharPtr]] %[[Res1]] ; CHECK: OpBranchConditional %[[#]] %[[#]] %[[Branch2:.*]] ; CHECK: %[[Res2:.*]] = OpInBoundsPtrAccessChain %[[CharPtr]] %[[#]] %[[#]] -; CHECK: %[[Res2Casted:.*]] = OpBitcast %[[StructPtr]] %[[Res2]] ; CHECK: OpBranchConditional %[[#]] %[[#]] %[[BranchSelect:.*]] ; CHECK: %[[SelectRes:.*]] = OpSelect %[[CharPtr]] %[[#]] %[[#]] %[[#]] -; CHECK: %[[SelectResCasted:.*]] = OpBitcast %[[StructPtr]] %[[SelectRes]] ; CHECK: OpLabel -; CHECK: OpPhi %[[StructPtr]] %[[Res1]] %[[Branch1]] %[[Res2Casted]] %[[Branch2]] %[[SelectResCasted]] %[[BranchSelect]] +; CHECK: OpPhi %[[CharPtr]] %[[Res1Casted]] %[[Branch1]] %[[Res2]] %[[Branch2]] %[[SelectRes]] %[[BranchSelect]] %struct = type { %array } %array = type { [1 x i64] }