Skip to content

Commit

Permalink
[SimplifyCFG] Fold switch over ucmp/scmp to icmp and br (llvm#105636)
Browse files Browse the repository at this point in the history
If we switch over ucmp/scmp and have two switch cases going to the same
destination, we can convert into icmp+br.

Fixes llvm#105632.
  • Loading branch information
nikic authored Aug 22, 2024
1 parent 58ac764 commit 4d85285
Show file tree
Hide file tree
Showing 2 changed files with 486 additions and 46 deletions.
116 changes: 116 additions & 0 deletions llvm/lib/Transforms/Utils/SimplifyCFG.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7131,6 +7131,119 @@ static bool simplifySwitchOfPowersOfTwo(SwitchInst *SI, IRBuilder<> &Builder,
return true;
}

/// Fold switch over ucmp/scmp intrinsic to br if two of the switch arms have
/// the same destination.
static bool simplifySwitchOfCmpIntrinsic(SwitchInst *SI, IRBuilderBase &Builder,
DomTreeUpdater *DTU) {
auto *Cmp = dyn_cast<CmpIntrinsic>(SI->getCondition());
if (!Cmp || !Cmp->hasOneUse())
return false;

SmallVector<uint32_t, 4> Weights;
bool HasWeights = extractBranchWeights(getBranchWeightMDNode(*SI), Weights);
if (!HasWeights)
Weights.resize(4); // Avoid checking HasWeights everywhere.

// Normalize to [us]cmp == Res ? Succ : OtherSucc.
int64_t Res;
BasicBlock *Succ, *OtherSucc;
uint32_t SuccWeight = 0, OtherSuccWeight = 0;
BasicBlock *Unreachable = nullptr;

if (SI->getNumCases() == 2) {
// Find which of 1, 0 or -1 is missing (handled by default dest).
SmallSet<int64_t, 3> Missing;
Missing.insert(1);
Missing.insert(0);
Missing.insert(-1);

Succ = SI->getDefaultDest();
SuccWeight = Weights[0];
OtherSucc = nullptr;
for (auto &Case : SI->cases()) {
std::optional<int64_t> Val =
Case.getCaseValue()->getValue().trySExtValue();
if (!Val)
return false;
if (!Missing.erase(*Val))
return false;
if (OtherSucc && OtherSucc != Case.getCaseSuccessor())
return false;
OtherSucc = Case.getCaseSuccessor();
OtherSuccWeight += Weights[Case.getSuccessorIndex()];
}

assert(Missing.size() == 1 && "Should have one case left");
Res = *Missing.begin();
} else if (SI->getNumCases() == 3 && SI->defaultDestUndefined()) {
// Normalize so that Succ is taken once and OtherSucc twice.
Unreachable = SI->getDefaultDest();
Succ = OtherSucc = nullptr;
for (auto &Case : SI->cases()) {
BasicBlock *NewSucc = Case.getCaseSuccessor();
uint32_t Weight = Weights[Case.getSuccessorIndex()];
if (!OtherSucc || OtherSucc == NewSucc) {
OtherSucc = NewSucc;
OtherSuccWeight += Weight;
} else if (!Succ) {
Succ = NewSucc;
SuccWeight = Weight;
} else if (Succ == NewSucc) {
std::swap(Succ, OtherSucc);
std::swap(SuccWeight, OtherSuccWeight);
} else
return false;
}
for (auto &Case : SI->cases()) {
std::optional<int64_t> Val =
Case.getCaseValue()->getValue().trySExtValue();
if (!Val || (Val != 1 && Val != 0 && Val != -1))
return false;
if (Case.getCaseSuccessor() == Succ) {
Res = *Val;
break;
}
}
} else {
return false;
}

// Determine predicate for the missing case.
ICmpInst::Predicate Pred;
switch (Res) {
case 1:
Pred = ICmpInst::ICMP_UGT;
break;
case 0:
Pred = ICmpInst::ICMP_EQ;
break;
case -1:
Pred = ICmpInst::ICMP_ULT;
break;
}
if (Cmp->isSigned())
Pred = ICmpInst::getSignedPredicate(Pred);

MDNode *NewWeights = nullptr;
if (HasWeights)
NewWeights = MDBuilder(SI->getContext())
.createBranchWeights(SuccWeight, OtherSuccWeight);

BasicBlock *BB = SI->getParent();
Builder.SetInsertPoint(SI->getIterator());
Value *ICmp = Builder.CreateICmp(Pred, Cmp->getLHS(), Cmp->getRHS());
Builder.CreateCondBr(ICmp, Succ, OtherSucc, NewWeights,
SI->getMetadata(LLVMContext::MD_unpredictable));
OtherSucc->removePredecessor(BB);
if (Unreachable)
Unreachable->removePredecessor(BB);
SI->eraseFromParent();
Cmp->eraseFromParent();
if (DTU && Unreachable)
DTU->applyUpdates({{DominatorTree::Delete, BB, Unreachable}});
return true;
}

bool SimplifyCFGOpt::simplifySwitch(SwitchInst *SI, IRBuilder<> &Builder) {
BasicBlock *BB = SI->getParent();

Expand Down Expand Up @@ -7163,6 +7276,9 @@ bool SimplifyCFGOpt::simplifySwitch(SwitchInst *SI, IRBuilder<> &Builder) {
if (eliminateDeadSwitchCases(SI, DTU, Options.AC, DL))
return requestResimplify();

if (simplifySwitchOfCmpIntrinsic(SI, Builder, DTU))
return requestResimplify();

if (trySwitchToSelect(SI, Builder, DTU, DL, TTI))
return requestResimplify();

Expand Down
Loading

0 comments on commit 4d85285

Please sign in to comment.