diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp index 0d35bfb921dc79..e8653498d32a12 100644 --- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp +++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp @@ -1527,6 +1527,10 @@ class LoopVectorizationCostModel { getReductionPatternCost(Instruction *I, ElementCount VF, Type *VectorTy, TTI::TargetCostKind CostKind) const; + /// Returns true if \p Op should be considered invariant and if it is + /// trivially hoistable. + bool shouldConsiderInvariant(Value *Op); + private: unsigned NumPredStores = 0; @@ -6382,6 +6386,17 @@ void LoopVectorizationCostModel::setVectorizedCallDecision(ElementCount VF) { } } +bool LoopVectorizationCostModel::shouldConsiderInvariant(Value *Op) { + if (!Legal->isInvariant(Op)) + return false; + // Consider Op invariant, if it or its operands aren't predicated + // instruction in the loop. In that case, it is not trivially hoistable. + return !isa(Op) || !TheLoop->contains(cast(Op)) || + (!isPredicatedInst(cast(Op)) && + all_of(cast(Op)->operands(), + [this](Value *Op) { return shouldConsiderInvariant(Op); })); +} + InstructionCost LoopVectorizationCostModel::getInstructionCost(Instruction *I, ElementCount VF) { @@ -6621,19 +6636,8 @@ LoopVectorizationCostModel::getInstructionCost(Instruction *I, Op2 = cast(PSE.getSCEV(Op2))->getValue(); } auto Op2Info = TTI.getOperandInfo(Op2); - std::function IsInvariant = - [this, &IsInvariant](Value *Op) -> bool { - if (!Legal->isInvariant(Op)) - return false; - // Consider Op2invariant, if it or its operands aren't predicated - // instruction in the loop. In that case, it is not trivially hoistable. - return !isa(Op) || - !TheLoop->contains(cast(Op)) || - (!isPredicatedInst(cast(Op)) && - all_of(cast(Op)->operands(), - [&IsInvariant](Value *Op) { return IsInvariant(Op); })); - }; - if (Op2Info.Kind == TargetTransformInfo::OK_AnyValue && IsInvariant(Op2)) + if (Op2Info.Kind == TargetTransformInfo::OK_AnyValue && + shouldConsiderInvariant(Op2)) Op2Info.Kind = TargetTransformInfo::OK_UniformValue; SmallVector Operands(I->operand_values());