Skip to content

Commit

Permalink
[GlobalISel] Fold G_ICMP if possible (#86357)
Browse files Browse the repository at this point in the history
This patch tries to fold `G_ICMP` if possible.
  • Loading branch information
shiltian authored Mar 29, 2024
1 parent 360f7f5 commit 3a106e5
Show file tree
Hide file tree
Showing 24 changed files with 694 additions and 393 deletions.
4 changes: 4 additions & 0 deletions llvm/include/llvm/CodeGen/GlobalISel/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -315,6 +315,10 @@ std::optional<SmallVector<unsigned>>
ConstantFoldCountZeros(Register Src, const MachineRegisterInfo &MRI,
std::function<unsigned(APInt)> CB);

std::optional<SmallVector<APInt>>
ConstantFoldICmp(unsigned Pred, const Register Op1, const Register Op2,
const MachineRegisterInfo &MRI);

/// Test if the given value is known to have exactly one bit set. This differs
/// from computeKnownBits in that it doesn't necessarily determine which bit is
/// set.
Expand Down
14 changes: 14 additions & 0 deletions llvm/lib/CodeGen/GlobalISel/CSEMIRBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,20 @@ MachineInstrBuilder CSEMIRBuilder::buildInstr(unsigned Opc,
switch (Opc) {
default:
break;
case TargetOpcode::G_ICMP: {
assert(SrcOps.size() == 3 && "Invalid sources");
assert(DstOps.size() == 1 && "Invalid dsts");
LLT SrcTy = SrcOps[1].getLLTTy(*getMRI());

if (std::optional<SmallVector<APInt>> Cst =
ConstantFoldICmp(SrcOps[0].getPredicate(), SrcOps[1].getReg(),
SrcOps[2].getReg(), *getMRI())) {
if (SrcTy.isVector())
return buildBuildVectorConstant(DstOps[0], *Cst);
return buildConstant(DstOps[0], Cst->front());
}
break;
}
case TargetOpcode::G_ADD:
case TargetOpcode::G_PTR_ADD:
case TargetOpcode::G_AND:
Expand Down
34 changes: 25 additions & 9 deletions llvm/lib/CodeGen/GlobalISel/LegalizerHelper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3768,9 +3768,11 @@ LegalizerHelper::lower(MachineInstr &MI, unsigned TypeIdx, LLT LowerHintTy) {
}
case TargetOpcode::G_ATOMIC_CMPXCHG_WITH_SUCCESS: {
auto [OldValRes, SuccessRes, Addr, CmpVal, NewVal] = MI.getFirst5Regs();
MIRBuilder.buildAtomicCmpXchg(OldValRes, Addr, CmpVal, NewVal,
Register NewOldValRes = MRI.cloneVirtualRegister(OldValRes);
MIRBuilder.buildAtomicCmpXchg(NewOldValRes, Addr, CmpVal, NewVal,
**MI.memoperands_begin());
MIRBuilder.buildICmp(CmpInst::ICMP_EQ, SuccessRes, OldValRes, CmpVal);
MIRBuilder.buildICmp(CmpInst::ICMP_EQ, SuccessRes, NewOldValRes, CmpVal);
MIRBuilder.buildCopy(OldValRes, NewOldValRes);
MI.eraseFromParent();
return Legalized;
}
Expand All @@ -3789,8 +3791,12 @@ LegalizerHelper::lower(MachineInstr &MI, unsigned TypeIdx, LLT LowerHintTy) {
case G_UADDO: {
auto [Res, CarryOut, LHS, RHS] = MI.getFirst4Regs();

MIRBuilder.buildAdd(Res, LHS, RHS);
MIRBuilder.buildICmp(CmpInst::ICMP_ULT, CarryOut, Res, RHS);
Register NewRes = MRI.cloneVirtualRegister(Res);

MIRBuilder.buildAdd(NewRes, LHS, RHS);
MIRBuilder.buildICmp(CmpInst::ICMP_ULT, CarryOut, NewRes, RHS);

MIRBuilder.buildCopy(Res, NewRes);

MI.eraseFromParent();
return Legalized;
Expand All @@ -3800,6 +3806,8 @@ LegalizerHelper::lower(MachineInstr &MI, unsigned TypeIdx, LLT LowerHintTy) {
const LLT CondTy = MRI.getType(CarryOut);
const LLT Ty = MRI.getType(Res);

Register NewRes = MRI.cloneVirtualRegister(Res);

// Initial add of the two operands.
auto TmpRes = MIRBuilder.buildAdd(Ty, LHS, RHS);

Expand All @@ -3808,15 +3816,18 @@ LegalizerHelper::lower(MachineInstr &MI, unsigned TypeIdx, LLT LowerHintTy) {

// Add the sum and the carry.
auto ZExtCarryIn = MIRBuilder.buildZExt(Ty, CarryIn);
MIRBuilder.buildAdd(Res, TmpRes, ZExtCarryIn);
MIRBuilder.buildAdd(NewRes, TmpRes, ZExtCarryIn);

// Second check for carry. We can only carry if the initial sum is all 1s
// and the carry is set, resulting in a new sum of 0.
auto Zero = MIRBuilder.buildConstant(Ty, 0);
auto ResEqZero = MIRBuilder.buildICmp(CmpInst::ICMP_EQ, CondTy, Res, Zero);
auto ResEqZero =
MIRBuilder.buildICmp(CmpInst::ICMP_EQ, CondTy, NewRes, Zero);
auto Carry2 = MIRBuilder.buildAnd(CondTy, ResEqZero, CarryIn);
MIRBuilder.buildOr(CarryOut, Carry, Carry2);

MIRBuilder.buildCopy(Res, NewRes);

MI.eraseFromParent();
return Legalized;
}
Expand Down Expand Up @@ -7671,10 +7682,12 @@ LegalizerHelper::lowerSADDO_SSUBO(MachineInstr &MI) {
LLT Ty = Dst0Ty;
LLT BoolTy = Dst1Ty;

Register NewDst0 = MRI.cloneVirtualRegister(Dst0);

if (IsAdd)
MIRBuilder.buildAdd(Dst0, LHS, RHS);
MIRBuilder.buildAdd(NewDst0, LHS, RHS);
else
MIRBuilder.buildSub(Dst0, LHS, RHS);
MIRBuilder.buildSub(NewDst0, LHS, RHS);

// TODO: If SADDSAT/SSUBSAT is legal, compare results to detect overflow.

Expand All @@ -7687,12 +7700,15 @@ LegalizerHelper::lowerSADDO_SSUBO(MachineInstr &MI) {
// (LHS) if and only if the other operand (RHS) is (non-zero) positive,
// otherwise there will be overflow.
auto ResultLowerThanLHS =
MIRBuilder.buildICmp(CmpInst::ICMP_SLT, BoolTy, Dst0, LHS);
MIRBuilder.buildICmp(CmpInst::ICMP_SLT, BoolTy, NewDst0, LHS);
auto ConditionRHS = MIRBuilder.buildICmp(
IsAdd ? CmpInst::ICMP_SLT : CmpInst::ICMP_SGT, BoolTy, RHS, Zero);

MIRBuilder.buildXor(Dst1, ConditionRHS, ResultLowerThanLHS);

MIRBuilder.buildCopy(Dst0, NewDst0);
MI.eraseFromParent();

return Legalized;
}

Expand Down
68 changes: 68 additions & 0 deletions llvm/lib/CodeGen/GlobalISel/Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -997,6 +997,74 @@ llvm::ConstantFoldCountZeros(Register Src, const MachineRegisterInfo &MRI,
return std::nullopt;
}

std::optional<SmallVector<APInt>>
llvm::ConstantFoldICmp(unsigned Pred, const Register Op1, const Register Op2,
const MachineRegisterInfo &MRI) {
LLT Ty = MRI.getType(Op1);
if (Ty != MRI.getType(Op2))
return std::nullopt;

auto TryFoldScalar = [&MRI, Pred](Register LHS,
Register RHS) -> std::optional<APInt> {
auto LHSCst = getIConstantVRegVal(LHS, MRI);
auto RHSCst = getIConstantVRegVal(RHS, MRI);
if (!LHSCst || !RHSCst)
return std::nullopt;

switch (Pred) {
case CmpInst::Predicate::ICMP_EQ:
return APInt(/*numBits=*/1, LHSCst->eq(*RHSCst));
case CmpInst::Predicate::ICMP_NE:
return APInt(/*numBits=*/1, LHSCst->ne(*RHSCst));
case CmpInst::Predicate::ICMP_UGT:
return APInt(/*numBits=*/1, LHSCst->ugt(*RHSCst));
case CmpInst::Predicate::ICMP_UGE:
return APInt(/*numBits=*/1, LHSCst->uge(*RHSCst));
case CmpInst::Predicate::ICMP_ULT:
return APInt(/*numBits=*/1, LHSCst->ult(*RHSCst));
case CmpInst::Predicate::ICMP_ULE:
return APInt(/*numBits=*/1, LHSCst->ule(*RHSCst));
case CmpInst::Predicate::ICMP_SGT:
return APInt(/*numBits=*/1, LHSCst->sgt(*RHSCst));
case CmpInst::Predicate::ICMP_SGE:
return APInt(/*numBits=*/1, LHSCst->sge(*RHSCst));
case CmpInst::Predicate::ICMP_SLT:
return APInt(/*numBits=*/1, LHSCst->slt(*RHSCst));
case CmpInst::Predicate::ICMP_SLE:
return APInt(/*numBits=*/1, LHSCst->sle(*RHSCst));
default:
return std::nullopt;
}
};

SmallVector<APInt> FoldedICmps;

if (Ty.isVector()) {
// Try to constant fold each element.
auto *BV1 = getOpcodeDef<GBuildVector>(Op1, MRI);
auto *BV2 = getOpcodeDef<GBuildVector>(Op2, MRI);
if (!BV1 || !BV2)
return std::nullopt;
assert(BV1->getNumSources() == BV2->getNumSources() && "Invalid vectors");
for (unsigned I = 0; I < BV1->getNumSources(); ++I) {
if (auto MaybeFold =
TryFoldScalar(BV1->getSourceReg(I), BV2->getSourceReg(I))) {
FoldedICmps.emplace_back(*MaybeFold);
continue;
}
return std::nullopt;
}
return FoldedICmps;
}

if (auto MaybeCst = TryFoldScalar(Op1, Op2)) {
FoldedICmps.emplace_back(*MaybeCst);
return FoldedICmps;
}

return std::nullopt;
}

bool llvm::isKnownToBeAPowerOfTwo(Register Reg, const MachineRegisterInfo &MRI,
GISelKnownBits *KB) {
std::optional<DefinitionAndSourceRegister> DefSrcReg =
Expand Down
Loading

0 comments on commit 3a106e5

Please sign in to comment.