diff --git a/llvm/lib/Transforms/Vectorize/LoopIdiomVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopIdiomVectorize.cpp index 63f14208bf556b..c7a8700e145314 100644 --- a/llvm/lib/Transforms/Vectorize/LoopIdiomVectorize.cpp +++ b/llvm/lib/Transforms/Vectorize/LoopIdiomVectorize.cpp @@ -78,6 +78,13 @@ class LoopIdiomVectorize { const TargetTransformInfo *TTI; const DataLayout *DL; + // Blocks that will be used for inserting vectorized code. + BasicBlock *EndBlock = nullptr; + BasicBlock *VectorLoopPreheaderBlock = nullptr; + BasicBlock *VectorLoopStartBlock = nullptr; + BasicBlock *VectorLoopMismatchBlock = nullptr; + BasicBlock *VectorLoopIncBlock = nullptr; + public: explicit LoopIdiomVectorize(DominatorTree *DT, LoopInfo *LI, const TargetTransformInfo *TTI, @@ -95,9 +102,16 @@ class LoopIdiomVectorize { SmallVectorImpl &ExitBlocks); bool recognizeByteCompare(); + Value *expandFindMismatch(IRBuilder<> &Builder, DomTreeUpdater &DTU, GetElementPtrInst *GEPA, GetElementPtrInst *GEPB, Instruction *Index, Value *Start, Value *MaxLen); + + Value *createMaskedFindMismatch(IRBuilder<> &Builder, DomTreeUpdater &DTU, + GetElementPtrInst *GEPA, + GetElementPtrInst *GEPB, Value *ExtStart, + Value *ExtEnd); + void transformByteCompare(GetElementPtrInst *GEPA, GetElementPtrInst *GEPB, PHINode *IndPhi, Value *MaxLen, Instruction *Index, Value *Start, bool IncIdx, BasicBlock *FoundBB, @@ -331,6 +345,115 @@ bool LoopIdiomVectorize::recognizeByteCompare() { return true; } +Value *LoopIdiomVectorize::createMaskedFindMismatch( + IRBuilder<> &Builder, DomTreeUpdater &DTU, GetElementPtrInst *GEPA, + GetElementPtrInst *GEPB, Value *ExtStart, Value *ExtEnd) { + Type *I64Type = Builder.getInt64Ty(); + Type *ResType = Builder.getInt32Ty(); + Type *LoadType = Builder.getInt8Ty(); + Value *PtrA = GEPA->getPointerOperand(); + Value *PtrB = GEPB->getPointerOperand(); + + // At this point we know two things must be true: + // 1. Start <= End + // 2. ExtMaxLen <= MinPageSize due to the page checks. + // Therefore, we know that we can use a 64-bit induction variable that + // starts from 0 -> ExtMaxLen and it will not overflow. + ScalableVectorType *PredVTy = + ScalableVectorType::get(Builder.getInt1Ty(), 16); + + Value *InitialPred = Builder.CreateIntrinsic( + Intrinsic::get_active_lane_mask, {PredVTy, I64Type}, {ExtStart, ExtEnd}); + + Value *VecLen = Builder.CreateIntrinsic(Intrinsic::vscale, {I64Type}, {}); + VecLen = Builder.CreateMul(VecLen, ConstantInt::get(I64Type, 16), "", + /*HasNUW=*/true, /*HasNSW=*/true); + + Value *PFalse = Builder.CreateVectorSplat(PredVTy->getElementCount(), + Builder.getInt1(false)); + + BranchInst *JumpToVectorLoop = BranchInst::Create(VectorLoopStartBlock); + Builder.Insert(JumpToVectorLoop); + + DTU.applyUpdates({{DominatorTree::Insert, VectorLoopPreheaderBlock, + VectorLoopStartBlock}}); + + // Set up the first vector loop block by creating the PHIs, doing the vector + // loads and comparing the vectors. + Builder.SetInsertPoint(VectorLoopStartBlock); + PHINode *LoopPred = Builder.CreatePHI(PredVTy, 2, "mismatch_vec_loop_pred"); + LoopPred->addIncoming(InitialPred, VectorLoopPreheaderBlock); + PHINode *VectorIndexPhi = Builder.CreatePHI(I64Type, 2, "mismatch_vec_index"); + VectorIndexPhi->addIncoming(ExtStart, VectorLoopPreheaderBlock); + Type *VectorLoadType = ScalableVectorType::get(Builder.getInt8Ty(), 16); + Value *Passthru = ConstantInt::getNullValue(VectorLoadType); + + Value *VectorLhsGep = + Builder.CreateGEP(LoadType, PtrA, VectorIndexPhi, "", GEPA->isInBounds()); + Value *VectorLhsLoad = Builder.CreateMaskedLoad(VectorLoadType, VectorLhsGep, + Align(1), LoopPred, Passthru); + + Value *VectorRhsGep = + Builder.CreateGEP(LoadType, PtrB, VectorIndexPhi, "", GEPB->isInBounds()); + Value *VectorRhsLoad = Builder.CreateMaskedLoad(VectorLoadType, VectorRhsGep, + Align(1), LoopPred, Passthru); + + Value *VectorMatchCmp = Builder.CreateICmpNE(VectorLhsLoad, VectorRhsLoad); + VectorMatchCmp = Builder.CreateSelect(LoopPred, VectorMatchCmp, PFalse); + Value *VectorMatchHasActiveLanes = Builder.CreateOrReduce(VectorMatchCmp); + BranchInst *VectorEarlyExit = BranchInst::Create( + VectorLoopMismatchBlock, VectorLoopIncBlock, VectorMatchHasActiveLanes); + Builder.Insert(VectorEarlyExit); + + DTU.applyUpdates( + {{DominatorTree::Insert, VectorLoopStartBlock, VectorLoopMismatchBlock}, + {DominatorTree::Insert, VectorLoopStartBlock, VectorLoopIncBlock}}); + + // Increment the index counter and calculate the predicate for the next + // iteration of the loop. We branch back to the start of the loop if there + // is at least one active lane. + Builder.SetInsertPoint(VectorLoopIncBlock); + Value *NewVectorIndexPhi = + Builder.CreateAdd(VectorIndexPhi, VecLen, "", + /*HasNUW=*/true, /*HasNSW=*/true); + VectorIndexPhi->addIncoming(NewVectorIndexPhi, VectorLoopIncBlock); + Value *NewPred = + Builder.CreateIntrinsic(Intrinsic::get_active_lane_mask, + {PredVTy, I64Type}, {NewVectorIndexPhi, ExtEnd}); + LoopPred->addIncoming(NewPred, VectorLoopIncBlock); + + Value *PredHasActiveLanes = + Builder.CreateExtractElement(NewPred, uint64_t(0)); + BranchInst *VectorLoopBranchBack = + BranchInst::Create(VectorLoopStartBlock, EndBlock, PredHasActiveLanes); + Builder.Insert(VectorLoopBranchBack); + + DTU.applyUpdates( + {{DominatorTree::Insert, VectorLoopIncBlock, VectorLoopStartBlock}, + {DominatorTree::Insert, VectorLoopIncBlock, EndBlock}}); + + // If we found a mismatch then we need to calculate which lane in the vector + // had a mismatch and add that on to the current loop index. + Builder.SetInsertPoint(VectorLoopMismatchBlock); + PHINode *FoundPred = Builder.CreatePHI(PredVTy, 1, "mismatch_vec_found_pred"); + FoundPred->addIncoming(VectorMatchCmp, VectorLoopStartBlock); + PHINode *LastLoopPred = + Builder.CreatePHI(PredVTy, 1, "mismatch_vec_last_loop_pred"); + LastLoopPred->addIncoming(LoopPred, VectorLoopStartBlock); + PHINode *VectorFoundIndex = + Builder.CreatePHI(I64Type, 1, "mismatch_vec_found_index"); + VectorFoundIndex->addIncoming(VectorIndexPhi, VectorLoopStartBlock); + + Value *PredMatchCmp = Builder.CreateAnd(LastLoopPred, FoundPred); + Value *Ctz = Builder.CreateIntrinsic( + Intrinsic::experimental_cttz_elts, {ResType, PredMatchCmp->getType()}, + {PredMatchCmp, /*ZeroIsPoison=*/Builder.getInt1(true)}); + Ctz = Builder.CreateZExt(Ctz, I64Type); + Value *VectorLoopRes64 = Builder.CreateAdd(VectorFoundIndex, Ctz, "", + /*HasNUW=*/true, /*HasNSW=*/true); + return Builder.CreateTrunc(VectorLoopRes64, ResType); +} + Value *LoopIdiomVectorize::expandFindMismatch( IRBuilder<> &Builder, DomTreeUpdater &DTU, GetElementPtrInst *GEPA, GetElementPtrInst *GEPB, Instruction *Index, Value *Start, Value *MaxLen) { @@ -345,8 +468,7 @@ Value *LoopIdiomVectorize::expandFindMismatch( Type *ResType = Builder.getInt32Ty(); // Split block in the original loop preheader. - BasicBlock *EndBlock = - SplitBlock(Preheader, PHBranch, DT, LI, nullptr, "mismatch_end"); + EndBlock = SplitBlock(Preheader, PHBranch, DT, LI, nullptr, "mismatch_end"); // Create the blocks that we're going to need: // 1. A block for checking the zero-extended length exceeds 0 @@ -370,17 +492,17 @@ Value *LoopIdiomVectorize::expandFindMismatch( BasicBlock *MemCheckBlock = BasicBlock::Create( Ctx, "mismatch_mem_check", EndBlock->getParent(), EndBlock); - BasicBlock *VectorLoopPreheaderBlock = BasicBlock::Create( + VectorLoopPreheaderBlock = BasicBlock::Create( Ctx, "mismatch_vec_loop_preheader", EndBlock->getParent(), EndBlock); - BasicBlock *VectorLoopStartBlock = BasicBlock::Create( - Ctx, "mismatch_vec_loop", EndBlock->getParent(), EndBlock); + VectorLoopStartBlock = BasicBlock::Create(Ctx, "mismatch_vec_loop", + EndBlock->getParent(), EndBlock); - BasicBlock *VectorLoopIncBlock = BasicBlock::Create( - Ctx, "mismatch_vec_loop_inc", EndBlock->getParent(), EndBlock); + VectorLoopIncBlock = BasicBlock::Create(Ctx, "mismatch_vec_loop_inc", + EndBlock->getParent(), EndBlock); - BasicBlock *VectorLoopMismatchBlock = BasicBlock::Create( - Ctx, "mismatch_vec_loop_found", EndBlock->getParent(), EndBlock); + VectorLoopMismatchBlock = BasicBlock::Create(Ctx, "mismatch_vec_loop_found", + EndBlock->getParent(), EndBlock); BasicBlock *LoopPreHeaderBlock = BasicBlock::Create( Ctx, "mismatch_loop_pre", EndBlock->getParent(), EndBlock); @@ -491,104 +613,8 @@ Value *LoopIdiomVectorize::expandFindMismatch( // processed in each iteration, etc. Builder.SetInsertPoint(VectorLoopPreheaderBlock); - // At this point we know two things must be true: - // 1. Start <= End - // 2. ExtMaxLen <= MinPageSize due to the page checks. - // Therefore, we know that we can use a 64-bit induction variable that - // starts from 0 -> ExtMaxLen and it will not overflow. - ScalableVectorType *PredVTy = - ScalableVectorType::get(Builder.getInt1Ty(), 16); - - Value *InitialPred = Builder.CreateIntrinsic( - Intrinsic::get_active_lane_mask, {PredVTy, I64Type}, {ExtStart, ExtEnd}); - - Value *VecLen = Builder.CreateIntrinsic(Intrinsic::vscale, {I64Type}, {}); - VecLen = Builder.CreateMul(VecLen, ConstantInt::get(I64Type, 16), "", - /*HasNUW=*/true, /*HasNSW=*/true); - - Value *PFalse = Builder.CreateVectorSplat(PredVTy->getElementCount(), - Builder.getInt1(false)); - - BranchInst *JumpToVectorLoop = BranchInst::Create(VectorLoopStartBlock); - Builder.Insert(JumpToVectorLoop); - - DTU.applyUpdates({{DominatorTree::Insert, VectorLoopPreheaderBlock, - VectorLoopStartBlock}}); - - // Set up the first vector loop block by creating the PHIs, doing the vector - // loads and comparing the vectors. - Builder.SetInsertPoint(VectorLoopStartBlock); - PHINode *LoopPred = Builder.CreatePHI(PredVTy, 2, "mismatch_vec_loop_pred"); - LoopPred->addIncoming(InitialPred, VectorLoopPreheaderBlock); - PHINode *VectorIndexPhi = Builder.CreatePHI(I64Type, 2, "mismatch_vec_index"); - VectorIndexPhi->addIncoming(ExtStart, VectorLoopPreheaderBlock); - Type *VectorLoadType = ScalableVectorType::get(Builder.getInt8Ty(), 16); - Value *Passthru = ConstantInt::getNullValue(VectorLoadType); - - Value *VectorLhsGep = - Builder.CreateGEP(LoadType, PtrA, VectorIndexPhi, "", GEPA->isInBounds()); - Value *VectorLhsLoad = Builder.CreateMaskedLoad(VectorLoadType, VectorLhsGep, - Align(1), LoopPred, Passthru); - - Value *VectorRhsGep = - Builder.CreateGEP(LoadType, PtrB, VectorIndexPhi, "", GEPB->isInBounds()); - Value *VectorRhsLoad = Builder.CreateMaskedLoad(VectorLoadType, VectorRhsGep, - Align(1), LoopPred, Passthru); - - Value *VectorMatchCmp = Builder.CreateICmpNE(VectorLhsLoad, VectorRhsLoad); - VectorMatchCmp = Builder.CreateSelect(LoopPred, VectorMatchCmp, PFalse); - Value *VectorMatchHasActiveLanes = Builder.CreateOrReduce(VectorMatchCmp); - BranchInst *VectorEarlyExit = BranchInst::Create( - VectorLoopMismatchBlock, VectorLoopIncBlock, VectorMatchHasActiveLanes); - Builder.Insert(VectorEarlyExit); - - DTU.applyUpdates( - {{DominatorTree::Insert, VectorLoopStartBlock, VectorLoopMismatchBlock}, - {DominatorTree::Insert, VectorLoopStartBlock, VectorLoopIncBlock}}); - - // Increment the index counter and calculate the predicate for the next - // iteration of the loop. We branch back to the start of the loop if there - // is at least one active lane. - Builder.SetInsertPoint(VectorLoopIncBlock); - Value *NewVectorIndexPhi = - Builder.CreateAdd(VectorIndexPhi, VecLen, "", - /*HasNUW=*/true, /*HasNSW=*/true); - VectorIndexPhi->addIncoming(NewVectorIndexPhi, VectorLoopIncBlock); - Value *NewPred = - Builder.CreateIntrinsic(Intrinsic::get_active_lane_mask, - {PredVTy, I64Type}, {NewVectorIndexPhi, ExtEnd}); - LoopPred->addIncoming(NewPred, VectorLoopIncBlock); - - Value *PredHasActiveLanes = - Builder.CreateExtractElement(NewPred, uint64_t(0)); - BranchInst *VectorLoopBranchBack = - BranchInst::Create(VectorLoopStartBlock, EndBlock, PredHasActiveLanes); - Builder.Insert(VectorLoopBranchBack); - - DTU.applyUpdates( - {{DominatorTree::Insert, VectorLoopIncBlock, VectorLoopStartBlock}, - {DominatorTree::Insert, VectorLoopIncBlock, EndBlock}}); - - // If we found a mismatch then we need to calculate which lane in the vector - // had a mismatch and add that on to the current loop index. - Builder.SetInsertPoint(VectorLoopMismatchBlock); - PHINode *FoundPred = Builder.CreatePHI(PredVTy, 1, "mismatch_vec_found_pred"); - FoundPred->addIncoming(VectorMatchCmp, VectorLoopStartBlock); - PHINode *LastLoopPred = - Builder.CreatePHI(PredVTy, 1, "mismatch_vec_last_loop_pred"); - LastLoopPred->addIncoming(LoopPred, VectorLoopStartBlock); - PHINode *VectorFoundIndex = - Builder.CreatePHI(I64Type, 1, "mismatch_vec_found_index"); - VectorFoundIndex->addIncoming(VectorIndexPhi, VectorLoopStartBlock); - - Value *PredMatchCmp = Builder.CreateAnd(LastLoopPred, FoundPred); - Value *Ctz = Builder.CreateIntrinsic( - Intrinsic::experimental_cttz_elts, {ResType, PredMatchCmp->getType()}, - {PredMatchCmp, /*ZeroIsPoison=*/Builder.getInt1(true)}); - Ctz = Builder.CreateZExt(Ctz, I64Type); - Value *VectorLoopRes64 = Builder.CreateAdd(VectorFoundIndex, Ctz, "", - /*HasNUW=*/true, /*HasNSW=*/true); - Value *VectorLoopRes = Builder.CreateTrunc(VectorLoopRes64, ResType); + Value *VectorLoopRes = + createMaskedFindMismatch(Builder, DTU, GEPA, GEPB, ExtStart, ExtEnd); Builder.Insert(BranchInst::Create(EndBlock));