Skip to content

Commit

Permalink
[ValueTracking] Consistently propagate DemandedElts is `ComputeNumS…
Browse files Browse the repository at this point in the history
…ignBits`

Summary: 

Test Plan: 

Reviewers: 

Subscribers: 

Tasks: 

Tags: 


Differential Revision: https://phabricator.intern.facebook.com/D60250924
  • Loading branch information
goldsteinn authored and yuxuanchen1997 committed Jul 25, 2024
1 parent 97b91e0 commit deb6463
Showing 1 changed file with 40 additions and 27 deletions.
67 changes: 40 additions & 27 deletions llvm/lib/Analysis/ValueTracking.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3801,7 +3801,8 @@ static unsigned ComputeNumSignBitsImpl(const Value *V,
default: break;
case Instruction::SExt:
Tmp = TyBits - U->getOperand(0)->getType()->getScalarSizeInBits();
return ComputeNumSignBits(U->getOperand(0), Depth + 1, Q) + Tmp;
return ComputeNumSignBits(U->getOperand(0), DemandedElts, Depth + 1, Q) +
Tmp;

case Instruction::SDiv: {
const APInt *Denominator;
Expand All @@ -3813,7 +3814,8 @@ static unsigned ComputeNumSignBitsImpl(const Value *V,
break;

// Calculate the incoming numerator bits.
unsigned NumBits = ComputeNumSignBits(U->getOperand(0), Depth + 1, Q);
unsigned NumBits =
ComputeNumSignBits(U->getOperand(0), DemandedElts, Depth + 1, Q);

// Add floor(log(C)) bits to the numerator bits.
return std::min(TyBits, NumBits + Denominator->logBase2());
Expand All @@ -3822,7 +3824,7 @@ static unsigned ComputeNumSignBitsImpl(const Value *V,
}

case Instruction::SRem: {
Tmp = ComputeNumSignBits(U->getOperand(0), Depth + 1, Q);
Tmp = ComputeNumSignBits(U->getOperand(0), DemandedElts, Depth + 1, Q);

const APInt *Denominator;
// srem X, C -> we know that the result is within [-C+1,C) when C is a
Expand Down Expand Up @@ -3853,7 +3855,7 @@ static unsigned ComputeNumSignBitsImpl(const Value *V,
}

case Instruction::AShr: {
Tmp = ComputeNumSignBits(U->getOperand(0), Depth + 1, Q);
Tmp = ComputeNumSignBits(U->getOperand(0), DemandedElts, Depth + 1, Q);
// ashr X, C -> adds C sign bits. Vectors too.
const APInt *ShAmt;
if (match(U->getOperand(1), m_APInt(ShAmt))) {
Expand All @@ -3869,7 +3871,7 @@ static unsigned ComputeNumSignBitsImpl(const Value *V,
const APInt *ShAmt;
if (match(U->getOperand(1), m_APInt(ShAmt))) {
// shl destroys sign bits.
Tmp = ComputeNumSignBits(U->getOperand(0), Depth + 1, Q);
Tmp = ComputeNumSignBits(U->getOperand(0), DemandedElts, Depth + 1, Q);
if (ShAmt->uge(TyBits) || // Bad shift.
ShAmt->uge(Tmp)) break; // Shifted all sign bits out.
Tmp2 = ShAmt->getZExtValue();
Expand All @@ -3881,9 +3883,9 @@ static unsigned ComputeNumSignBitsImpl(const Value *V,
case Instruction::Or:
case Instruction::Xor: // NOT is handled here.
// Logical binary ops preserve the number of sign bits at the worst.
Tmp = ComputeNumSignBits(U->getOperand(0), Depth + 1, Q);
Tmp = ComputeNumSignBits(U->getOperand(0), DemandedElts, Depth + 1, Q);
if (Tmp != 1) {
Tmp2 = ComputeNumSignBits(U->getOperand(1), Depth + 1, Q);
Tmp2 = ComputeNumSignBits(U->getOperand(1), DemandedElts, Depth + 1, Q);
FirstAnswer = std::min(Tmp, Tmp2);
// We computed what we know about the sign bits as our first
// answer. Now proceed to the generic code that uses
Expand All @@ -3899,9 +3901,10 @@ static unsigned ComputeNumSignBitsImpl(const Value *V,
if (isSignedMinMaxClamp(U, X, CLow, CHigh))
return std::min(CLow->getNumSignBits(), CHigh->getNumSignBits());

Tmp = ComputeNumSignBits(U->getOperand(1), Depth + 1, Q);
if (Tmp == 1) break;
Tmp2 = ComputeNumSignBits(U->getOperand(2), Depth + 1, Q);
Tmp = ComputeNumSignBits(U->getOperand(1), DemandedElts, Depth + 1, Q);
if (Tmp == 1)
break;
Tmp2 = ComputeNumSignBits(U->getOperand(2), DemandedElts, Depth + 1, Q);
return std::min(Tmp, Tmp2);
}

Expand All @@ -3915,7 +3918,7 @@ static unsigned ComputeNumSignBitsImpl(const Value *V,
if (const auto *CRHS = dyn_cast<Constant>(U->getOperand(1)))
if (CRHS->isAllOnesValue()) {
KnownBits Known(TyBits);
computeKnownBits(U->getOperand(0), Known, Depth + 1, Q);
computeKnownBits(U->getOperand(0), DemandedElts, Known, Depth + 1, Q);

// If the input is known to be 0 or 1, the output is 0/-1, which is
// all sign bits set.
Expand All @@ -3928,19 +3931,21 @@ static unsigned ComputeNumSignBitsImpl(const Value *V,
return Tmp;
}

Tmp2 = ComputeNumSignBits(U->getOperand(1), Depth + 1, Q);
if (Tmp2 == 1) break;
Tmp2 = ComputeNumSignBits(U->getOperand(1), DemandedElts, Depth + 1, Q);
if (Tmp2 == 1)
break;
return std::min(Tmp, Tmp2) - 1;

case Instruction::Sub:
Tmp2 = ComputeNumSignBits(U->getOperand(1), Depth + 1, Q);
if (Tmp2 == 1) break;
Tmp2 = ComputeNumSignBits(U->getOperand(1), DemandedElts, Depth + 1, Q);
if (Tmp2 == 1)
break;

// Handle NEG.
if (const auto *CLHS = dyn_cast<Constant>(U->getOperand(0)))
if (CLHS->isNullValue()) {
KnownBits Known(TyBits);
computeKnownBits(U->getOperand(1), Known, Depth + 1, Q);
computeKnownBits(U->getOperand(1), DemandedElts, Known, Depth + 1, Q);
// If the input is known to be 0 or 1, the output is 0/-1, which is
// all sign bits set.
if ((Known.Zero | 1).isAllOnes())
Expand All @@ -3957,17 +3962,22 @@ static unsigned ComputeNumSignBitsImpl(const Value *V,

// Sub can have at most one carry bit. Thus we know that the output
// is, at worst, one more bit than the inputs.
Tmp = ComputeNumSignBits(U->getOperand(0), Depth + 1, Q);
if (Tmp == 1) break;
Tmp = ComputeNumSignBits(U->getOperand(0), DemandedElts, Depth + 1, Q);
if (Tmp == 1)
break;
return std::min(Tmp, Tmp2) - 1;

case Instruction::Mul: {
// The output of the Mul can be at most twice the valid bits in the
// inputs.
unsigned SignBitsOp0 = ComputeNumSignBits(U->getOperand(0), Depth + 1, Q);
if (SignBitsOp0 == 1) break;
unsigned SignBitsOp1 = ComputeNumSignBits(U->getOperand(1), Depth + 1, Q);
if (SignBitsOp1 == 1) break;
unsigned SignBitsOp0 =
ComputeNumSignBits(U->getOperand(0), DemandedElts, Depth + 1, Q);
if (SignBitsOp0 == 1)
break;
unsigned SignBitsOp1 =
ComputeNumSignBits(U->getOperand(1), DemandedElts, Depth + 1, Q);
if (SignBitsOp1 == 1)
break;
unsigned OutValidBits =
(TyBits - SignBitsOp0 + 1) + (TyBits - SignBitsOp1 + 1);
return OutValidBits > TyBits ? 1 : TyBits - OutValidBits + 1;
Expand All @@ -3988,8 +3998,8 @@ static unsigned ComputeNumSignBitsImpl(const Value *V,
for (unsigned i = 0, e = NumIncomingValues; i != e; ++i) {
if (Tmp == 1) return Tmp;
RecQ.CxtI = PN->getIncomingBlock(i)->getTerminator();
Tmp = std::min(
Tmp, ComputeNumSignBits(PN->getIncomingValue(i), Depth + 1, RecQ));
Tmp = std::min(Tmp, ComputeNumSignBits(PN->getIncomingValue(i),
DemandedElts, Depth + 1, RecQ));
}
return Tmp;
}
Expand Down Expand Up @@ -4050,10 +4060,13 @@ static unsigned ComputeNumSignBitsImpl(const Value *V,
case Instruction::Call: {
if (const auto *II = dyn_cast<IntrinsicInst>(U)) {
switch (II->getIntrinsicID()) {
default: break;
default:
break;
case Intrinsic::abs:
Tmp = ComputeNumSignBits(U->getOperand(0), Depth + 1, Q);
if (Tmp == 1) break;
Tmp =
ComputeNumSignBits(U->getOperand(0), DemandedElts, Depth + 1, Q);
if (Tmp == 1)
break;

// Absolute value reduces number of sign bits by at most 1.
return Tmp - 1;
Expand Down

0 comments on commit deb6463

Please sign in to comment.