Skip to content

Commit

Permalink
[ValueTracking] Consistently propagate DemandedElts is `isKnownNonZ…
Browse files Browse the repository at this point in the history
…ero`

Summary: 

Test Plan: 

Reviewers: 

Subscribers: 

Tasks: 

Tags: 


Differential Revision: https://phabricator.intern.facebook.com/D60250837
  • Loading branch information
goldsteinn authored and yuxuanchen1997 committed Jul 25, 2024
1 parent c98d4a2 commit 97b91e0
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 39 deletions.
90 changes: 56 additions & 34 deletions llvm/lib/Analysis/ValueTracking.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -303,15 +303,21 @@ bool llvm::isKnownNegative(const Value *V, const SimplifyQuery &SQ,
return computeKnownBits(V, Depth, SQ).isNegative();
}

static bool isKnownNonEqual(const Value *V1, const Value *V2, unsigned Depth,
static bool isKnownNonEqual(const Value *V1, const Value *V2,
const APInt &DemandedElts, unsigned Depth,
const SimplifyQuery &Q);

bool llvm::isKnownNonEqual(const Value *V1, const Value *V2,
const DataLayout &DL, AssumptionCache *AC,
const Instruction *CxtI, const DominatorTree *DT,
bool UseInstrInfo) {
assert(V1->getType() == V2->getType() &&
"Testing equality of non-equal types!");
auto *FVTy = dyn_cast<FixedVectorType>(V1->getType());
APInt DemandedElts =
FVTy ? APInt::getAllOnes(FVTy->getNumElements()) : APInt(1, 1);
return ::isKnownNonEqual(
V1, V2, 0,
V1, V2, DemandedElts, 0,
SimplifyQuery(DL, DT, AC, safeCxtI(V2, V1, CxtI), UseInstrInfo));
}

Expand Down Expand Up @@ -2654,7 +2660,7 @@ static bool isNonZeroSub(const APInt &DemandedElts, unsigned Depth,
if (C->isNullValue() && isKnownNonZero(Y, DemandedElts, Q, Depth))
return true;

return ::isKnownNonEqual(X, Y, Depth, Q);
return ::isKnownNonEqual(X, Y, DemandedElts, Depth, Q);
}

static bool isNonZeroMul(const APInt &DemandedElts, unsigned Depth,
Expand Down Expand Up @@ -2778,8 +2784,11 @@ static bool isKnownNonZeroFromOperator(const Operator *I,
// This all implies the 2 i16 elements are non-zero.
Type *FromTy = I->getOperand(0)->getType();
if ((FromTy->isIntOrIntVectorTy() || FromTy->isPtrOrPtrVectorTy()) &&
(BitWidth % getBitWidth(FromTy->getScalarType(), Q.DL)) == 0)
(BitWidth % getBitWidth(FromTy->getScalarType(), Q.DL)) == 0) {
if (match(I, m_ElementWiseBitCast(m_Value())))
return isKnownNonZero(I->getOperand(0), DemandedElts, Q, Depth);
return isKnownNonZero(I->getOperand(0), Q, Depth);
}
} break;
case Instruction::IntToPtr:
// Note that we have to take special care to avoid looking through
Expand All @@ -2788,21 +2797,21 @@ static bool isKnownNonZeroFromOperator(const Operator *I,
if (!isa<ScalableVectorType>(I->getType()) &&
Q.DL.getTypeSizeInBits(I->getOperand(0)->getType()).getFixedValue() <=
Q.DL.getTypeSizeInBits(I->getType()).getFixedValue())
return isKnownNonZero(I->getOperand(0), Q, Depth);
return isKnownNonZero(I->getOperand(0), DemandedElts, Q, Depth);
break;
case Instruction::PtrToInt:
// Similar to int2ptr above, we can look through ptr2int here if the cast
// is a no-op or an extend and not a truncate.
if (!isa<ScalableVectorType>(I->getType()) &&
Q.DL.getTypeSizeInBits(I->getOperand(0)->getType()).getFixedValue() <=
Q.DL.getTypeSizeInBits(I->getType()).getFixedValue())
return isKnownNonZero(I->getOperand(0), Q, Depth);
return isKnownNonZero(I->getOperand(0), DemandedElts, Q, Depth);
break;
case Instruction::Trunc:
// nuw/nsw trunc preserves zero/non-zero status of input.
if (auto *TI = dyn_cast<TruncInst>(I))
if (TI->hasNoSignedWrap() || TI->hasNoUnsignedWrap())
return isKnownNonZero(TI->getOperand(0), Q, Depth);
return isKnownNonZero(TI->getOperand(0), DemandedElts, Q, Depth);
break;

case Instruction::Sub:
Expand All @@ -2823,13 +2832,13 @@ static bool isKnownNonZeroFromOperator(const Operator *I,
case Instruction::SExt:
case Instruction::ZExt:
// ext X != 0 if X != 0.
return isKnownNonZero(I->getOperand(0), Q, Depth);
return isKnownNonZero(I->getOperand(0), DemandedElts, Q, Depth);

case Instruction::Shl: {
// shl nsw/nuw can't remove any non-zero bits.
const OverflowingBinaryOperator *BO = cast<OverflowingBinaryOperator>(I);
if (Q.IIQ.hasNoUnsignedWrap(BO) || Q.IIQ.hasNoSignedWrap(BO))
return isKnownNonZero(I->getOperand(0), Q, Depth);
return isKnownNonZero(I->getOperand(0), DemandedElts, Q, Depth);

// shl X, Y != 0 if X is odd. Note that the value of the shift is undefined
// if the lowest bit is shifted off the end.
Expand All @@ -2845,7 +2854,7 @@ static bool isKnownNonZeroFromOperator(const Operator *I,
// shr exact can only shift out zero bits.
const PossiblyExactOperator *BO = cast<PossiblyExactOperator>(I);
if (BO->isExact())
return isKnownNonZero(I->getOperand(0), Q, Depth);
return isKnownNonZero(I->getOperand(0), DemandedElts, Q, Depth);

// shr X, Y != 0 if X is negative. Note that the value of the shift is not
// defined if the sign bit is shifted off the end.
Expand Down Expand Up @@ -3100,6 +3109,8 @@ static bool isKnownNonZeroFromOperator(const Operator *I,
/*NSW=*/true, /* NUW=*/false);
// Vec reverse preserves zero/non-zero status from input vec.
case Intrinsic::vector_reverse:
return isKnownNonZero(II->getArgOperand(0), DemandedElts.reverseBits(),
Q, Depth);
// umin/smin/smax/smin/or of all non-zero elements is always non-zero.
case Intrinsic::vector_reduce_or:
case Intrinsic::vector_reduce_umax:
Expand Down Expand Up @@ -3424,7 +3435,8 @@ getInvertibleOperands(const Operator *Op1,
/// Only handle a small subset of binops where (binop V2, X) with non-zero X
/// implies V2 != V1.
static bool isModifyingBinopOfNonZero(const Value *V1, const Value *V2,
unsigned Depth, const SimplifyQuery &Q) {
const APInt &DemandedElts, unsigned Depth,
const SimplifyQuery &Q) {
const BinaryOperator *BO = dyn_cast<BinaryOperator>(V1);
if (!BO)
return false;
Expand All @@ -3444,39 +3456,43 @@ static bool isModifyingBinopOfNonZero(const Value *V1, const Value *V2,
Op = BO->getOperand(0);
else
return false;
return isKnownNonZero(Op, Q, Depth + 1);
return isKnownNonZero(Op, DemandedElts, Q, Depth + 1);
}
return false;
}

/// Return true if V2 == V1 * C, where V1 is known non-zero, C is not 0/1 and
/// the multiplication is nuw or nsw.
static bool isNonEqualMul(const Value *V1, const Value *V2, unsigned Depth,
static bool isNonEqualMul(const Value *V1, const Value *V2,
const APInt &DemandedElts, unsigned Depth,
const SimplifyQuery &Q) {
if (auto *OBO = dyn_cast<OverflowingBinaryOperator>(V2)) {
const APInt *C;
return match(OBO, m_Mul(m_Specific(V1), m_APInt(C))) &&
(OBO->hasNoUnsignedWrap() || OBO->hasNoSignedWrap()) &&
!C->isZero() && !C->isOne() && isKnownNonZero(V1, Q, Depth + 1);
!C->isZero() && !C->isOne() &&
isKnownNonZero(V1, DemandedElts, Q, Depth + 1);
}
return false;
}

/// Return true if V2 == V1 << C, where V1 is known non-zero, C is not 0 and
/// the shift is nuw or nsw.
static bool isNonEqualShl(const Value *V1, const Value *V2, unsigned Depth,
static bool isNonEqualShl(const Value *V1, const Value *V2,
const APInt &DemandedElts, unsigned Depth,
const SimplifyQuery &Q) {
if (auto *OBO = dyn_cast<OverflowingBinaryOperator>(V2)) {
const APInt *C;
return match(OBO, m_Shl(m_Specific(V1), m_APInt(C))) &&
(OBO->hasNoUnsignedWrap() || OBO->hasNoSignedWrap()) &&
!C->isZero() && isKnownNonZero(V1, Q, Depth + 1);
!C->isZero() && isKnownNonZero(V1, DemandedElts, Q, Depth + 1);
}
return false;
}

static bool isNonEqualPHIs(const PHINode *PN1, const PHINode *PN2,
unsigned Depth, const SimplifyQuery &Q) {
const APInt &DemandedElts, unsigned Depth,
const SimplifyQuery &Q) {
// Check two PHIs are in same block.
if (PN1->getParent() != PN2->getParent())
return false;
Expand All @@ -3498,14 +3514,15 @@ static bool isNonEqualPHIs(const PHINode *PN1, const PHINode *PN2,

SimplifyQuery RecQ = Q;
RecQ.CxtI = IncomBB->getTerminator();
if (!isKnownNonEqual(IV1, IV2, Depth + 1, RecQ))
if (!isKnownNonEqual(IV1, IV2, DemandedElts, Depth + 1, RecQ))
return false;
UsedFullRecursion = true;
}
return true;
}

static bool isNonEqualSelect(const Value *V1, const Value *V2, unsigned Depth,
static bool isNonEqualSelect(const Value *V1, const Value *V2,
const APInt &DemandedElts, unsigned Depth,
const SimplifyQuery &Q) {
const SelectInst *SI1 = dyn_cast<SelectInst>(V1);
if (!SI1)
Expand All @@ -3516,12 +3533,12 @@ static bool isNonEqualSelect(const Value *V1, const Value *V2, unsigned Depth,
const Value *Cond2 = SI2->getCondition();
if (Cond1 == Cond2)
return isKnownNonEqual(SI1->getTrueValue(), SI2->getTrueValue(),
Depth + 1, Q) &&
DemandedElts, Depth + 1, Q) &&
isKnownNonEqual(SI1->getFalseValue(), SI2->getFalseValue(),
Depth + 1, Q);
DemandedElts, Depth + 1, Q);
}
return isKnownNonEqual(SI1->getTrueValue(), V2, Depth + 1, Q) &&
isKnownNonEqual(SI1->getFalseValue(), V2, Depth + 1, Q);
return isKnownNonEqual(SI1->getTrueValue(), V2, DemandedElts, Depth + 1, Q) &&
isKnownNonEqual(SI1->getFalseValue(), V2, DemandedElts, Depth + 1, Q);
}

// Check to see if A is both a GEP and is the incoming value for a PHI in the
Expand Down Expand Up @@ -3577,7 +3594,8 @@ static bool isNonEqualPointersWithRecursiveGEP(const Value *A, const Value *B,
}

/// Return true if it is known that V1 != V2.
static bool isKnownNonEqual(const Value *V1, const Value *V2, unsigned Depth,
static bool isKnownNonEqual(const Value *V1, const Value *V2,
const APInt &DemandedElts, unsigned Depth,
const SimplifyQuery &Q) {
if (V1 == V2)
return false;
Expand All @@ -3595,40 +3613,44 @@ static bool isKnownNonEqual(const Value *V1, const Value *V2, unsigned Depth,
auto *O2 = dyn_cast<Operator>(V2);
if (O1 && O2 && O1->getOpcode() == O2->getOpcode()) {
if (auto Values = getInvertibleOperands(O1, O2))
return isKnownNonEqual(Values->first, Values->second, Depth + 1, Q);
return isKnownNonEqual(Values->first, Values->second, DemandedElts,
Depth + 1, Q);

if (const PHINode *PN1 = dyn_cast<PHINode>(V1)) {
const PHINode *PN2 = cast<PHINode>(V2);
// FIXME: This is missing a generalization to handle the case where one is
// a PHI and another one isn't.
if (isNonEqualPHIs(PN1, PN2, Depth, Q))
if (isNonEqualPHIs(PN1, PN2, DemandedElts, Depth, Q))
return true;
};
}

if (isModifyingBinopOfNonZero(V1, V2, Depth, Q) ||
isModifyingBinopOfNonZero(V2, V1, Depth, Q))
if (isModifyingBinopOfNonZero(V1, V2, DemandedElts, Depth, Q) ||
isModifyingBinopOfNonZero(V2, V1, DemandedElts, Depth, Q))
return true;

if (isNonEqualMul(V1, V2, Depth, Q) || isNonEqualMul(V2, V1, Depth, Q))
if (isNonEqualMul(V1, V2, DemandedElts, Depth, Q) ||
isNonEqualMul(V2, V1, DemandedElts, Depth, Q))
return true;

if (isNonEqualShl(V1, V2, Depth, Q) || isNonEqualShl(V2, V1, Depth, Q))
if (isNonEqualShl(V1, V2, DemandedElts, Depth, Q) ||
isNonEqualShl(V2, V1, DemandedElts, Depth, Q))
return true;

if (V1->getType()->isIntOrIntVectorTy()) {
// Are any known bits in V1 contradictory to known bits in V2? If V1
// has a known zero where V2 has a known one, they must not be equal.
KnownBits Known1 = computeKnownBits(V1, Depth, Q);
KnownBits Known1 = computeKnownBits(V1, DemandedElts, Depth, Q);
if (!Known1.isUnknown()) {
KnownBits Known2 = computeKnownBits(V2, Depth, Q);
KnownBits Known2 = computeKnownBits(V2, DemandedElts, Depth, Q);
if (Known1.Zero.intersects(Known2.One) ||
Known2.Zero.intersects(Known1.One))
return true;
}
}

if (isNonEqualSelect(V1, V2, Depth, Q) || isNonEqualSelect(V2, V1, Depth, Q))
if (isNonEqualSelect(V1, V2, DemandedElts, Depth, Q) ||
isNonEqualSelect(V2, V1, DemandedElts, Depth, Q))
return true;

if (isNonEqualPointersWithRecursiveGEP(V1, V2, Q) ||
Expand All @@ -3640,7 +3662,7 @@ static bool isKnownNonEqual(const Value *V1, const Value *V2, unsigned Depth,
// Check PtrToInt type matches the pointer size.
if (match(V1, m_PtrToIntSameSize(Q.DL, m_Value(A))) &&
match(V2, m_PtrToIntSameSize(Q.DL, m_Value(B))))
return isKnownNonEqual(A, B, Depth + 1, Q);
return isKnownNonEqual(A, B, DemandedElts, Depth + 1, Q);

return false;
}
Expand Down
6 changes: 1 addition & 5 deletions llvm/test/Analysis/ValueTracking/known-non-zero.ll
Original file line number Diff line number Diff line change
Expand Up @@ -1522,11 +1522,7 @@ define <4 x i1> @vec_reverse_non_zero_fail(<4 x i8> %xx) {

define i1 @vec_reverse_non_zero_demanded(<4 x i8> %xx) {
; CHECK-LABEL: @vec_reverse_non_zero_demanded(
; CHECK-NEXT: [[X:%.*]] = add nuw <4 x i8> [[XX:%.*]], <i8 1, i8 0, i8 0, i8 0>
; CHECK-NEXT: [[REV:%.*]] = call <4 x i8> @llvm.vector.reverse.v4i8(<4 x i8> [[X]])
; CHECK-NEXT: [[ELE:%.*]] = extractelement <4 x i8> [[REV]], i64 3
; CHECK-NEXT: [[R:%.*]] = icmp eq i8 [[ELE]], 0
; CHECK-NEXT: ret i1 [[R]]
; CHECK-NEXT: ret i1 false
;
%x = add nuw <4 x i8> %xx, <i8 1, i8 0, i8 0, i8 0>
%rev = call <4 x i8> @llvm.vector.reverse(<4 x i8> %x)
Expand Down

0 comments on commit 97b91e0

Please sign in to comment.