diff --git a/llvm/include/llvm/Analysis/ScalarEvolution.h b/llvm/include/llvm/Analysis/ScalarEvolution.h index 68b860725752d0..179a2c38d9d3c2 100644 --- a/llvm/include/llvm/Analysis/ScalarEvolution.h +++ b/llvm/include/llvm/Analysis/ScalarEvolution.h @@ -1125,15 +1125,10 @@ class ScalarEvolution { // Not taken either exactly ConstantMaxNotTaken or zero times bool MaxOrZero = false; - /// A set of predicate guards for this ExitLimit. The result is only valid - /// if all of the predicates in \c Predicates evaluate to 'true' at + /// A vector of predicate guards for this ExitLimit. The result is only + /// valid if all of the predicates in \c Predicates evaluate to 'true' at /// run-time. - SmallPtrSet Predicates; - - void addPredicate(const SCEVPredicate *P) { - assert(!isa(P) && "Only add leaf predicates here!"); - Predicates.insert(P); - } + SmallVector Predicates; /// Construct either an exact exit limit from a constant, or an unknown /// one from a SCEVCouldNotCompute. No other types of SCEVs are allowed @@ -1142,12 +1137,11 @@ class ScalarEvolution { ExitLimit(const SCEV *E, const SCEV *ConstantMaxNotTaken, const SCEV *SymbolicMaxNotTaken, bool MaxOrZero, - ArrayRef *> - PredSetList = {}); + ArrayRef> PredLists = {}); ExitLimit(const SCEV *E, const SCEV *ConstantMaxNotTaken, const SCEV *SymbolicMaxNotTaken, bool MaxOrZero, - const SmallPtrSetImpl &PredSet); + ArrayRef PredList); /// Test whether this ExitLimit contains any computed information, or /// whether it's all SCEVCouldNotCompute values. @@ -1297,7 +1291,7 @@ class ScalarEvolution { /// adding additional predicates to \p Preds as required. const SCEVAddRecExpr *convertSCEVToAddRecWithPredicates( const SCEV *S, const Loop *L, - SmallPtrSetImpl &Preds); + SmallVectorImpl &Preds); /// Compute \p LHS - \p RHS and returns the result as an APInt if it is a /// constant, and std::nullopt if it isn't. @@ -1489,12 +1483,13 @@ class ScalarEvolution { const SCEV *ExactNotTaken; const SCEV *ConstantMaxNotTaken; const SCEV *SymbolicMaxNotTaken; - SmallPtrSet Predicates; + SmallVector Predicates; - explicit ExitNotTakenInfo( - PoisoningVH ExitingBlock, const SCEV *ExactNotTaken, - const SCEV *ConstantMaxNotTaken, const SCEV *SymbolicMaxNotTaken, - const SmallPtrSet &Predicates) + explicit ExitNotTakenInfo(PoisoningVH ExitingBlock, + const SCEV *ExactNotTaken, + const SCEV *ConstantMaxNotTaken, + const SCEV *SymbolicMaxNotTaken, + ArrayRef Predicates) : ExitingBlock(ExitingBlock), ExactNotTaken(ExactNotTaken), ConstantMaxNotTaken(ConstantMaxNotTaken), SymbolicMaxNotTaken(SymbolicMaxNotTaken), Predicates(Predicates) {} diff --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp index d5db3263294a6d..c939270ed39a65 100644 --- a/llvm/lib/Analysis/ScalarEvolution.cpp +++ b/llvm/lib/Analysis/ScalarEvolution.cpp @@ -8693,12 +8693,12 @@ bool ScalarEvolution::BackedgeTakenInfo::isConstantMaxOrZero( } ScalarEvolution::ExitLimit::ExitLimit(const SCEV *E) - : ExitLimit(E, E, E, false, {}) {} + : ExitLimit(E, E, E, false) {} ScalarEvolution::ExitLimit::ExitLimit( const SCEV *E, const SCEV *ConstantMaxNotTaken, const SCEV *SymbolicMaxNotTaken, bool MaxOrZero, - ArrayRef *> PredSetList) + ArrayRef> PredLists) : ExactNotTaken(E), ConstantMaxNotTaken(ConstantMaxNotTaken), SymbolicMaxNotTaken(SymbolicMaxNotTaken), MaxOrZero(MaxOrZero) { // If we prove the max count is zero, so is the symbolic bound. This happens @@ -8721,9 +8721,15 @@ ScalarEvolution::ExitLimit::ExitLimit( assert((isa(ConstantMaxNotTaken) || isa(ConstantMaxNotTaken)) && "No point in having a non-constant max backedge taken count!"); - for (const auto *PredSet : PredSetList) - for (const auto *P : *PredSet) - addPredicate(P); + SmallPtrSet SeenPreds; + for (const auto PredList : PredLists) + for (const auto *P : PredList) { + if (SeenPreds.contains(P)) + continue; + assert(!isa(P) && "Only add leaf predicates here!"); + SeenPreds.insert(P); + Predicates.push_back(P); + } assert((isa(E) || !E->getType()->isPointerTy()) && "Backedge count should be int"); assert((isa(ConstantMaxNotTaken) || @@ -8731,12 +8737,13 @@ ScalarEvolution::ExitLimit::ExitLimit( "Max backedge count should be int"); } -ScalarEvolution::ExitLimit::ExitLimit( - const SCEV *E, const SCEV *ConstantMaxNotTaken, - const SCEV *SymbolicMaxNotTaken, bool MaxOrZero, - const SmallPtrSetImpl &PredSet) +ScalarEvolution::ExitLimit::ExitLimit(const SCEV *E, + const SCEV *ConstantMaxNotTaken, + const SCEV *SymbolicMaxNotTaken, + bool MaxOrZero, + ArrayRef PredList) : ExitLimit(E, ConstantMaxNotTaken, SymbolicMaxNotTaken, MaxOrZero, - { &PredSet }) {} + ArrayRef({PredList})) {} /// Allocate memory for BackedgeTakenInfo and copy the not-taken count of each /// computable exit into a persistent ExitNotTakenInfo array. @@ -9098,7 +9105,7 @@ ScalarEvolution::computeExitLimitFromCondFromBinOp( SymbolicMaxBECount = isa(BECount) ? ConstantMaxBECount : BECount; return ExitLimit(BECount, ConstantMaxBECount, SymbolicMaxBECount, false, - { &EL0.Predicates, &EL1.Predicates }); + {ArrayRef(EL0.Predicates), ArrayRef(EL1.Predicates)}); } ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromICmp( @@ -10131,7 +10138,7 @@ const SCEV *ScalarEvolution::stripInjectiveFunctions(const SCEV *S) const { /// If the equation does not have a solution, SCEVCouldNotCompute is returned. static const SCEV * SolveLinEquationWithOverflow(const APInt &A, const SCEV *B, - SmallPtrSetImpl *Predicates, + SmallVectorImpl *Predicates, ScalarEvolution &SE) { uint32_t BW = A.getBitWidth(); @@ -10162,7 +10169,7 @@ SolveLinEquationWithOverflow(const APInt &A, const SCEV *B, // Avoid adding a predicate that is known to be false. if (SE.isKnownPredicate(CmpInst::ICMP_NE, URem, Zero)) return SE.getCouldNotCompute(); - Predicates->insert(SE.getEqualPredicate(URem, Zero)); + Predicates->push_back(SE.getEqualPredicate(URem, Zero)); } } @@ -10466,7 +10473,7 @@ ScalarEvolution::ExitLimit ScalarEvolution::howFarToZero(const SCEV *V, // effectively V != 0. We know and take advantage of the fact that this // expression only being used in a comparison by zero context. - SmallPtrSet Predicates; + SmallVector Predicates; // If the value is a constant if (const SCEVConstant *C = dyn_cast(V)) { // If the value is already zero, the branch will execute zero times. @@ -12885,7 +12892,7 @@ ScalarEvolution::ExitLimit ScalarEvolution::howManyLessThans(const SCEV *LHS, const SCEV *RHS, const Loop *L, bool IsSigned, bool ControlsOnlyExit, bool AllowPredicates) { - SmallPtrSet Predicates; + SmallVector Predicates; const SCEVAddRecExpr *IV = dyn_cast(LHS); bool PredicatedIV = false; @@ -13325,7 +13332,7 @@ ScalarEvolution::howManyLessThans(const SCEV *LHS, const SCEV *RHS, ScalarEvolution::ExitLimit ScalarEvolution::howManyGreaterThans( const SCEV *LHS, const SCEV *RHS, const Loop *L, bool IsSigned, bool ControlsOnlyExit, bool AllowPredicates) { - SmallPtrSet Predicates; + SmallVector Predicates; // We handle only IV > Invariant if (!isLoopInvariant(RHS, L)) return getCouldNotCompute(); @@ -13695,7 +13702,7 @@ static void PrintLoopInfo(raw_ostream &OS, ScalarEvolution *SE, PrintSCEVWithTypeHint(OS, EC); if (isa(EC)) { // Retry with predicates. - SmallVector Predicates; + SmallVector Predicates; EC = SE->getPredicatedExitCount(L, ExitingBlock, &Predicates); if (!isa(EC)) { OS << "\n predicated exit count for " << ExitingBlock->getName() @@ -13747,7 +13754,7 @@ static void PrintLoopInfo(raw_ostream &OS, ScalarEvolution *SE, PrintSCEVWithTypeHint(OS, ExitBTC); if (isa(ExitBTC)) { // Retry with predicates. - SmallVector Predicates; + SmallVector Predicates; ExitBTC = SE->getPredicatedExitCount(L, ExitingBlock, &Predicates, ScalarEvolution::SymbolicMaximum); if (!isa(ExitBTC)) { @@ -14727,7 +14734,7 @@ class SCEVPredicateRewriter : public SCEVRewriteVisitor { /// If \p NewPreds is non-null, rewrite is free to add further predicates to /// \p NewPreds such that the result will be an AddRecExpr. static const SCEV *rewrite(const SCEV *S, const Loop *L, ScalarEvolution &SE, - SmallPtrSetImpl *NewPreds, + SmallVectorImpl *NewPreds, const SCEVPredicate *Pred) { SCEVPredicateRewriter Rewriter(L, SE, NewPreds, Pred); return Rewriter.visit(S); @@ -14783,9 +14790,10 @@ class SCEVPredicateRewriter : public SCEVRewriteVisitor { } private: - explicit SCEVPredicateRewriter(const Loop *L, ScalarEvolution &SE, - SmallPtrSetImpl *NewPreds, - const SCEVPredicate *Pred) + explicit SCEVPredicateRewriter( + const Loop *L, ScalarEvolution &SE, + SmallVectorImpl *NewPreds, + const SCEVPredicate *Pred) : SCEVRewriteVisitor(SE), NewPreds(NewPreds), Pred(Pred), L(L) {} bool addOverflowAssumption(const SCEVPredicate *P) { @@ -14793,7 +14801,7 @@ class SCEVPredicateRewriter : public SCEVRewriteVisitor { // Check if we've already made this assumption. return Pred && Pred->implies(P); } - NewPreds->insert(P); + NewPreds->push_back(P); return true; } @@ -14829,7 +14837,7 @@ class SCEVPredicateRewriter : public SCEVRewriteVisitor { return PredicatedRewrite->first; } - SmallPtrSetImpl *NewPreds; + SmallVectorImpl *NewPreds; const SCEVPredicate *Pred; const Loop *L; }; @@ -14844,8 +14852,8 @@ ScalarEvolution::rewriteUsingPredicate(const SCEV *S, const Loop *L, const SCEVAddRecExpr *ScalarEvolution::convertSCEVToAddRecWithPredicates( const SCEV *S, const Loop *L, - SmallPtrSetImpl &Preds) { - SmallPtrSet TransformPreds; + SmallVectorImpl &Preds) { + SmallVector TransformPreds; S = SCEVPredicateRewriter::rewrite(S, L, *this, &TransformPreds, nullptr); auto *AddRec = dyn_cast(S); @@ -14854,7 +14862,7 @@ const SCEVAddRecExpr *ScalarEvolution::convertSCEVToAddRecWithPredicates( // Since the transformation was successful, we can now transfer the SCEV // predicates. - Preds.insert(TransformPreds.begin(), TransformPreds.end()); + Preds.append(TransformPreds.begin(), TransformPreds.end()); return AddRec; } @@ -15101,7 +15109,7 @@ bool PredicatedScalarEvolution::hasNoOverflow( const SCEVAddRecExpr *PredicatedScalarEvolution::getAsAddRec(Value *V) { const SCEV *Expr = this->getSCEV(V); - SmallPtrSet NewPreds; + SmallVector NewPreds; auto *New = SE.convertSCEVToAddRecWithPredicates(Expr, &L, NewPreds); if (!New)