Skip to content

Commit

Permalink
[SCEV] Store predicates for EL/ENT in SmallVector.
Browse files Browse the repository at this point in the history
Store predicates in ExitLimit and ExitNotTaken in a SmallVector instead
of a SmallPtrSet. This guarantees the predicates can be iterated on in a
predictable manner. This ensures the predicates can be printed and
generated in a predictable order.

This shifts de-duplication of predicates to construction time for
ExitLimit. ExitNotTaken just takes predicates from ExitLimit, so they
should also be free of duplicates.

This was exposed by 2f7ccaf
(#108777).

Should fix https://lab.llvm.org/buildbot/#/builders/110/builds/1494.
  • Loading branch information
fhahn committed Sep 28, 2024
1 parent d3ca484 commit 6022a3a
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 45 deletions.
29 changes: 12 additions & 17 deletions llvm/include/llvm/Analysis/ScalarEvolution.h
Original file line number Diff line number Diff line change
Expand Up @@ -1125,15 +1125,10 @@ class ScalarEvolution {
// Not taken either exactly ConstantMaxNotTaken or zero times
bool MaxOrZero = false;

/// A set of predicate guards for this ExitLimit. The result is only valid
/// if all of the predicates in \c Predicates evaluate to 'true' at
/// A vector of predicate guards for this ExitLimit. The result is only
/// valid if all of the predicates in \c Predicates evaluate to 'true' at
/// run-time.
SmallPtrSet<const SCEVPredicate *, 4> Predicates;

void addPredicate(const SCEVPredicate *P) {
assert(!isa<SCEVUnionPredicate>(P) && "Only add leaf predicates here!");
Predicates.insert(P);
}
SmallVector<const SCEVPredicate *, 4> Predicates;

/// Construct either an exact exit limit from a constant, or an unknown
/// one from a SCEVCouldNotCompute. No other types of SCEVs are allowed
Expand All @@ -1142,12 +1137,11 @@ class ScalarEvolution {

ExitLimit(const SCEV *E, const SCEV *ConstantMaxNotTaken,
const SCEV *SymbolicMaxNotTaken, bool MaxOrZero,
ArrayRef<const SmallPtrSetImpl<const SCEVPredicate *> *>
PredSetList = {});
ArrayRef<ArrayRef<const SCEVPredicate *>> PredLists = {});

ExitLimit(const SCEV *E, const SCEV *ConstantMaxNotTaken,
const SCEV *SymbolicMaxNotTaken, bool MaxOrZero,
const SmallPtrSetImpl<const SCEVPredicate *> &PredSet);
ArrayRef<const SCEVPredicate *> PredList);

/// Test whether this ExitLimit contains any computed information, or
/// whether it's all SCEVCouldNotCompute values.
Expand Down Expand Up @@ -1297,7 +1291,7 @@ class ScalarEvolution {
/// adding additional predicates to \p Preds as required.
const SCEVAddRecExpr *convertSCEVToAddRecWithPredicates(
const SCEV *S, const Loop *L,
SmallPtrSetImpl<const SCEVPredicate *> &Preds);
SmallVectorImpl<const SCEVPredicate *> &Preds);

/// Compute \p LHS - \p RHS and returns the result as an APInt if it is a
/// constant, and std::nullopt if it isn't.
Expand Down Expand Up @@ -1489,12 +1483,13 @@ class ScalarEvolution {
const SCEV *ExactNotTaken;
const SCEV *ConstantMaxNotTaken;
const SCEV *SymbolicMaxNotTaken;
SmallPtrSet<const SCEVPredicate *, 4> Predicates;
SmallVector<const SCEVPredicate *, 4> Predicates;

explicit ExitNotTakenInfo(
PoisoningVH<BasicBlock> ExitingBlock, const SCEV *ExactNotTaken,
const SCEV *ConstantMaxNotTaken, const SCEV *SymbolicMaxNotTaken,
const SmallPtrSet<const SCEVPredicate *, 4> &Predicates)
explicit ExitNotTakenInfo(PoisoningVH<BasicBlock> ExitingBlock,
const SCEV *ExactNotTaken,
const SCEV *ConstantMaxNotTaken,
const SCEV *SymbolicMaxNotTaken,
ArrayRef<const SCEVPredicate *> Predicates)
: ExitingBlock(ExitingBlock), ExactNotTaken(ExactNotTaken),
ConstantMaxNotTaken(ConstantMaxNotTaken),
SymbolicMaxNotTaken(SymbolicMaxNotTaken), Predicates(Predicates) {}
Expand Down
64 changes: 36 additions & 28 deletions llvm/lib/Analysis/ScalarEvolution.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8693,12 +8693,12 @@ bool ScalarEvolution::BackedgeTakenInfo::isConstantMaxOrZero(
}

ScalarEvolution::ExitLimit::ExitLimit(const SCEV *E)
: ExitLimit(E, E, E, false, {}) {}
: ExitLimit(E, E, E, false) {}

ScalarEvolution::ExitLimit::ExitLimit(
const SCEV *E, const SCEV *ConstantMaxNotTaken,
const SCEV *SymbolicMaxNotTaken, bool MaxOrZero,
ArrayRef<const SmallPtrSetImpl<const SCEVPredicate *> *> PredSetList)
ArrayRef<ArrayRef<const SCEVPredicate *>> PredLists)
: ExactNotTaken(E), ConstantMaxNotTaken(ConstantMaxNotTaken),
SymbolicMaxNotTaken(SymbolicMaxNotTaken), MaxOrZero(MaxOrZero) {
// If we prove the max count is zero, so is the symbolic bound. This happens
Expand All @@ -8721,22 +8721,29 @@ ScalarEvolution::ExitLimit::ExitLimit(
assert((isa<SCEVCouldNotCompute>(ConstantMaxNotTaken) ||
isa<SCEVConstant>(ConstantMaxNotTaken)) &&
"No point in having a non-constant max backedge taken count!");
for (const auto *PredSet : PredSetList)
for (const auto *P : *PredSet)
addPredicate(P);
SmallPtrSet<const SCEVPredicate *, 4> SeenPreds;
for (const auto PredList : PredLists)
for (const auto *P : PredList) {
if (SeenPreds.contains(P))
continue;
assert(!isa<SCEVUnionPredicate>(P) && "Only add leaf predicates here!");
SeenPreds.insert(P);
Predicates.push_back(P);
}
assert((isa<SCEVCouldNotCompute>(E) || !E->getType()->isPointerTy()) &&
"Backedge count should be int");
assert((isa<SCEVCouldNotCompute>(ConstantMaxNotTaken) ||
!ConstantMaxNotTaken->getType()->isPointerTy()) &&
"Max backedge count should be int");
}

ScalarEvolution::ExitLimit::ExitLimit(
const SCEV *E, const SCEV *ConstantMaxNotTaken,
const SCEV *SymbolicMaxNotTaken, bool MaxOrZero,
const SmallPtrSetImpl<const SCEVPredicate *> &PredSet)
ScalarEvolution::ExitLimit::ExitLimit(const SCEV *E,
const SCEV *ConstantMaxNotTaken,
const SCEV *SymbolicMaxNotTaken,
bool MaxOrZero,
ArrayRef<const SCEVPredicate *> PredList)
: ExitLimit(E, ConstantMaxNotTaken, SymbolicMaxNotTaken, MaxOrZero,
{ &PredSet }) {}
ArrayRef({PredList})) {}

/// Allocate memory for BackedgeTakenInfo and copy the not-taken count of each
/// computable exit into a persistent ExitNotTakenInfo array.
Expand Down Expand Up @@ -9098,7 +9105,7 @@ ScalarEvolution::computeExitLimitFromCondFromBinOp(
SymbolicMaxBECount =
isa<SCEVCouldNotCompute>(BECount) ? ConstantMaxBECount : BECount;
return ExitLimit(BECount, ConstantMaxBECount, SymbolicMaxBECount, false,
{ &EL0.Predicates, &EL1.Predicates });
{ArrayRef(EL0.Predicates), ArrayRef(EL1.Predicates)});
}

ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromICmp(
Expand Down Expand Up @@ -10131,7 +10138,7 @@ const SCEV *ScalarEvolution::stripInjectiveFunctions(const SCEV *S) const {
/// If the equation does not have a solution, SCEVCouldNotCompute is returned.
static const SCEV *
SolveLinEquationWithOverflow(const APInt &A, const SCEV *B,
SmallPtrSetImpl<const SCEVPredicate *> *Predicates,
SmallVectorImpl<const SCEVPredicate *> *Predicates,

ScalarEvolution &SE) {
uint32_t BW = A.getBitWidth();
Expand Down Expand Up @@ -10162,7 +10169,7 @@ SolveLinEquationWithOverflow(const APInt &A, const SCEV *B,
// Avoid adding a predicate that is known to be false.
if (SE.isKnownPredicate(CmpInst::ICMP_NE, URem, Zero))
return SE.getCouldNotCompute();
Predicates->insert(SE.getEqualPredicate(URem, Zero));
Predicates->push_back(SE.getEqualPredicate(URem, Zero));
}
}

Expand Down Expand Up @@ -10466,7 +10473,7 @@ ScalarEvolution::ExitLimit ScalarEvolution::howFarToZero(const SCEV *V,
// effectively V != 0. We know and take advantage of the fact that this
// expression only being used in a comparison by zero context.

SmallPtrSet<const SCEVPredicate *, 4> Predicates;
SmallVector<const SCEVPredicate *> Predicates;
// If the value is a constant
if (const SCEVConstant *C = dyn_cast<SCEVConstant>(V)) {
// If the value is already zero, the branch will execute zero times.
Expand Down Expand Up @@ -12885,7 +12892,7 @@ ScalarEvolution::ExitLimit
ScalarEvolution::howManyLessThans(const SCEV *LHS, const SCEV *RHS,
const Loop *L, bool IsSigned,
bool ControlsOnlyExit, bool AllowPredicates) {
SmallPtrSet<const SCEVPredicate *, 4> Predicates;
SmallVector<const SCEVPredicate *> Predicates;

const SCEVAddRecExpr *IV = dyn_cast<SCEVAddRecExpr>(LHS);
bool PredicatedIV = false;
Expand Down Expand Up @@ -13325,7 +13332,7 @@ ScalarEvolution::howManyLessThans(const SCEV *LHS, const SCEV *RHS,
ScalarEvolution::ExitLimit ScalarEvolution::howManyGreaterThans(
const SCEV *LHS, const SCEV *RHS, const Loop *L, bool IsSigned,
bool ControlsOnlyExit, bool AllowPredicates) {
SmallPtrSet<const SCEVPredicate *, 4> Predicates;
SmallVector<const SCEVPredicate *> Predicates;
// We handle only IV > Invariant
if (!isLoopInvariant(RHS, L))
return getCouldNotCompute();
Expand Down Expand Up @@ -13695,7 +13702,7 @@ static void PrintLoopInfo(raw_ostream &OS, ScalarEvolution *SE,
PrintSCEVWithTypeHint(OS, EC);
if (isa<SCEVCouldNotCompute>(EC)) {
// Retry with predicates.
SmallVector<const SCEVPredicate *, 4> Predicates;
SmallVector<const SCEVPredicate *> Predicates;
EC = SE->getPredicatedExitCount(L, ExitingBlock, &Predicates);
if (!isa<SCEVCouldNotCompute>(EC)) {
OS << "\n predicated exit count for " << ExitingBlock->getName()
Expand Down Expand Up @@ -13747,7 +13754,7 @@ static void PrintLoopInfo(raw_ostream &OS, ScalarEvolution *SE,
PrintSCEVWithTypeHint(OS, ExitBTC);
if (isa<SCEVCouldNotCompute>(ExitBTC)) {
// Retry with predicates.
SmallVector<const SCEVPredicate *, 4> Predicates;
SmallVector<const SCEVPredicate *> Predicates;
ExitBTC = SE->getPredicatedExitCount(L, ExitingBlock, &Predicates,
ScalarEvolution::SymbolicMaximum);
if (!isa<SCEVCouldNotCompute>(ExitBTC)) {
Expand Down Expand Up @@ -14727,7 +14734,7 @@ class SCEVPredicateRewriter : public SCEVRewriteVisitor<SCEVPredicateRewriter> {
/// If \p NewPreds is non-null, rewrite is free to add further predicates to
/// \p NewPreds such that the result will be an AddRecExpr.
static const SCEV *rewrite(const SCEV *S, const Loop *L, ScalarEvolution &SE,
SmallPtrSetImpl<const SCEVPredicate *> *NewPreds,
SmallVectorImpl<const SCEVPredicate *> *NewPreds,
const SCEVPredicate *Pred) {
SCEVPredicateRewriter Rewriter(L, SE, NewPreds, Pred);
return Rewriter.visit(S);
Expand Down Expand Up @@ -14783,17 +14790,18 @@ class SCEVPredicateRewriter : public SCEVRewriteVisitor<SCEVPredicateRewriter> {
}

private:
explicit SCEVPredicateRewriter(const Loop *L, ScalarEvolution &SE,
SmallPtrSetImpl<const SCEVPredicate *> *NewPreds,
const SCEVPredicate *Pred)
explicit SCEVPredicateRewriter(
const Loop *L, ScalarEvolution &SE,
SmallVectorImpl<const SCEVPredicate *> *NewPreds,
const SCEVPredicate *Pred)
: SCEVRewriteVisitor(SE), NewPreds(NewPreds), Pred(Pred), L(L) {}

bool addOverflowAssumption(const SCEVPredicate *P) {
if (!NewPreds) {
// Check if we've already made this assumption.
return Pred && Pred->implies(P);
}
NewPreds->insert(P);
NewPreds->push_back(P);
return true;
}

Expand Down Expand Up @@ -14829,7 +14837,7 @@ class SCEVPredicateRewriter : public SCEVRewriteVisitor<SCEVPredicateRewriter> {
return PredicatedRewrite->first;
}

SmallPtrSetImpl<const SCEVPredicate *> *NewPreds;
SmallVectorImpl<const SCEVPredicate *> *NewPreds;
const SCEVPredicate *Pred;
const Loop *L;
};
Expand All @@ -14844,8 +14852,8 @@ ScalarEvolution::rewriteUsingPredicate(const SCEV *S, const Loop *L,

const SCEVAddRecExpr *ScalarEvolution::convertSCEVToAddRecWithPredicates(
const SCEV *S, const Loop *L,
SmallPtrSetImpl<const SCEVPredicate *> &Preds) {
SmallPtrSet<const SCEVPredicate *, 4> TransformPreds;
SmallVectorImpl<const SCEVPredicate *> &Preds) {
SmallVector<const SCEVPredicate *> TransformPreds;
S = SCEVPredicateRewriter::rewrite(S, L, *this, &TransformPreds, nullptr);
auto *AddRec = dyn_cast<SCEVAddRecExpr>(S);

Expand All @@ -14854,7 +14862,7 @@ const SCEVAddRecExpr *ScalarEvolution::convertSCEVToAddRecWithPredicates(

// Since the transformation was successful, we can now transfer the SCEV
// predicates.
Preds.insert(TransformPreds.begin(), TransformPreds.end());
Preds.append(TransformPreds.begin(), TransformPreds.end());

return AddRec;
}
Expand Down Expand Up @@ -15101,7 +15109,7 @@ bool PredicatedScalarEvolution::hasNoOverflow(

const SCEVAddRecExpr *PredicatedScalarEvolution::getAsAddRec(Value *V) {
const SCEV *Expr = this->getSCEV(V);
SmallPtrSet<const SCEVPredicate *, 4> NewPreds;
SmallVector<const SCEVPredicate *, 4> NewPreds;
auto *New = SE.convertSCEVToAddRecWithPredicates(Expr, &L, NewPreds);

if (!New)
Expand Down

0 comments on commit 6022a3a

Please sign in to comment.