Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[InstCombine] Pattern match minmax calls for unsigned saturation. #99250

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
102 changes: 71 additions & 31 deletions llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1117,68 +1117,108 @@ static Instruction *moveAddAfterMinMax(IntrinsicInst *II,
return IsSigned ? BinaryOperator::CreateNSWAdd(NewMinMax, Add->getOperand(1))
: BinaryOperator::CreateNUWAdd(NewMinMax, Add->getOperand(1));
}
/// Match a sadd_sat or ssub_sat which is using min/max to clamp the value.
Instruction *InstCombinerImpl::matchSAddSubSat(IntrinsicInst &MinMax1) {
/// Match a [s|u]add_sat or [s|u]sub_sat which is using min/max to clamp the
/// value.
Instruction *InstCombinerImpl::matchAddSubSat(IntrinsicInst &MinMax1) {
Type *Ty = MinMax1.getType();

// We are looking for a tree of:
// max(INT_MIN, min(INT_MAX, add(sext(A), sext(B))))
// Where the min and max could be reversed
Instruction *MinMax2;
// 1. We are looking for a tree of signed saturation:
// smax(SINT_MIN, smin(SINT_MAX, add|sub(sext(A), sext(B))))
// Where the smin and smax could be reversed.
// 2. A tree of unsigned saturation:
// smax(UINT_MIN, sub(zext(A), zext(B)))
// Or umin(UINT_MAX, add(zext(A), zext(B))).
huihzhang marked this conversation as resolved.
Show resolved Hide resolved
Instruction *MinMax2 = nullptr;
BinaryOperator *AddSub;
const APInt *MinValue, *MaxValue;
if (match(&MinMax1, m_SMin(m_Instruction(MinMax2), m_APInt(MaxValue)))) {
if (!match(MinMax2, m_SMax(m_BinOp(AddSub), m_APInt(MinValue))))
const APInt *MinValue = nullptr, *MaxValue = nullptr;
bool IsUnsignedSaturate = false;
// Pattern match for unsigned saturation.
if (match(&MinMax1, m_UMin(m_BinOp(AddSub), m_APInt(MaxValue)))) {
// Bail out if AddSub could be negative.
if (!isKnownNonNegative(AddSub, SQ.getWithInstruction(AddSub)))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this check necessary? I don't see any llvm.assume in the alive2 proof.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This check was used to reject "umin(UINT_MAX, sub) -> usub_sat".
I pushed a new update to check for BinOp opcode 'Add', so that !isKnownNonNegative check can be deleted.

return nullptr;
} else if (match(&MinMax1,
m_SMax(m_Instruction(MinMax2), m_APInt(MinValue)))) {
if (!match(MinMax2, m_SMin(m_BinOp(AddSub), m_APInt(MaxValue))))
IsUnsignedSaturate = true;
} else if (match(&MinMax1, m_SMax(m_BinOp(AddSub), m_APInt(MinValue)))) {
huihzhang marked this conversation as resolved.
Show resolved Hide resolved
if (!MinValue->isZero())
return nullptr;
} else
return nullptr;
IsUnsignedSaturate = true;
} else {
// Pattern match for signed saturation.
if (match(&MinMax1, m_SMin(m_Instruction(MinMax2), m_APInt(MaxValue)))) {
if (!match(MinMax2, m_SMax(m_BinOp(AddSub), m_APInt(MinValue))))
return nullptr;
} else if (match(&MinMax1,
m_SMax(m_Instruction(MinMax2), m_APInt(MinValue)))) {
if (!match(MinMax2, m_SMin(m_BinOp(AddSub), m_APInt(MaxValue))))
return nullptr;
} else
return nullptr;
}

// Check that the constants clamp a saturate, and that the new type would be
// sensible to convert to.
if (!(*MaxValue + 1).isPowerOf2() || -*MinValue != *MaxValue + 1)
if ((MaxValue && !(*MaxValue + 1).isPowerOf2()) ||
(!IsUnsignedSaturate && -*MinValue != *MaxValue + 1))
return nullptr;
// In what bitwidth can this be treated as saturating arithmetics?
unsigned NewBitWidth = (*MaxValue + 1).logBase2() + 1;

// Trying to decide the bitwidth for saturating arithmetics.
Value *Op0 = AddSub->getOperand(0);
Value *Op1 = AddSub->getOperand(1);
unsigned Op0MaxBitWidth =
IsUnsignedSaturate ? computeKnownBits(Op0, 0, AddSub).countMaxActiveBits()
: ComputeMaxSignificantBits(Op0, 0, AddSub);
unsigned Op1MaxBitWidth =
IsUnsignedSaturate ? computeKnownBits(Op1, 0, AddSub).countMaxActiveBits()
: ComputeMaxSignificantBits(Op1, 0, AddSub);
unsigned NewBitWidth = IsUnsignedSaturate
? std::max(Op0MaxBitWidth, Op1MaxBitWidth)
: (*MaxValue + 1).logBase2() + 1;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The compute known bits stuff should be after the:

  if (!shouldChangeType(Ty->getScalarType()->getIntegerBitWidth(), NewBitWidth))
    return nullptr;

  // Also make sure that the inner min/max and the add/sub have one use.

Checks below.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The problem is when trying to fold "smax(UINT_MIN, sub(zext(A), zext(B))) -> usub_sat", the MaxValue is not given. I was trying to use computeKnownBits to determine NewBitWidth.
I pushed a new update to first try setting NewBitWidth to half of the bitwidth of MinMax1, when MaxValue is not given.
Later use the results from computeKnownBits to try to reduce NewBitWidth further, and check for legality.
Please let me know if this approach is more sensible?


if (!IsUnsignedSaturate) {
// The two operands of the add/sub must be nsw-truncatable to type with
// NewBitWidth. This is usually achieved via a sext from a smaller type.
if (Op0MaxBitWidth > NewBitWidth || Op1MaxBitWidth > NewBitWidth)
return nullptr;
} else {
// Bail out if NewBitWidth is not smaller than the bitwidth of MinMax1.
if (NewBitWidth == Ty->getScalarType()->getIntegerBitWidth())
return nullptr;
// Bail out if MaxValue is not a valid unsigned saturating maximum value.
if (MaxValue && (*MaxValue + 1).logBase2() != NewBitWidth)
return nullptr;
}

// FIXME: This isn't quite right for vectors, but using the scalar type is a
// good first approximation for what should be done there.
if (!shouldChangeType(Ty->getScalarType()->getIntegerBitWidth(), NewBitWidth))
return nullptr;

// Also make sure that the inner min/max and the add/sub have one use.
if (!MinMax2->hasOneUse() || !AddSub->hasOneUse())
if ((MinMax2 && !MinMax2->hasOneUse()) || !AddSub->hasOneUse())
return nullptr;

// Create the new type (which can be a vector type)
Type *NewTy = Ty->getWithNewBitWidth(NewBitWidth);

Intrinsic::ID IntrinsicID;
if (AddSub->getOpcode() == Instruction::Add)
IntrinsicID = Intrinsic::sadd_sat;
IntrinsicID =
IsUnsignedSaturate ? Intrinsic::uadd_sat : Intrinsic::sadd_sat;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't see any logic thats checking the right binop for the given minmax pattern. I.e umin(UINT_MAX, add) -> uadd_sat is okay, but umin(UINT_MAX, sub) -> usub_sat isn.t

Copy link
Contributor Author

@huihzhang huihzhang Jul 24, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"umin(UINT_MAX, sub) -> usub_sat" was previously rejected by "!isKnownNonNegative".

-  if (match(&MinMax1, m_UMin(m_BinOp(AddSub), m_APInt(MaxValue)))) {
-    // Bail out if AddSub could be negative.
-    if (!isKnownNonNegative(AddSub, SQ.getWithInstruction(AddSub)))
-      return nullptr;
+  if (match(&MinMax1, m_UMin(m_BinOp(AddSub), m_APInt(MaxValue))) &&
+      AddSub->getOpcode() == Instruction::Add) {

In the new update, I check for BinOp opcode equals 'Add', to make sure we don't accept umin(UINT_MAX, sub) case.

else if (AddSub->getOpcode() == Instruction::Sub)
IntrinsicID = Intrinsic::ssub_sat;
IntrinsicID =
IsUnsignedSaturate ? Intrinsic::usub_sat : Intrinsic::ssub_sat;
else
return nullptr;

// The two operands of the add/sub must be nsw-truncatable to the NewTy. This
// is usually achieved via a sext from a smaller type.
if (ComputeMaxSignificantBits(AddSub->getOperand(0), 0, AddSub) >
NewBitWidth ||
ComputeMaxSignificantBits(AddSub->getOperand(1), 0, AddSub) > NewBitWidth)
return nullptr;

// Finally create and return the sat intrinsic, truncated to the new type
Function *F = Intrinsic::getDeclaration(MinMax1.getModule(), IntrinsicID, NewTy);
Value *AT = Builder.CreateTrunc(AddSub->getOperand(0), NewTy);
Value *BT = Builder.CreateTrunc(AddSub->getOperand(1), NewTy);
Value *Sat = Builder.CreateCall(F, {AT, BT});
return CastInst::Create(Instruction::SExt, Sat, Ty);
return CastInst::Create(
IsUnsignedSaturate ? Instruction::ZExt : Instruction::SExt, Sat, Ty);
huihzhang marked this conversation as resolved.
Show resolved Hide resolved
}


/// If we have a clamp pattern like max (min X, 42), 41 -- where the output
/// can only be one of two possible constant values -- turn that into a select
/// of constants.
Expand Down Expand Up @@ -1878,8 +1918,8 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) {
if (Instruction *Sel = foldClampRangeOfTwo(II, Builder))
return Sel;

if (Instruction *SAdd = matchSAddSubSat(*II))
return SAdd;
if (Instruction *AddSubSat = matchAddSubSat(*II))
return AddSubSat;

if (Value *NewMinMax = reassociateMinMaxWithConstants(II, Builder, SQ))
return replaceInstUsesWith(*II, NewMinMax);
Expand Down
2 changes: 1 addition & 1 deletion llvm/lib/Transforms/InstCombine/InstCombineInternal.h
Original file line number Diff line number Diff line change
Expand Up @@ -392,7 +392,7 @@ class LLVM_LIBRARY_VISIBILITY InstCombinerImpl final
Instruction *narrowMathIfNoOverflow(BinaryOperator &I);
Instruction *narrowFunnelShift(TruncInst &Trunc);
Instruction *optimizeBitCastFromPhi(CastInst &CI, PHINode *PN);
Instruction *matchSAddSubSat(IntrinsicInst &MinMax1);
Instruction *matchAddSubSat(IntrinsicInst &MinMax1);
Instruction *foldNot(BinaryOperator &I);
Instruction *foldBinOpOfDisplacedShifts(BinaryOperator &I);

Expand Down
Loading
Loading