Skip to content

Commit

Permalink
update type inference for function pointers and update test cases
Browse files Browse the repository at this point in the history
  • Loading branch information
VyacheslavLevytskyy committed Oct 11, 2024
1 parent 7f79653 commit bd9bcea
Show file tree
Hide file tree
Showing 4 changed files with 149 additions and 52 deletions.
107 changes: 84 additions & 23 deletions llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ class SPIRVEmitIntrinsics
SPIRVGlobalRegistry *GR = nullptr;
Function *F = nullptr;
bool TrackConstants = true;
bool HaveFunPtrs = false;
DenseMap<Instruction *, Constant *> AggrConsts;
DenseMap<Instruction *, Type *> AggrConstTypes;
DenseSet<Instruction *> AggrStores;
Expand Down Expand Up @@ -714,6 +715,37 @@ static bool deduceOperandElementTypeCalledFunction(
return true;
}

// Try to deduce element type for a function pointer.
static void deduceOperandElementTypeFunctionPointer(
SPIRVGlobalRegistry *GR, Instruction *I, CallInst *CI,
SmallVector<std::pair<Value *, unsigned>> &Ops, Type *&KnownElemTy) {
Value *Op = CI->getCalledOperand();
if (!Op || !isPointerTy(Op->getType()))
return;
Ops.push_back(std::make_pair(Op, std::numeric_limits<unsigned>::max()));
FunctionType *FTy = CI->getFunctionType();
bool IsNewFTy = false;
SmallVector<Type *, 4> ArgTys;
for (Value *Arg : CI->args()) {
Type *ArgTy = Arg->getType();
if (ArgTy->isPointerTy())
if (Type *ElemTy = GR->findDeducedElementType(Arg)) {
IsNewFTy = true;
ArgTy = TypedPointerType::get(ElemTy, getPointerAddressSpace(ArgTy));
}
ArgTys.push_back(ArgTy);
}
Type *RetTy = FTy->getReturnType();
if (I->getType()->isPointerTy())
if (Type *ElemTy = GR->findDeducedElementType(I)) {
IsNewFTy = true;
RetTy =
TypedPointerType::get(ElemTy, getPointerAddressSpace(I->getType()));
}
KnownElemTy =
IsNewFTy ? FunctionType::get(RetTy, ArgTys, FTy->isVarArg()) : FTy;
}

// If the Instruction has Pointer operands with unresolved types, this function
// tries to deduce them. If the Instruction has Pointer operands with known
// types which differ from expected, this function tries to insert a bitcast to
Expand Down Expand Up @@ -820,17 +852,11 @@ void SPIRVEmitIntrinsics::deduceOperandElementType(Instruction *I,
Ops.push_back(std::make_pair(Op0, 0));
}
} else if (CallInst *CI = dyn_cast<CallInst>(I)) {
if (!CI->isIndirectCall()) {
if (!CI->isIndirectCall())
deduceOperandElementTypeCalledFunction(GR, I, InstrSet, CI, Ops,
KnownElemTy);
} else if (TM->getSubtarget<SPIRVSubtarget>(*F).canUseExtension(
SPIRV::Extension::SPV_INTEL_function_pointers)) {
Value *Op = CI->getCalledOperand();
if (!Op || !isPointerTy(Op->getType()))
return;
Ops.push_back(std::make_pair(Op, std::numeric_limits<unsigned>::max()));
KnownElemTy = CI->getFunctionType();
}
else if (HaveFunPtrs)
deduceOperandElementTypeFunctionPointer(GR, I, CI, Ops, KnownElemTy);
}

// There is no enough info to deduce types or all is valid.
Expand Down Expand Up @@ -1710,23 +1736,53 @@ void SPIRVEmitIntrinsics::processParamTypes(Function *F, IRBuilder<> &B) {
}
}

static FunctionType *getFunctionPointerElemType(Function *F,
SPIRVGlobalRegistry *GR) {
FunctionType *FTy = F->getFunctionType();
bool IsNewFTy = false;
SmallVector<Type *, 4> ArgTys;
for (Argument &Arg : F->args()) {
Type *ArgTy = Arg.getType();
if (ArgTy->isPointerTy())
if (Type *ElemTy = GR->findDeducedElementType(&Arg)) {
IsNewFTy = true;
ArgTy = TypedPointerType::get(ElemTy, getPointerAddressSpace(ArgTy));
}
ArgTys.push_back(ArgTy);
}
return IsNewFTy
? FunctionType::get(FTy->getReturnType(), ArgTys, FTy->isVarArg())
: FTy;
}

bool SPIRVEmitIntrinsics::processFunctionPointers(Module &M) {
bool IsExt = false;
SmallVector<Function *> Worklist;
for (auto &F : M) {
if (!IsExt) {
if (!TM->getSubtarget<SPIRVSubtarget>(F).canUseExtension(
SPIRV::Extension::SPV_INTEL_function_pointers))
return false;
IsExt = true;
}
if (!F.isDeclaration() || F.isIntrinsic())
if (F.isIntrinsic())
continue;
for (User *U : F.users()) {
CallInst *CI = dyn_cast<CallInst>(U);
if (!CI || CI->getCalledFunction() != &F) {
Worklist.push_back(&F);
break;
if (F.isDeclaration()) {
for (User *U : F.users()) {
CallInst *CI = dyn_cast<CallInst>(U);
if (!CI || CI->getCalledFunction() != &F) {
Worklist.push_back(&F);
break;
}
}
} else {
if (F.user_empty())
continue;
Type *FPElemTy = GR->findDeducedElementType(&F);
if (!FPElemTy)
FPElemTy = getFunctionPointerElemType(&F, GR);
for (User *U : F.users()) {
IntrinsicInst *II = dyn_cast<IntrinsicInst>(U);
if (!II || II->arg_size() != 3 || II->getOperand(0) != &F)
continue;
if (II->getIntrinsicID() == Intrinsic::spv_assign_ptr_type ||
II->getIntrinsicID() == Intrinsic::spv_ptrcast) {
updateAssignType(II, &F, PoisonValue::get(FPElemTy));
break;
}
}
}
}
Expand Down Expand Up @@ -1765,6 +1821,10 @@ bool SPIRVEmitIntrinsics::runOnFunction(Function &Func) {
InstrSet = ST.isOpenCLEnv() ? SPIRV::InstructionSet::OpenCL_std
: SPIRV::InstructionSet::GLSL_std_450;

if (!F)
HaveFunPtrs =
ST.canUseExtension(SPIRV::Extension::SPV_INTEL_function_pointers);

F = &Func;
IRBuilder<> B(Func.getContext());
AggrConsts.clear();
Expand Down Expand Up @@ -1910,7 +1970,8 @@ bool SPIRVEmitIntrinsics::runOnModule(Module &M) {
}

Changed |= postprocessTypes();
Changed |= processFunctionPointers(M);
if (HaveFunPtrs)
Changed |= processFunctionPointers(M);

return Changed;
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,22 +1,47 @@
; RUN: llc -O0 -mtriple=spirv64-unknown-unknown --spirv-ext=+SPV_INTEL_function_pointers %s -o - | FileCheck %s
; RUN: llc -O0 -mtriple=spirv32-unknown-unknown --spirv-ext=+SPV_INTEL_function_pointers %s -o - | FileCheck %s
; TODO: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown %s -o - -filetype=obj | spirv-val %}

; CHECK: OpFunction
; CHECK-DAG: OpName %[[I9:.*]] "_ZN13BaseIncrement9incrementEPi"
; CHECK-DAG: OpName %[[I29:.*]] "_ZN12IncrementBy29incrementEPi"
; CHECK-DAG: OpName %[[I49:.*]] "_ZN12IncrementBy49incrementEPi"
; CHECK-DAG: OpName %[[I89:.*]] "_ZN12IncrementBy89incrementEPi"

%classid = type { %arrayid }
%arrayid = type { [1 x i64] }
%struct.obj_storage_t = type { %storage }
%storage = type { [8 x i8] }
; CHECK-DAG: %[[TyVoid:.*]] = OpTypeVoid
; CHECK-DAG: %[[TyArr:.*]] = OpTypeArray
; CHECK-DAG: %[[TyStruct1:.*]] = OpTypeStruct %[[TyArr]]
; CHECK-DAG: %[[TyStruct2:.*]] = OpTypeStruct %[[TyStruct1]]
; CHECK-DAG: %[[TyPtrStruct2:.*]] = OpTypePointer Generic %[[TyStruct2]]
; CHECK-DAG: %[[TyFun:.*]] = OpTypeFunction %[[TyVoid]] %[[TyPtrStruct2]] %[[#]]
; CHECK-DAG: %[[TyPtrFun:.*]] = OpTypePointer Generic %[[TyFun]]
; CHECK-DAG: %[[TyPtrPtrFun:.*]] = OpTypePointer Generic %[[TyPtrFun]]

; CHECK: %[[I9]] = OpFunction
; CHECK: %[[I29]] = OpFunction
; CHECK: %[[I49]] = OpFunction
; CHECK: %[[I89]] = OpFunction

; CHECK: %[[Arg1:.*]] = OpPhi %[[TyPtrStruct2]]
; CHECK: %[[VTbl:.*]] = OpBitcast %[[TyPtrPtrFun]] %[[#]]
; CHECK: %[[FP:.*]] = OpLoad %[[TyPtrFun]] %[[VTbl]]
; CHECK: %[[#]] = OpFunctionPointerCallINTEL %[[TyVoid]] %[[FP]] %[[Arg1]] %[[#]]

%"cls::id" = type { %"cls::detail::array" }
%"cls::detail::array" = type { [1 x i64] }
%struct.obj_storage_t = type { %"struct.aligned_storage<BaseIncrement, IncrementBy2, IncrementBy4, IncrementBy8>::type" }
%"struct.aligned_storage<BaseIncrement, IncrementBy2, IncrementBy4, IncrementBy8>::type" = type { [8 x i8] }

@_ZTV12IncrementBy8 = linkonce_odr dso_local unnamed_addr addrspace(1) constant { [3 x ptr addrspace(4)] } { [3 x ptr addrspace(4)] [ptr addrspace(4) null, ptr addrspace(4) null, ptr addrspace(4) addrspacecast (ptr @_ZN12IncrementBy89incrementEPi to ptr addrspace(4))] }, align 8
@_ZTV13BaseIncrement = linkonce_odr dso_local unnamed_addr addrspace(1) constant { [3 x ptr addrspace(4)] } { [3 x ptr addrspace(4)] [ptr addrspace(4) null, ptr addrspace(4) null, ptr addrspace(4) addrspacecast (ptr @_ZN13BaseIncrement9incrementEPi to ptr addrspace(4))] }, align 8
@_ZTV12IncrementBy4 = linkonce_odr dso_local unnamed_addr addrspace(1) constant { [3 x ptr addrspace(4)] } { [3 x ptr addrspace(4)] [ptr addrspace(4) null, ptr addrspace(4) null, ptr addrspace(4) addrspacecast (ptr @_ZN12IncrementBy49incrementEPi to ptr addrspace(4))] }, align 8
@_ZTV12IncrementBy2 = linkonce_odr dso_local unnamed_addr addrspace(1) constant { [3 x ptr addrspace(4)] } { [3 x ptr addrspace(4)] [ptr addrspace(4) null, ptr addrspace(4) null, ptr addrspace(4) addrspacecast (ptr @_ZN12IncrementBy29incrementEPi to ptr addrspace(4))] }, align 8
@__spirv_BuiltInWorkgroupId = external dso_local local_unnamed_addr addrspace(1) constant <3 x i64>, align 32
@__spirv_BuiltInGlobalLinearId = external dso_local local_unnamed_addr addrspace(1) constant i64, align 8
@__spirv_BuiltInWorkgroupSize = external dso_local local_unnamed_addr addrspace(1) constant <3 x i64>, align 32

define weak_odr dso_local spir_kernel void @foo(ptr addrspace(1) noundef align 8 %_arg_StorageAcc, ptr noundef byval(%classid) align 8 %_arg_StorageAcc3, i32 noundef %_arg_TestCase, ptr addrspace(1) noundef align 4 %_arg_DataAcc) {
define weak_odr dso_local spir_kernel void @foo(ptr addrspace(1) noundef align 8 %_arg_StorageAcc, ptr noundef byval(%"cls::id") align 8 %_arg_StorageAcc3, i32 noundef %_arg_TestCase, ptr addrspace(1) noundef align 4 %_arg_DataAcc) {
entry:
%0 = load i64, ptr %_arg_StorageAcc3, align 8
%add.ptr.i = getelementptr inbounds %struct.obj_storage_t, ptr addrspace(1) %_arg_StorageAcc, i64 %0
%r0 = load i64, ptr %_arg_StorageAcc3, align 8
%add.ptr.i = getelementptr inbounds %struct.obj_storage_t, ptr addrspace(1) %_arg_StorageAcc, i64 %r0
%arrayidx.ascast.i = addrspacecast ptr addrspace(1) %add.ptr.i to ptr addrspace(4)
%cmp.i = icmp ugt i32 %_arg_TestCase, 3
br i1 %cmp.i, label %entry.critedge, label %if.end.1
Expand Down Expand Up @@ -51,9 +76,9 @@ if.end.2: ; preds = %if.end.1
exit: ; preds = %if.end.2, %if.end.3, %if.end.4, %if.end.5, %entry.critedge
%vtable.i = phi ptr addrspace(4) [ %vtable.i.pre, %entry.critedge ], [ inttoptr (i64 ptrtoint (ptr addrspace(1) getelementptr inbounds inrange(-16, 8) (i8, ptr addrspace(1) @_ZTV12IncrementBy8, i64 16) to i64) to ptr addrspace(4)), %if.end.5 ], [ inttoptr (i64 ptrtoint (ptr addrspace(1) getelementptr inbounds inrange(-16, 8) (i8, ptr addrspace(1) @_ZTV12IncrementBy4, i64 16) to i64) to ptr addrspace(4)), %if.end.4 ], [ inttoptr (i64 ptrtoint (ptr addrspace(1) getelementptr inbounds inrange(-16, 8) (i8, ptr addrspace(1) @_ZTV12IncrementBy2, i64 16) to i64) to ptr addrspace(4)), %if.end.3 ], [ inttoptr (i64 ptrtoint (ptr addrspace(1) getelementptr inbounds inrange(-16, 8) (i8, ptr addrspace(1) @_ZTV13BaseIncrement, i64 16) to i64) to ptr addrspace(4)), %if.end.2 ]
%retval.0.i = phi ptr addrspace(4) [ null, %entry.critedge ], [ %arrayidx.ascast.i, %if.end.5 ], [ %arrayidx.ascast.i, %if.end.4 ], [ %arrayidx.ascast.i, %if.end.3 ], [ %arrayidx.ascast.i, %if.end.2 ]
%1 = addrspacecast ptr addrspace(1) %_arg_DataAcc to ptr addrspace(4)
%2 = load ptr addrspace(4), ptr addrspace(4) %vtable.i, align 8
tail call spir_func addrspace(4) void %2(ptr addrspace(4) noundef align 8 dereferenceable_or_null(8) %retval.0.i, ptr addrspace(4) noundef %1)
%r1 = addrspacecast ptr addrspace(1) %_arg_DataAcc to ptr addrspace(4)
%r2 = load ptr addrspace(4), ptr addrspace(4) %vtable.i, align 8
tail call spir_func addrspace(4) void %r2(ptr addrspace(4) noundef align 8 dereferenceable_or_null(8) %retval.0.i, ptr addrspace(4) noundef %r1)
ret void
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,30 +5,39 @@
; CHECK-DAG: OpCapability FunctionPointersINTEL
; CHECK-DAG: OpCapability Int64
; CHECK: OpExtension "SPV_INTEL_function_pointers"
; CHECK-DAG: %[[TyInt8:.*]] = OpTypeInt 8 0

; CHECK-DAG: %[[TyVoid:.*]] = OpTypeVoid
; CHECK-DAG: %[[TyInt64:.*]] = OpTypeInt 64 0
; CHECK-DAG: %[[TyFunFp:.*]] = OpTypeFunction %[[TyVoid]] %[[TyInt64]]
; CHECK-DAG: %[[ConstInt64:.*]] = OpConstant %[[TyInt64]] 42
; CHECK-DAG: %[[TyPtrFunFp:.*]] = OpTypePointer Function %[[TyFunFp]]
; CHECK-DAG: %[[ConstFunFp:.*]] = OpConstantFunctionPointerINTEL %[[TyPtrFunFp]] %[[DefFunFp:.*]]
; CHECK: %[[FunPtr1:.*]] = OpBitcast %[[#]] %[[ConstFunFp]]
; CHECK: %[[FunPtr2:.*]] = OpLoad %[[#]] %[[FunPtr1]]
; CHECK: OpFunctionPointerCallINTEL %[[TyInt64]] %[[FunPtr2]] %[[ConstInt64]]
; CHECK: OpReturn
; CHECK-DAG: %[[TyFun:.*]] = OpTypeFunction %[[TyInt64]] %[[TyInt64]]
; CHECK-DAG: %[[TyInt8:.*]] = OpTypeInt 8 0
; CHECK-DAG: %[[TyPtrFun:.*]] = OpTypePointer Function %[[TyFun]]
; CHECK-DAG: %[[ConstFunFp:.*]] = OpConstantFunctionPointerINTEL %[[TyPtrFun]] %[[DefFunFp:.*]]
; CHECK-DAG: %[[TyPtrPtrFun:.*]] = OpTypePointer Function %[[TyPtrFun]]
; CHECK-DAG: %[[TyPtrInt8:.*]] = OpTypePointer Function %[[TyInt8]]
; CHECK-DAG: %[[TyPtrPtrInt8:.*]] = OpTypePointer Function %[[TyPtrInt8]]
; CHECK: OpFunction
; CHECK: %[[Var:.*]] = OpVariable %[[TyPtrPtrInt8]] Function
; CHECK: %[[SAddr:.*]] = OpBitcast %[[TyPtrPtrFun]] %[[Var]]
; CHECK: OpStore %[[SAddr]] %[[ConstFunFp]]
; CHECK: %[[LAddr:.*]] = OpBitcast %[[TyPtrPtrFun]] %[[Var]]
; CHECK: %[[FP:.*]] = OpLoad %[[TyPtrFun]] %[[LAddr]]
; CHECK: OpFunctionPointerCallINTEL %[[TyInt64]] %[[FP]] %[[#]]
; CHECK: OpFunctionEnd
; CHECK: %[[DefFunFp]] = OpFunction %[[TyVoid]] None %[[TyFunFp]]

; CHECK: %[[DefFunFp]] = OpFunction %[[TyInt64]] None %[[TyFun]]

target triple = "spir64-unknown-unknown"

define spir_kernel void @test() {
entry:
%0 = load ptr, ptr @foo
%1 = call i64 %0(i64 42)
%fp = alloca ptr
store ptr @foo, ptr %fp
%tocall = load ptr, ptr %fp
%res = call i64 %tocall(i64 42)
ret void
}

define void @foo(i64 %a) {
define i64 @foo(i64 %a) {
entry:
ret void
ret i64 %a
}
8 changes: 5 additions & 3 deletions llvm/test/CodeGen/SPIRV/instructions/select-phi.ll
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
; This test case checks how phi-nodes with different operand types select
; a result type. Majority of operands makes it i8* in this case.

; RUN: llc -O0 -mtriple=spirv32-unknown-unknown %s -o - | FileCheck %s
; RUN: llc -O0 -mtriple=spirv64-unknown-unknown %s -o - | FileCheck %s

Expand All @@ -15,14 +18,13 @@

; CHECK: %[[Branch1:.*]] = OpLabel
; CHECK: %[[Res1:.*]] = OpVariable %[[StructPtr]] Function
; CHECK: %[[Res1Casted:.*]] = OpBitcast %[[CharPtr]] %[[Res1]]
; CHECK: OpBranchConditional %[[#]] %[[#]] %[[Branch2:.*]]
; CHECK: %[[Res2:.*]] = OpInBoundsPtrAccessChain %[[CharPtr]] %[[#]] %[[#]]
; CHECK: %[[Res2Casted:.*]] = OpBitcast %[[StructPtr]] %[[Res2]]
; CHECK: OpBranchConditional %[[#]] %[[#]] %[[BranchSelect:.*]]
; CHECK: %[[SelectRes:.*]] = OpSelect %[[CharPtr]] %[[#]] %[[#]] %[[#]]
; CHECK: %[[SelectResCasted:.*]] = OpBitcast %[[StructPtr]] %[[SelectRes]]
; CHECK: OpLabel
; CHECK: OpPhi %[[StructPtr]] %[[Res1]] %[[Branch1]] %[[Res2Casted]] %[[Branch2]] %[[SelectResCasted]] %[[BranchSelect]]
; CHECK: OpPhi %[[CharPtr]] %[[Res1Casted]] %[[Branch1]] %[[Res2]] %[[Branch2]] %[[SelectRes]] %[[BranchSelect]]

%struct = type { %array }
%array = type { [1 x i64] }
Expand Down

0 comments on commit bd9bcea

Please sign in to comment.