Skip to content

Commit

Permalink
improve type deduction for phi and call base
Browse files Browse the repository at this point in the history
  • Loading branch information
VyacheslavLevytskyy committed Oct 11, 2024
1 parent 91b9add commit 7f79653
Show file tree
Hide file tree
Showing 3 changed files with 110 additions and 67 deletions.
16 changes: 11 additions & 5 deletions llvm/lib/Target/SPIRV/SPIRVAsmPrinter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,11 @@ class SPIRVAsmPrinter : public AsmPrinter {
void outputExecutionMode(const Module &M);
void outputAnnotations(const Module &M);
void outputModuleSections();
bool isHidden() {
return MF->getFunction()
.getFnAttribute(SPIRV_BACKEND_SERVICE_FUN_NAME)
.isValid();
}

void emitInstruction(const MachineInstr *MI) override;
void emitFunctionEntryLabel() override {}
Expand Down Expand Up @@ -131,7 +136,7 @@ void SPIRVAsmPrinter::emitFunctionHeader() {
TII = ST->getInstrInfo();
const Function &F = MF->getFunction();

if (isVerbose()) {
if (isVerbose() && !isHidden()) {
OutStreamer->getCommentOS()
<< "-- Begin function "
<< GlobalValue::dropLLVMManglingEscape(F.getName()) << '\n';
Expand All @@ -150,16 +155,17 @@ void SPIRVAsmPrinter::outputOpFunctionEnd() {
// Emit OpFunctionEnd at the end of MF and clear BBNumToRegMap.
void SPIRVAsmPrinter::emitFunctionBodyEnd() {
// Do not emit anything if it's an internal service function.
if (MF->getFunction()
.getFnAttribute(SPIRV_BACKEND_SERVICE_FUN_NAME)
.isValid())
if (isHidden())
return;

outputOpFunctionEnd();
MAI->BBNumToRegMap.clear();
}

void SPIRVAsmPrinter::emitOpLabel(const MachineBasicBlock &MBB) {
// Do not emit anything if it's an internal service function.
if (isHidden())
return;

MCInst LabelInst;
LabelInst.setOpcode(SPIRV::OpLabel);
LabelInst.addOperand(MCOperand::createReg(MAI->getOrCreateMBBRegister(MBB)));
Expand Down
20 changes: 10 additions & 10 deletions llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -500,16 +500,6 @@ void SPIRVCallLowering::produceIndirectPtrTypes(

bool SPIRVCallLowering::lowerCall(MachineIRBuilder &MIRBuilder,
CallLoweringInfo &Info) const {
// Ignore if called from the internal service function
if (MIRBuilder.getMF()
.getFunction()
.getFnAttribute(SPIRV_BACKEND_SERVICE_FUN_NAME)
.isValid()) {
// insert a no-op
MIRBuilder.buildTrap();
return true;
}

// Currently call returns should have single vregs.
// TODO: handle the case of multiple registers.
if (Info.OrigRet.Regs.size() > 1)
Expand Down Expand Up @@ -597,6 +587,16 @@ bool SPIRVCallLowering::lowerCall(MachineIRBuilder &MIRBuilder,
lowerFormalArguments(FirstBlockBuilder, *CF, VRegArgs, FuncInfo);
}

// Ignore the call if it's called from the internal service function
if (MIRBuilder.getMF()
.getFunction()
.getFnAttribute(SPIRV_BACKEND_SERVICE_FUN_NAME)
.isValid()) {
// insert a no-op
MIRBuilder.buildTrap();
return true;
}

unsigned CallOp;
if (Info.CB->isIndirectCall()) {
if (!ST->canUseExtension(SPIRV::Extension::SPV_INTEL_function_pointers))
Expand Down
141 changes: 89 additions & 52 deletions llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -386,7 +386,8 @@ Type *SPIRVEmitIntrinsics::deduceElementTypeByValueDeep(
// Traverse User instructions to deduce an element pointer type of the operand.
Type *SPIRVEmitIntrinsics::deduceElementTypeByUsersDeep(
Value *Op, std::unordered_set<Value *> &Visited, bool UnknownElemTypeI8) {
if (!Op || !isPointerTy(Op->getType()))
if (!Op || !isPointerTy(Op->getType()) || isa<ConstantPointerNull>(Op) ||
isa<UndefValue>(Op))
return nullptr;

if (auto ElemTy = getPointeeType(Op->getType()))
Expand Down Expand Up @@ -483,12 +484,25 @@ Type *SPIRVEmitIntrinsics::deduceElementTypeHelper(
if (isPointerTy(Op->getType()))
Ty = deduceElementTypeHelper(Op, Visited, UnknownElemTypeI8);
} else if (auto *Ref = dyn_cast<PHINode>(I)) {
for (unsigned i = 0; i < Ref->getNumIncomingValues(); i++) {
Type *BestTy = nullptr;
unsigned MaxN = 1;
DenseMap<Type *, unsigned> PhiTys;
for (int i = Ref->getNumIncomingValues() - 1; i >= 0; --i) {
Ty = deduceElementTypeByUsersDeep(Ref->getIncomingValue(i), Visited,
UnknownElemTypeI8);
if (Ty)
break;
if (!Ty)
continue;
auto It = PhiTys.try_emplace(Ty, 1);
if (!It.second) {
++It.first->second;
if (It.first->second > MaxN) {
MaxN = It.first->second;
BestTy = Ty;
}
}
}
if (BestTy)
Ty = BestTy;
} else if (auto *Ref = dyn_cast<SelectInst>(I)) {
for (Value *Op : {Ref->getTrueValue(), Ref->getFalseValue()}) {
Ty = deduceElementTypeByUsersDeep(Op, Visited, UnknownElemTypeI8);
Expand Down Expand Up @@ -644,6 +658,62 @@ static inline Type *getAtomicElemTy(SPIRVGlobalRegistry *GR, Instruction *I,
return nullptr;
}

// Try to deduce element type for a call base. Returns false if this is an
// indirect function invocation, and true otherwise.
static bool deduceOperandElementTypeCalledFunction(
SPIRVGlobalRegistry *GR, Instruction *I,
SPIRV::InstructionSet::InstructionSet InstrSet, CallInst *CI,
SmallVector<std::pair<Value *, unsigned>> &Ops, Type *&KnownElemTy) {
Function *CalledF = CI->getCalledFunction();
if (!CalledF)
return false;
std::string DemangledName =
getOclOrSpirvBuiltinDemangledName(CalledF->getName());
if (DemangledName.length() > 0 &&
!StringRef(DemangledName).starts_with("llvm.")) {
auto [Grp, Opcode, ExtNo] =
SPIRV::mapBuiltinToOpcode(DemangledName, InstrSet);
if (Opcode == SPIRV::OpGroupAsyncCopy) {
for (unsigned i = 0, PtrCnt = 0; i < CI->arg_size() && PtrCnt < 2; ++i) {
Value *Op = CI->getArgOperand(i);
if (!isPointerTy(Op->getType()))
continue;
++PtrCnt;
if (Type *ElemTy = GR->findDeducedElementType(Op))
KnownElemTy = ElemTy; // src will rewrite dest if both are defined
Ops.push_back(std::make_pair(Op, i));
}
} else if (Grp == SPIRV::Atomic || Grp == SPIRV::AtomicFloating) {
if (CI->arg_size() < 2)
return true;
Value *Op = CI->getArgOperand(0);
if (!isPointerTy(Op->getType()))
return true;
switch (Opcode) {
case SPIRV::OpAtomicLoad:
case SPIRV::OpAtomicCompareExchangeWeak:
case SPIRV::OpAtomicCompareExchange:
case SPIRV::OpAtomicExchange:
case SPIRV::OpAtomicIAdd:
case SPIRV::OpAtomicISub:
case SPIRV::OpAtomicOr:
case SPIRV::OpAtomicXor:
case SPIRV::OpAtomicAnd:
case SPIRV::OpAtomicUMin:
case SPIRV::OpAtomicUMax:
case SPIRV::OpAtomicSMin:
case SPIRV::OpAtomicSMax: {
KnownElemTy = getAtomicElemTy(GR, I, Op);
if (!KnownElemTy)
return true;
Ops.push_back(std::make_pair(Op, 0));
} break;
}
}
}
return true;
}

// If the Instruction has Pointer operands with unresolved types, this function
// tries to deduce them. If the Instruction has Pointer operands with known
// types which differ from expected, this function tries to insert a bitcast to
Expand Down Expand Up @@ -749,53 +819,17 @@ void SPIRVEmitIntrinsics::deduceOperandElementType(Instruction *I,
KnownElemTy = ElemTy1;
Ops.push_back(std::make_pair(Op0, 0));
}
} else if (auto *CI = dyn_cast<CallInst>(I)) {
if (Function *CalledF = CI->getCalledFunction()) {
std::string DemangledName =
getOclOrSpirvBuiltinDemangledName(CalledF->getName());
if (DemangledName.length() > 0 &&
!StringRef(DemangledName).starts_with("llvm.")) {
auto [Grp, Opcode, ExtNo] =
SPIRV::mapBuiltinToOpcode(DemangledName, InstrSet);
if (Opcode == SPIRV::OpGroupAsyncCopy) {
for (unsigned i = 0, PtrCnt = 0; i < CI->arg_size() && PtrCnt < 2;
++i) {
Value *Op = CI->getArgOperand(i);
if (!isPointerTy(Op->getType()))
continue;
++PtrCnt;
if (Type *ElemTy = GR->findDeducedElementType(Op))
KnownElemTy = ElemTy; // src will rewrite dest if both are defined
Ops.push_back(std::make_pair(Op, i));
}
} else if (Grp == SPIRV::Atomic || Grp == SPIRV::AtomicFloating) {
if (CI->arg_size() < 2)
return;
Value *Op = CI->getArgOperand(0);
if (!isPointerTy(Op->getType()))
return;
switch (Opcode) {
case SPIRV::OpAtomicLoad:
case SPIRV::OpAtomicCompareExchangeWeak:
case SPIRV::OpAtomicCompareExchange:
case SPIRV::OpAtomicExchange:
case SPIRV::OpAtomicIAdd:
case SPIRV::OpAtomicISub:
case SPIRV::OpAtomicOr:
case SPIRV::OpAtomicXor:
case SPIRV::OpAtomicAnd:
case SPIRV::OpAtomicUMin:
case SPIRV::OpAtomicUMax:
case SPIRV::OpAtomicSMin:
case SPIRV::OpAtomicSMax: {
KnownElemTy = getAtomicElemTy(GR, I, Op);
if (!KnownElemTy)
return;
Ops.push_back(std::make_pair(Op, 0));
} break;
}
}
}
} else if (CallInst *CI = dyn_cast<CallInst>(I)) {
if (!CI->isIndirectCall()) {
deduceOperandElementTypeCalledFunction(GR, I, InstrSet, CI, Ops,
KnownElemTy);
} else if (TM->getSubtarget<SPIRVSubtarget>(*F).canUseExtension(
SPIRV::Extension::SPV_INTEL_function_pointers)) {
Value *Op = CI->getCalledOperand();
if (!Op || !isPointerTy(Op->getType()))
return;
Ops.push_back(std::make_pair(Op, std::numeric_limits<unsigned>::max()));
KnownElemTy = CI->getFunctionType();
}
}

Expand Down Expand Up @@ -846,7 +880,10 @@ void SPIRVEmitIntrinsics::deduceOperandElementType(Instruction *I,
B.getInt32(getPointerAddressSpace(OpTy))};
CallInst *PtrCastI =
B.CreateIntrinsic(Intrinsic::spv_ptrcast, {Types}, Args);
I->setOperand(OpIt.second, PtrCastI);
if (OpIt.second == std::numeric_limits<unsigned>::max())
dyn_cast<CallInst>(I)->setCalledOperand(PtrCastI);
else
I->setOperand(OpIt.second, PtrCastI);
buildAssignPtr(B, KnownElemTy, PtrCastI);
}
}
Expand Down

0 comments on commit 7f79653

Please sign in to comment.