diff --git a/llvm/lib/Analysis/ValueTracking.cpp b/llvm/lib/Analysis/ValueTracking.cpp index b715ab6eabf702..f54de030d33445 100644 --- a/llvm/lib/Analysis/ValueTracking.cpp +++ b/llvm/lib/Analysis/ValueTracking.cpp @@ -3801,7 +3801,8 @@ static unsigned ComputeNumSignBitsImpl(const Value *V, default: break; case Instruction::SExt: Tmp = TyBits - U->getOperand(0)->getType()->getScalarSizeInBits(); - return ComputeNumSignBits(U->getOperand(0), Depth + 1, Q) + Tmp; + return ComputeNumSignBits(U->getOperand(0), DemandedElts, Depth + 1, Q) + + Tmp; case Instruction::SDiv: { const APInt *Denominator; @@ -3813,7 +3814,8 @@ static unsigned ComputeNumSignBitsImpl(const Value *V, break; // Calculate the incoming numerator bits. - unsigned NumBits = ComputeNumSignBits(U->getOperand(0), Depth + 1, Q); + unsigned NumBits = + ComputeNumSignBits(U->getOperand(0), DemandedElts, Depth + 1, Q); // Add floor(log(C)) bits to the numerator bits. return std::min(TyBits, NumBits + Denominator->logBase2()); @@ -3822,7 +3824,7 @@ static unsigned ComputeNumSignBitsImpl(const Value *V, } case Instruction::SRem: { - Tmp = ComputeNumSignBits(U->getOperand(0), Depth + 1, Q); + Tmp = ComputeNumSignBits(U->getOperand(0), DemandedElts, Depth + 1, Q); const APInt *Denominator; // srem X, C -> we know that the result is within [-C+1,C) when C is a @@ -3853,7 +3855,7 @@ static unsigned ComputeNumSignBitsImpl(const Value *V, } case Instruction::AShr: { - Tmp = ComputeNumSignBits(U->getOperand(0), Depth + 1, Q); + Tmp = ComputeNumSignBits(U->getOperand(0), DemandedElts, Depth + 1, Q); // ashr X, C -> adds C sign bits. Vectors too. const APInt *ShAmt; if (match(U->getOperand(1), m_APInt(ShAmt))) { @@ -3869,7 +3871,7 @@ static unsigned ComputeNumSignBitsImpl(const Value *V, const APInt *ShAmt; if (match(U->getOperand(1), m_APInt(ShAmt))) { // shl destroys sign bits. - Tmp = ComputeNumSignBits(U->getOperand(0), Depth + 1, Q); + Tmp = ComputeNumSignBits(U->getOperand(0), DemandedElts, Depth + 1, Q); if (ShAmt->uge(TyBits) || // Bad shift. ShAmt->uge(Tmp)) break; // Shifted all sign bits out. Tmp2 = ShAmt->getZExtValue(); @@ -3881,9 +3883,9 @@ static unsigned ComputeNumSignBitsImpl(const Value *V, case Instruction::Or: case Instruction::Xor: // NOT is handled here. // Logical binary ops preserve the number of sign bits at the worst. - Tmp = ComputeNumSignBits(U->getOperand(0), Depth + 1, Q); + Tmp = ComputeNumSignBits(U->getOperand(0), DemandedElts, Depth + 1, Q); if (Tmp != 1) { - Tmp2 = ComputeNumSignBits(U->getOperand(1), Depth + 1, Q); + Tmp2 = ComputeNumSignBits(U->getOperand(1), DemandedElts, Depth + 1, Q); FirstAnswer = std::min(Tmp, Tmp2); // We computed what we know about the sign bits as our first // answer. Now proceed to the generic code that uses @@ -3899,9 +3901,10 @@ static unsigned ComputeNumSignBitsImpl(const Value *V, if (isSignedMinMaxClamp(U, X, CLow, CHigh)) return std::min(CLow->getNumSignBits(), CHigh->getNumSignBits()); - Tmp = ComputeNumSignBits(U->getOperand(1), Depth + 1, Q); - if (Tmp == 1) break; - Tmp2 = ComputeNumSignBits(U->getOperand(2), Depth + 1, Q); + Tmp = ComputeNumSignBits(U->getOperand(1), DemandedElts, Depth + 1, Q); + if (Tmp == 1) + break; + Tmp2 = ComputeNumSignBits(U->getOperand(2), DemandedElts, Depth + 1, Q); return std::min(Tmp, Tmp2); } @@ -3915,7 +3918,7 @@ static unsigned ComputeNumSignBitsImpl(const Value *V, if (const auto *CRHS = dyn_cast(U->getOperand(1))) if (CRHS->isAllOnesValue()) { KnownBits Known(TyBits); - computeKnownBits(U->getOperand(0), Known, Depth + 1, Q); + computeKnownBits(U->getOperand(0), DemandedElts, Known, Depth + 1, Q); // If the input is known to be 0 or 1, the output is 0/-1, which is // all sign bits set. @@ -3928,19 +3931,21 @@ static unsigned ComputeNumSignBitsImpl(const Value *V, return Tmp; } - Tmp2 = ComputeNumSignBits(U->getOperand(1), Depth + 1, Q); - if (Tmp2 == 1) break; + Tmp2 = ComputeNumSignBits(U->getOperand(1), DemandedElts, Depth + 1, Q); + if (Tmp2 == 1) + break; return std::min(Tmp, Tmp2) - 1; case Instruction::Sub: - Tmp2 = ComputeNumSignBits(U->getOperand(1), Depth + 1, Q); - if (Tmp2 == 1) break; + Tmp2 = ComputeNumSignBits(U->getOperand(1), DemandedElts, Depth + 1, Q); + if (Tmp2 == 1) + break; // Handle NEG. if (const auto *CLHS = dyn_cast(U->getOperand(0))) if (CLHS->isNullValue()) { KnownBits Known(TyBits); - computeKnownBits(U->getOperand(1), Known, Depth + 1, Q); + computeKnownBits(U->getOperand(1), DemandedElts, Known, Depth + 1, Q); // If the input is known to be 0 or 1, the output is 0/-1, which is // all sign bits set. if ((Known.Zero | 1).isAllOnes()) @@ -3957,17 +3962,22 @@ static unsigned ComputeNumSignBitsImpl(const Value *V, // Sub can have at most one carry bit. Thus we know that the output // is, at worst, one more bit than the inputs. - Tmp = ComputeNumSignBits(U->getOperand(0), Depth + 1, Q); - if (Tmp == 1) break; + Tmp = ComputeNumSignBits(U->getOperand(0), DemandedElts, Depth + 1, Q); + if (Tmp == 1) + break; return std::min(Tmp, Tmp2) - 1; case Instruction::Mul: { // The output of the Mul can be at most twice the valid bits in the // inputs. - unsigned SignBitsOp0 = ComputeNumSignBits(U->getOperand(0), Depth + 1, Q); - if (SignBitsOp0 == 1) break; - unsigned SignBitsOp1 = ComputeNumSignBits(U->getOperand(1), Depth + 1, Q); - if (SignBitsOp1 == 1) break; + unsigned SignBitsOp0 = + ComputeNumSignBits(U->getOperand(0), DemandedElts, Depth + 1, Q); + if (SignBitsOp0 == 1) + break; + unsigned SignBitsOp1 = + ComputeNumSignBits(U->getOperand(1), DemandedElts, Depth + 1, Q); + if (SignBitsOp1 == 1) + break; unsigned OutValidBits = (TyBits - SignBitsOp0 + 1) + (TyBits - SignBitsOp1 + 1); return OutValidBits > TyBits ? 1 : TyBits - OutValidBits + 1; @@ -3988,8 +3998,8 @@ static unsigned ComputeNumSignBitsImpl(const Value *V, for (unsigned i = 0, e = NumIncomingValues; i != e; ++i) { if (Tmp == 1) return Tmp; RecQ.CxtI = PN->getIncomingBlock(i)->getTerminator(); - Tmp = std::min( - Tmp, ComputeNumSignBits(PN->getIncomingValue(i), Depth + 1, RecQ)); + Tmp = std::min(Tmp, ComputeNumSignBits(PN->getIncomingValue(i), + DemandedElts, Depth + 1, RecQ)); } return Tmp; } @@ -4050,10 +4060,13 @@ static unsigned ComputeNumSignBitsImpl(const Value *V, case Instruction::Call: { if (const auto *II = dyn_cast(U)) { switch (II->getIntrinsicID()) { - default: break; + default: + break; case Intrinsic::abs: - Tmp = ComputeNumSignBits(U->getOperand(0), Depth + 1, Q); - if (Tmp == 1) break; + Tmp = + ComputeNumSignBits(U->getOperand(0), DemandedElts, Depth + 1, Q); + if (Tmp == 1) + break; // Absolute value reduces number of sign bits by at most 1. return Tmp - 1;