Skip to content

Commit

Permalink
[LoopIdiomVectorize][NFC] Factoring out the part that handles vectori…
Browse files Browse the repository at this point in the history
…zation strategy (#94682)

To pave the way for porting LIV to RISC-V, which uses VP intrinsics for
vectors.

NFC.
  • Loading branch information
mshockwave authored Jul 3, 2024
1 parent 0856064 commit de5ff38
Showing 1 changed file with 133 additions and 107 deletions.
240 changes: 133 additions & 107 deletions llvm/lib/Transforms/Vectorize/LoopIdiomVectorize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,13 @@ class LoopIdiomVectorize {
const TargetTransformInfo *TTI;
const DataLayout *DL;

// Blocks that will be used for inserting vectorized code.
BasicBlock *EndBlock = nullptr;
BasicBlock *VectorLoopPreheaderBlock = nullptr;
BasicBlock *VectorLoopStartBlock = nullptr;
BasicBlock *VectorLoopMismatchBlock = nullptr;
BasicBlock *VectorLoopIncBlock = nullptr;

public:
explicit LoopIdiomVectorize(DominatorTree *DT, LoopInfo *LI,
const TargetTransformInfo *TTI,
Expand All @@ -95,9 +102,16 @@ class LoopIdiomVectorize {
SmallVectorImpl<BasicBlock *> &ExitBlocks);

bool recognizeByteCompare();

Value *expandFindMismatch(IRBuilder<> &Builder, DomTreeUpdater &DTU,
GetElementPtrInst *GEPA, GetElementPtrInst *GEPB,
Instruction *Index, Value *Start, Value *MaxLen);

Value *createMaskedFindMismatch(IRBuilder<> &Builder, DomTreeUpdater &DTU,
GetElementPtrInst *GEPA,
GetElementPtrInst *GEPB, Value *ExtStart,
Value *ExtEnd);

void transformByteCompare(GetElementPtrInst *GEPA, GetElementPtrInst *GEPB,
PHINode *IndPhi, Value *MaxLen, Instruction *Index,
Value *Start, bool IncIdx, BasicBlock *FoundBB,
Expand Down Expand Up @@ -331,6 +345,115 @@ bool LoopIdiomVectorize::recognizeByteCompare() {
return true;
}

Value *LoopIdiomVectorize::createMaskedFindMismatch(
IRBuilder<> &Builder, DomTreeUpdater &DTU, GetElementPtrInst *GEPA,
GetElementPtrInst *GEPB, Value *ExtStart, Value *ExtEnd) {
Type *I64Type = Builder.getInt64Ty();
Type *ResType = Builder.getInt32Ty();
Type *LoadType = Builder.getInt8Ty();
Value *PtrA = GEPA->getPointerOperand();
Value *PtrB = GEPB->getPointerOperand();

// At this point we know two things must be true:
// 1. Start <= End
// 2. ExtMaxLen <= MinPageSize due to the page checks.
// Therefore, we know that we can use a 64-bit induction variable that
// starts from 0 -> ExtMaxLen and it will not overflow.
ScalableVectorType *PredVTy =
ScalableVectorType::get(Builder.getInt1Ty(), 16);

Value *InitialPred = Builder.CreateIntrinsic(
Intrinsic::get_active_lane_mask, {PredVTy, I64Type}, {ExtStart, ExtEnd});

Value *VecLen = Builder.CreateIntrinsic(Intrinsic::vscale, {I64Type}, {});
VecLen = Builder.CreateMul(VecLen, ConstantInt::get(I64Type, 16), "",
/*HasNUW=*/true, /*HasNSW=*/true);

Value *PFalse = Builder.CreateVectorSplat(PredVTy->getElementCount(),
Builder.getInt1(false));

BranchInst *JumpToVectorLoop = BranchInst::Create(VectorLoopStartBlock);
Builder.Insert(JumpToVectorLoop);

DTU.applyUpdates({{DominatorTree::Insert, VectorLoopPreheaderBlock,
VectorLoopStartBlock}});

// Set up the first vector loop block by creating the PHIs, doing the vector
// loads and comparing the vectors.
Builder.SetInsertPoint(VectorLoopStartBlock);
PHINode *LoopPred = Builder.CreatePHI(PredVTy, 2, "mismatch_vec_loop_pred");
LoopPred->addIncoming(InitialPred, VectorLoopPreheaderBlock);
PHINode *VectorIndexPhi = Builder.CreatePHI(I64Type, 2, "mismatch_vec_index");
VectorIndexPhi->addIncoming(ExtStart, VectorLoopPreheaderBlock);
Type *VectorLoadType = ScalableVectorType::get(Builder.getInt8Ty(), 16);
Value *Passthru = ConstantInt::getNullValue(VectorLoadType);

Value *VectorLhsGep =
Builder.CreateGEP(LoadType, PtrA, VectorIndexPhi, "", GEPA->isInBounds());
Value *VectorLhsLoad = Builder.CreateMaskedLoad(VectorLoadType, VectorLhsGep,
Align(1), LoopPred, Passthru);

Value *VectorRhsGep =
Builder.CreateGEP(LoadType, PtrB, VectorIndexPhi, "", GEPB->isInBounds());
Value *VectorRhsLoad = Builder.CreateMaskedLoad(VectorLoadType, VectorRhsGep,
Align(1), LoopPred, Passthru);

Value *VectorMatchCmp = Builder.CreateICmpNE(VectorLhsLoad, VectorRhsLoad);
VectorMatchCmp = Builder.CreateSelect(LoopPred, VectorMatchCmp, PFalse);
Value *VectorMatchHasActiveLanes = Builder.CreateOrReduce(VectorMatchCmp);
BranchInst *VectorEarlyExit = BranchInst::Create(
VectorLoopMismatchBlock, VectorLoopIncBlock, VectorMatchHasActiveLanes);
Builder.Insert(VectorEarlyExit);

DTU.applyUpdates(
{{DominatorTree::Insert, VectorLoopStartBlock, VectorLoopMismatchBlock},
{DominatorTree::Insert, VectorLoopStartBlock, VectorLoopIncBlock}});

// Increment the index counter and calculate the predicate for the next
// iteration of the loop. We branch back to the start of the loop if there
// is at least one active lane.
Builder.SetInsertPoint(VectorLoopIncBlock);
Value *NewVectorIndexPhi =
Builder.CreateAdd(VectorIndexPhi, VecLen, "",
/*HasNUW=*/true, /*HasNSW=*/true);
VectorIndexPhi->addIncoming(NewVectorIndexPhi, VectorLoopIncBlock);
Value *NewPred =
Builder.CreateIntrinsic(Intrinsic::get_active_lane_mask,
{PredVTy, I64Type}, {NewVectorIndexPhi, ExtEnd});
LoopPred->addIncoming(NewPred, VectorLoopIncBlock);

Value *PredHasActiveLanes =
Builder.CreateExtractElement(NewPred, uint64_t(0));
BranchInst *VectorLoopBranchBack =
BranchInst::Create(VectorLoopStartBlock, EndBlock, PredHasActiveLanes);
Builder.Insert(VectorLoopBranchBack);

DTU.applyUpdates(
{{DominatorTree::Insert, VectorLoopIncBlock, VectorLoopStartBlock},
{DominatorTree::Insert, VectorLoopIncBlock, EndBlock}});

// If we found a mismatch then we need to calculate which lane in the vector
// had a mismatch and add that on to the current loop index.
Builder.SetInsertPoint(VectorLoopMismatchBlock);
PHINode *FoundPred = Builder.CreatePHI(PredVTy, 1, "mismatch_vec_found_pred");
FoundPred->addIncoming(VectorMatchCmp, VectorLoopStartBlock);
PHINode *LastLoopPred =
Builder.CreatePHI(PredVTy, 1, "mismatch_vec_last_loop_pred");
LastLoopPred->addIncoming(LoopPred, VectorLoopStartBlock);
PHINode *VectorFoundIndex =
Builder.CreatePHI(I64Type, 1, "mismatch_vec_found_index");
VectorFoundIndex->addIncoming(VectorIndexPhi, VectorLoopStartBlock);

Value *PredMatchCmp = Builder.CreateAnd(LastLoopPred, FoundPred);
Value *Ctz = Builder.CreateIntrinsic(
Intrinsic::experimental_cttz_elts, {ResType, PredMatchCmp->getType()},
{PredMatchCmp, /*ZeroIsPoison=*/Builder.getInt1(true)});
Ctz = Builder.CreateZExt(Ctz, I64Type);
Value *VectorLoopRes64 = Builder.CreateAdd(VectorFoundIndex, Ctz, "",
/*HasNUW=*/true, /*HasNSW=*/true);
return Builder.CreateTrunc(VectorLoopRes64, ResType);
}

Value *LoopIdiomVectorize::expandFindMismatch(
IRBuilder<> &Builder, DomTreeUpdater &DTU, GetElementPtrInst *GEPA,
GetElementPtrInst *GEPB, Instruction *Index, Value *Start, Value *MaxLen) {
Expand All @@ -345,8 +468,7 @@ Value *LoopIdiomVectorize::expandFindMismatch(
Type *ResType = Builder.getInt32Ty();

// Split block in the original loop preheader.
BasicBlock *EndBlock =
SplitBlock(Preheader, PHBranch, DT, LI, nullptr, "mismatch_end");
EndBlock = SplitBlock(Preheader, PHBranch, DT, LI, nullptr, "mismatch_end");

// Create the blocks that we're going to need:
// 1. A block for checking the zero-extended length exceeds 0
Expand All @@ -370,17 +492,17 @@ Value *LoopIdiomVectorize::expandFindMismatch(
BasicBlock *MemCheckBlock = BasicBlock::Create(
Ctx, "mismatch_mem_check", EndBlock->getParent(), EndBlock);

BasicBlock *VectorLoopPreheaderBlock = BasicBlock::Create(
VectorLoopPreheaderBlock = BasicBlock::Create(
Ctx, "mismatch_vec_loop_preheader", EndBlock->getParent(), EndBlock);

BasicBlock *VectorLoopStartBlock = BasicBlock::Create(
Ctx, "mismatch_vec_loop", EndBlock->getParent(), EndBlock);
VectorLoopStartBlock = BasicBlock::Create(Ctx, "mismatch_vec_loop",
EndBlock->getParent(), EndBlock);

BasicBlock *VectorLoopIncBlock = BasicBlock::Create(
Ctx, "mismatch_vec_loop_inc", EndBlock->getParent(), EndBlock);
VectorLoopIncBlock = BasicBlock::Create(Ctx, "mismatch_vec_loop_inc",
EndBlock->getParent(), EndBlock);

BasicBlock *VectorLoopMismatchBlock = BasicBlock::Create(
Ctx, "mismatch_vec_loop_found", EndBlock->getParent(), EndBlock);
VectorLoopMismatchBlock = BasicBlock::Create(Ctx, "mismatch_vec_loop_found",
EndBlock->getParent(), EndBlock);

BasicBlock *LoopPreHeaderBlock = BasicBlock::Create(
Ctx, "mismatch_loop_pre", EndBlock->getParent(), EndBlock);
Expand Down Expand Up @@ -491,104 +613,8 @@ Value *LoopIdiomVectorize::expandFindMismatch(
// processed in each iteration, etc.
Builder.SetInsertPoint(VectorLoopPreheaderBlock);

// At this point we know two things must be true:
// 1. Start <= End
// 2. ExtMaxLen <= MinPageSize due to the page checks.
// Therefore, we know that we can use a 64-bit induction variable that
// starts from 0 -> ExtMaxLen and it will not overflow.
ScalableVectorType *PredVTy =
ScalableVectorType::get(Builder.getInt1Ty(), 16);

Value *InitialPred = Builder.CreateIntrinsic(
Intrinsic::get_active_lane_mask, {PredVTy, I64Type}, {ExtStart, ExtEnd});

Value *VecLen = Builder.CreateIntrinsic(Intrinsic::vscale, {I64Type}, {});
VecLen = Builder.CreateMul(VecLen, ConstantInt::get(I64Type, 16), "",
/*HasNUW=*/true, /*HasNSW=*/true);

Value *PFalse = Builder.CreateVectorSplat(PredVTy->getElementCount(),
Builder.getInt1(false));

BranchInst *JumpToVectorLoop = BranchInst::Create(VectorLoopStartBlock);
Builder.Insert(JumpToVectorLoop);

DTU.applyUpdates({{DominatorTree::Insert, VectorLoopPreheaderBlock,
VectorLoopStartBlock}});

// Set up the first vector loop block by creating the PHIs, doing the vector
// loads and comparing the vectors.
Builder.SetInsertPoint(VectorLoopStartBlock);
PHINode *LoopPred = Builder.CreatePHI(PredVTy, 2, "mismatch_vec_loop_pred");
LoopPred->addIncoming(InitialPred, VectorLoopPreheaderBlock);
PHINode *VectorIndexPhi = Builder.CreatePHI(I64Type, 2, "mismatch_vec_index");
VectorIndexPhi->addIncoming(ExtStart, VectorLoopPreheaderBlock);
Type *VectorLoadType = ScalableVectorType::get(Builder.getInt8Ty(), 16);
Value *Passthru = ConstantInt::getNullValue(VectorLoadType);

Value *VectorLhsGep =
Builder.CreateGEP(LoadType, PtrA, VectorIndexPhi, "", GEPA->isInBounds());
Value *VectorLhsLoad = Builder.CreateMaskedLoad(VectorLoadType, VectorLhsGep,
Align(1), LoopPred, Passthru);

Value *VectorRhsGep =
Builder.CreateGEP(LoadType, PtrB, VectorIndexPhi, "", GEPB->isInBounds());
Value *VectorRhsLoad = Builder.CreateMaskedLoad(VectorLoadType, VectorRhsGep,
Align(1), LoopPred, Passthru);

Value *VectorMatchCmp = Builder.CreateICmpNE(VectorLhsLoad, VectorRhsLoad);
VectorMatchCmp = Builder.CreateSelect(LoopPred, VectorMatchCmp, PFalse);
Value *VectorMatchHasActiveLanes = Builder.CreateOrReduce(VectorMatchCmp);
BranchInst *VectorEarlyExit = BranchInst::Create(
VectorLoopMismatchBlock, VectorLoopIncBlock, VectorMatchHasActiveLanes);
Builder.Insert(VectorEarlyExit);

DTU.applyUpdates(
{{DominatorTree::Insert, VectorLoopStartBlock, VectorLoopMismatchBlock},
{DominatorTree::Insert, VectorLoopStartBlock, VectorLoopIncBlock}});

// Increment the index counter and calculate the predicate for the next
// iteration of the loop. We branch back to the start of the loop if there
// is at least one active lane.
Builder.SetInsertPoint(VectorLoopIncBlock);
Value *NewVectorIndexPhi =
Builder.CreateAdd(VectorIndexPhi, VecLen, "",
/*HasNUW=*/true, /*HasNSW=*/true);
VectorIndexPhi->addIncoming(NewVectorIndexPhi, VectorLoopIncBlock);
Value *NewPred =
Builder.CreateIntrinsic(Intrinsic::get_active_lane_mask,
{PredVTy, I64Type}, {NewVectorIndexPhi, ExtEnd});
LoopPred->addIncoming(NewPred, VectorLoopIncBlock);

Value *PredHasActiveLanes =
Builder.CreateExtractElement(NewPred, uint64_t(0));
BranchInst *VectorLoopBranchBack =
BranchInst::Create(VectorLoopStartBlock, EndBlock, PredHasActiveLanes);
Builder.Insert(VectorLoopBranchBack);

DTU.applyUpdates(
{{DominatorTree::Insert, VectorLoopIncBlock, VectorLoopStartBlock},
{DominatorTree::Insert, VectorLoopIncBlock, EndBlock}});

// If we found a mismatch then we need to calculate which lane in the vector
// had a mismatch and add that on to the current loop index.
Builder.SetInsertPoint(VectorLoopMismatchBlock);
PHINode *FoundPred = Builder.CreatePHI(PredVTy, 1, "mismatch_vec_found_pred");
FoundPred->addIncoming(VectorMatchCmp, VectorLoopStartBlock);
PHINode *LastLoopPred =
Builder.CreatePHI(PredVTy, 1, "mismatch_vec_last_loop_pred");
LastLoopPred->addIncoming(LoopPred, VectorLoopStartBlock);
PHINode *VectorFoundIndex =
Builder.CreatePHI(I64Type, 1, "mismatch_vec_found_index");
VectorFoundIndex->addIncoming(VectorIndexPhi, VectorLoopStartBlock);

Value *PredMatchCmp = Builder.CreateAnd(LastLoopPred, FoundPred);
Value *Ctz = Builder.CreateIntrinsic(
Intrinsic::experimental_cttz_elts, {ResType, PredMatchCmp->getType()},
{PredMatchCmp, /*ZeroIsPoison=*/Builder.getInt1(true)});
Ctz = Builder.CreateZExt(Ctz, I64Type);
Value *VectorLoopRes64 = Builder.CreateAdd(VectorFoundIndex, Ctz, "",
/*HasNUW=*/true, /*HasNSW=*/true);
Value *VectorLoopRes = Builder.CreateTrunc(VectorLoopRes64, ResType);
Value *VectorLoopRes =
createMaskedFindMismatch(Builder, DTU, GEPA, GEPB, ExtStart, ExtEnd);

Builder.Insert(BranchInst::Create(EndBlock));

Expand Down

0 comments on commit de5ff38

Please sign in to comment.