Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[InstCombine] Pattern match minmax calls for unsigned saturation. #99250

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 42 additions & 18 deletions llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1117,68 +1117,92 @@ static Instruction *moveAddAfterMinMax(IntrinsicInst *II,
return IsSigned ? BinaryOperator::CreateNSWAdd(NewMinMax, Add->getOperand(1))
: BinaryOperator::CreateNUWAdd(NewMinMax, Add->getOperand(1));
}
/// Match a sadd_sat or ssub_sat which is using min/max to clamp the value.
Instruction *InstCombinerImpl::matchSAddSubSat(IntrinsicInst &MinMax1) {
/// Match a [s|u]add_sat or [s|u]sub_sat which is using min/max to clamp the
/// value.
Instruction *InstCombinerImpl::matchAddSubSat(IntrinsicInst &MinMax1) {
Type *Ty = MinMax1.getType();

// We are looking for a tree of:
// max(INT_MIN, min(INT_MAX, add(sext(A), sext(B))))
// Where the min and max could be reversed
Instruction *MinMax2;
// 1. We are looking for a tree of signed saturation:
// smax(SINT_MIN, smin(SINT_MAX, add|sub(sext(A), sext(B))))
// Where the smin and smax could be reversed.
// 2. A tree of unsigned saturation:
// smax(UINT_MIN, smin(UINT_MAX, sub(zext(A), zext(B))))
dtcxzyw marked this conversation as resolved.
Show resolved Hide resolved
// Where the smin and smax could be reversed.
// Or umin(UINT_MAX, add(zext(A), zext(B)))
Instruction *MinMax2 = nullptr;
BinaryOperator *AddSub;
const APInt *MinValue, *MaxValue;
bool IsUnsignedSaturate = false;
if (match(&MinMax1, m_SMin(m_Instruction(MinMax2), m_APInt(MaxValue)))) {
if (!match(MinMax2, m_SMax(m_BinOp(AddSub), m_APInt(MinValue))))
return nullptr;
} else if (match(&MinMax1,
m_SMax(m_Instruction(MinMax2), m_APInt(MinValue)))) {
if (!match(MinMax2, m_SMin(m_BinOp(AddSub), m_APInt(MaxValue))))
return nullptr;
} else if (match(&MinMax1, m_UMin(m_BinOp(AddSub), m_APInt(MaxValue)))) {
// Bail out if AddSub could be negative.
if (!isKnownNonNegative(AddSub, SQ.getWithInstruction(AddSub)))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this check necessary? I don't see any llvm.assume in the alive2 proof.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This check was used to reject "umin(UINT_MAX, sub) -> usub_sat".
I pushed a new update to check for BinOp opcode 'Add', so that !isKnownNonNegative check can be deleted.

return nullptr;
IsUnsignedSaturate = true;
} else
return nullptr;

if (!IsUnsignedSaturate && MinValue && MinValue->isZero())
IsUnsignedSaturate = true;

// Check that the constants clamp a saturate, and that the new type would be
// sensible to convert to.
if (!(*MaxValue + 1).isPowerOf2() || -*MinValue != *MaxValue + 1)
if (!(*MaxValue + 1).isPowerOf2() ||
(!IsUnsignedSaturate && -*MinValue != *MaxValue + 1))
return nullptr;
// In what bitwidth can this be treated as saturating arithmetics?
unsigned NewBitWidth = (*MaxValue + 1).logBase2() + 1;
unsigned NewBitWidth =
(*MaxValue + 1).logBase2() + (IsUnsignedSaturate ? 0 : 1);
// FIXME: This isn't quite right for vectors, but using the scalar type is a
// good first approximation for what should be done there.
if (!shouldChangeType(Ty->getScalarType()->getIntegerBitWidth(), NewBitWidth))
return nullptr;

// Also make sure that the inner min/max and the add/sub have one use.
if (!MinMax2->hasOneUse() || !AddSub->hasOneUse())
if ((MinMax2 && !MinMax2->hasOneUse()) || !AddSub->hasOneUse())
return nullptr;

// Create the new type (which can be a vector type)
Type *NewTy = Ty->getWithNewBitWidth(NewBitWidth);

Intrinsic::ID IntrinsicID;
if (AddSub->getOpcode() == Instruction::Add)
IntrinsicID = Intrinsic::sadd_sat;
IntrinsicID =
IsUnsignedSaturate ? Intrinsic::uadd_sat : Intrinsic::sadd_sat;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't see any logic thats checking the right binop for the given minmax pattern. I.e umin(UINT_MAX, add) -> uadd_sat is okay, but umin(UINT_MAX, sub) -> usub_sat isn.t

Copy link
Contributor Author

@huihzhang huihzhang Jul 24, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"umin(UINT_MAX, sub) -> usub_sat" was previously rejected by "!isKnownNonNegative".

-  if (match(&MinMax1, m_UMin(m_BinOp(AddSub), m_APInt(MaxValue)))) {
-    // Bail out if AddSub could be negative.
-    if (!isKnownNonNegative(AddSub, SQ.getWithInstruction(AddSub)))
-      return nullptr;
+  if (match(&MinMax1, m_UMin(m_BinOp(AddSub), m_APInt(MaxValue))) &&
+      AddSub->getOpcode() == Instruction::Add) {

In the new update, I check for BinOp opcode equals 'Add', to make sure we don't accept umin(UINT_MAX, sub) case.

else if (AddSub->getOpcode() == Instruction::Sub)
IntrinsicID = Intrinsic::ssub_sat;
IntrinsicID =
IsUnsignedSaturate ? Intrinsic::usub_sat : Intrinsic::ssub_sat;
else
return nullptr;

// The two operands of the add/sub must be nsw-truncatable to the NewTy. This
// is usually achieved via a sext from a smaller type.
if (ComputeMaxSignificantBits(AddSub->getOperand(0), 0, AddSub) >
NewBitWidth ||
ComputeMaxSignificantBits(AddSub->getOperand(1), 0, AddSub) > NewBitWidth)
Value *Op0 = AddSub->getOperand(0);
Value *Op1 = AddSub->getOperand(1);
unsigned Op0MaxBitWidth =
IsUnsignedSaturate ? computeKnownBits(Op0, 0, AddSub).countMaxActiveBits()
: ComputeMaxSignificantBits(Op0, 0, AddSub);
unsigned Op1MaxBitWidth =
IsUnsignedSaturate ? computeKnownBits(Op1, 0, AddSub).countMaxActiveBits()
: ComputeMaxSignificantBits(Op1, 0, AddSub);
if (Op0MaxBitWidth > NewBitWidth || Op1MaxBitWidth > NewBitWidth)
return nullptr;

// Finally create and return the sat intrinsic, truncated to the new type
Function *F = Intrinsic::getDeclaration(MinMax1.getModule(), IntrinsicID, NewTy);
Value *AT = Builder.CreateTrunc(AddSub->getOperand(0), NewTy);
Value *BT = Builder.CreateTrunc(AddSub->getOperand(1), NewTy);
Value *Sat = Builder.CreateCall(F, {AT, BT});
return CastInst::Create(Instruction::SExt, Sat, Ty);
return CastInst::Create(
IsUnsignedSaturate ? Instruction::ZExt : Instruction::SExt, Sat, Ty);
huihzhang marked this conversation as resolved.
Show resolved Hide resolved
}


/// If we have a clamp pattern like max (min X, 42), 41 -- where the output
/// can only be one of two possible constant values -- turn that into a select
/// of constants.
Expand Down Expand Up @@ -1878,8 +1902,8 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) {
if (Instruction *Sel = foldClampRangeOfTwo(II, Builder))
return Sel;

if (Instruction *SAdd = matchSAddSubSat(*II))
return SAdd;
if (Instruction *AddSubSat = matchAddSubSat(*II))
return AddSubSat;

if (Value *NewMinMax = reassociateMinMaxWithConstants(II, Builder, SQ))
return replaceInstUsesWith(*II, NewMinMax);
Expand Down
2 changes: 1 addition & 1 deletion llvm/lib/Transforms/InstCombine/InstCombineInternal.h
Original file line number Diff line number Diff line change
Expand Up @@ -392,7 +392,7 @@ class LLVM_LIBRARY_VISIBILITY InstCombinerImpl final
Instruction *narrowMathIfNoOverflow(BinaryOperator &I);
Instruction *narrowFunnelShift(TruncInst &Trunc);
Instruction *optimizeBitCastFromPhi(CastInst &CI, PHINode *PN);
Instruction *matchSAddSubSat(IntrinsicInst &MinMax1);
Instruction *matchAddSubSat(IntrinsicInst &MinMax1);
Instruction *foldNot(BinaryOperator &I);
Instruction *foldBinOpOfDisplacedShifts(BinaryOperator &I);

Expand Down
Loading
Loading