Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[InstCombine] Pattern match minmax calls for unsigned saturation. #99250

Open
wants to merge 5 commits into
base: main
Choose a base branch
from

Conversation

huihzhang
Copy link
Contributor

@huihzhang huihzhang commented Jul 16, 2024

Extend the facility from D68651 to match the following patterns for unsigned saturation:

  1. fold smax(0, sub(zext(A), zext(B))) into usub_sat;
  2. fold umin(UINT_MAX, add(zext(A), zext(B))) into uadd_sat.

Proofs:
uadd_sat: https://alive2.llvm.org/ce/z/v-LJZr
usub_sat: https://alive2.llvm.org/ce/z/Veoypa

This patch matches the following patterns for unsigned saturation:
1) fold smax(UINT_MIN, smin(UINT_MAX, sub(zext(A), zext(B)))) into usub_sat,
   where smin smax could be reversed.
2) fold umin(UINT_MAX, add(zext(A), zext(B))) into uadd_sat.

Note that this patch extends the signed saturation (sadd|ssub_sat) pattern
matching from D68651.
@llvmbot
Copy link
Collaborator

llvmbot commented Jul 16, 2024

@llvm/pr-subscribers-llvm-transforms

Author: Huihui Zhang (huihzhang)

Changes

Extend the facility from D68651 to match the following patterns for unsigned saturation:

  1. fold smax(UINT_MIN, smin(UINT_MAX, sub(zext(A), zext(B)))) into usub_sat,
    where smin smax could be reversed.
  2. fold umin(UINT_MAX, add(zext(A), zext(B))) into uadd_sat.

Patch is 24.63 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/99250.diff

3 Files Affected:

  • (modified) llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp (+39-18)
  • (modified) llvm/lib/Transforms/InstCombine/InstCombineInternal.h (+1-1)
  • (added) llvm/test/Transforms/InstCombine/uaddsub_sat.ll (+499)
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
index 467b291f9a4c3..bbb2f994e1aea 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
@@ -1117,16 +1117,22 @@ static Instruction *moveAddAfterMinMax(IntrinsicInst *II,
   return IsSigned ? BinaryOperator::CreateNSWAdd(NewMinMax, Add->getOperand(1))
                   : BinaryOperator::CreateNUWAdd(NewMinMax, Add->getOperand(1));
 }
-/// Match a sadd_sat or ssub_sat which is using min/max to clamp the value.
-Instruction *InstCombinerImpl::matchSAddSubSat(IntrinsicInst &MinMax1) {
+/// Match a [s|u]add_sat or [s|u]sub_sat which is using min/max to clamp the
+/// value.
+Instruction *InstCombinerImpl::matchAddSubSat(IntrinsicInst &MinMax1) {
   Type *Ty = MinMax1.getType();
 
-  // We are looking for a tree of:
-  // max(INT_MIN, min(INT_MAX, add(sext(A), sext(B))))
-  // Where the min and max could be reversed
-  Instruction *MinMax2;
+  // 1. We are looking for a tree of signed saturation:
+  //    smax(SINT_MIN, smin(SINT_MAX, add|sub(sext(A), sext(B))))
+  //    Where the smin and smax could be reversed.
+  // 2. A tree of unsigned saturation:
+  //    smax(UINT_MIN, smin(UINT_MAX, sub(zext(A), zext(B))))
+  //    Where the smin and smax could be reversed.
+  //    Or umin(UINT_MAX, add(zext(A), zext(B)))
+  Instruction *MinMax2 = nullptr;
   BinaryOperator *AddSub;
   const APInt *MinValue, *MaxValue;
+  bool IsUnsignedSaturate = false;
   if (match(&MinMax1, m_SMin(m_Instruction(MinMax2), m_APInt(MaxValue)))) {
     if (!match(MinMax2, m_SMax(m_BinOp(AddSub), m_APInt(MinValue))))
       return nullptr;
@@ -1134,22 +1140,29 @@ Instruction *InstCombinerImpl::matchSAddSubSat(IntrinsicInst &MinMax1) {
                    m_SMax(m_Instruction(MinMax2), m_APInt(MinValue)))) {
     if (!match(MinMax2, m_SMin(m_BinOp(AddSub), m_APInt(MaxValue))))
       return nullptr;
+  } else if (match(&MinMax1, m_UMin(m_BinOp(AddSub), m_APInt(MaxValue)))) {
+    IsUnsignedSaturate = true;
   } else
     return nullptr;
 
+  if (!IsUnsignedSaturate && MinValue && MinValue->isZero())
+    IsUnsignedSaturate = true;
+
   // Check that the constants clamp a saturate, and that the new type would be
   // sensible to convert to.
-  if (!(*MaxValue + 1).isPowerOf2() || -*MinValue != *MaxValue + 1)
+  if (!(*MaxValue + 1).isPowerOf2() ||
+      (!IsUnsignedSaturate && -*MinValue != *MaxValue + 1))
     return nullptr;
   // In what bitwidth can this be treated as saturating arithmetics?
-  unsigned NewBitWidth = (*MaxValue + 1).logBase2() + 1;
+  unsigned NewBitWidth =
+      (*MaxValue + 1).logBase2() + (IsUnsignedSaturate ? 0 : 1);
   // FIXME: This isn't quite right for vectors, but using the scalar type is a
   // good first approximation for what should be done there.
   if (!shouldChangeType(Ty->getScalarType()->getIntegerBitWidth(), NewBitWidth))
     return nullptr;
 
   // Also make sure that the inner min/max and the add/sub have one use.
-  if (!MinMax2->hasOneUse() || !AddSub->hasOneUse())
+  if ((MinMax2 && !MinMax2->hasOneUse()) || !AddSub->hasOneUse())
     return nullptr;
 
   // Create the new type (which can be a vector type)
@@ -1157,17 +1170,25 @@ Instruction *InstCombinerImpl::matchSAddSubSat(IntrinsicInst &MinMax1) {
 
   Intrinsic::ID IntrinsicID;
   if (AddSub->getOpcode() == Instruction::Add)
-    IntrinsicID = Intrinsic::sadd_sat;
+    IntrinsicID =
+        IsUnsignedSaturate ? Intrinsic::uadd_sat : Intrinsic::sadd_sat;
   else if (AddSub->getOpcode() == Instruction::Sub)
-    IntrinsicID = Intrinsic::ssub_sat;
+    IntrinsicID =
+        IsUnsignedSaturate ? Intrinsic::usub_sat : Intrinsic::ssub_sat;
   else
     return nullptr;
 
   // The two operands of the add/sub must be nsw-truncatable to the NewTy. This
   // is usually achieved via a sext from a smaller type.
-  if (ComputeMaxSignificantBits(AddSub->getOperand(0), 0, AddSub) >
-          NewBitWidth ||
-      ComputeMaxSignificantBits(AddSub->getOperand(1), 0, AddSub) > NewBitWidth)
+  Value *Op0 = AddSub->getOperand(0);
+  Value *Op1 = AddSub->getOperand(1);
+  unsigned Op0MaxBitWidth =
+      IsUnsignedSaturate ? computeKnownBits(Op0, 0, AddSub).countMaxActiveBits()
+                         : ComputeMaxSignificantBits(Op0, 0, AddSub);
+  unsigned Op1MaxBitWidth =
+      IsUnsignedSaturate ? computeKnownBits(Op1, 0, AddSub).countMaxActiveBits()
+                         : ComputeMaxSignificantBits(Op1, 0, AddSub);
+  if (Op0MaxBitWidth > NewBitWidth || Op1MaxBitWidth > NewBitWidth)
     return nullptr;
 
   // Finally create and return the sat intrinsic, truncated to the new type
@@ -1175,10 +1196,10 @@ Instruction *InstCombinerImpl::matchSAddSubSat(IntrinsicInst &MinMax1) {
   Value *AT = Builder.CreateTrunc(AddSub->getOperand(0), NewTy);
   Value *BT = Builder.CreateTrunc(AddSub->getOperand(1), NewTy);
   Value *Sat = Builder.CreateCall(F, {AT, BT});
-  return CastInst::Create(Instruction::SExt, Sat, Ty);
+  return CastInst::Create(
+      IsUnsignedSaturate ? Instruction::ZExt : Instruction::SExt, Sat, Ty);
 }
 
-
 /// If we have a clamp pattern like max (min X, 42), 41 -- where the output
 /// can only be one of two possible constant values -- turn that into a select
 /// of constants.
@@ -1878,8 +1899,8 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) {
     if (Instruction *Sel = foldClampRangeOfTwo(II, Builder))
       return Sel;
 
-    if (Instruction *SAdd = matchSAddSubSat(*II))
-      return SAdd;
+    if (Instruction *AddSubSat = matchAddSubSat(*II))
+      return AddSubSat;
 
     if (Value *NewMinMax = reassociateMinMaxWithConstants(II, Builder, SQ))
       return replaceInstUsesWith(*II, NewMinMax);
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
index 64fbcc80e0edf..b76d71da230a7 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
+++ b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
@@ -392,7 +392,7 @@ class LLVM_LIBRARY_VISIBILITY InstCombinerImpl final
   Instruction *narrowMathIfNoOverflow(BinaryOperator &I);
   Instruction *narrowFunnelShift(TruncInst &Trunc);
   Instruction *optimizeBitCastFromPhi(CastInst &CI, PHINode *PN);
-  Instruction *matchSAddSubSat(IntrinsicInst &MinMax1);
+  Instruction *matchAddSubSat(IntrinsicInst &MinMax1);
   Instruction *foldNot(BinaryOperator &I);
   Instruction *foldBinOpOfDisplacedShifts(BinaryOperator &I);
 
diff --git a/llvm/test/Transforms/InstCombine/uaddsub_sat.ll b/llvm/test/Transforms/InstCombine/uaddsub_sat.ll
new file mode 100644
index 0000000000000..469a651ef9a93
--- /dev/null
+++ b/llvm/test/Transforms/InstCombine/uaddsub_sat.ll
@@ -0,0 +1,499 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5
+; RUN: opt < %s -passes=instcombine -S | FileCheck %s
+
+target datalayout = "e-m:e-p:32:32-Fi8-i64:64-v128:64:128-a:0:32-n32-S64"
+
+define i32 @uadd_sat32(i32 %a, i32 %b) {
+; CHECK-LABEL: define i32 @uadd_sat32(
+; CHECK-SAME: i32 [[A:%.*]], i32 [[B:%.*]]) {
+; CHECK-NEXT:  [[ENTRY:.*:]]
+; CHECK-NEXT:    [[TMP0:%.*]] = call i32 @llvm.uadd.sat.i32(i32 [[B]], i32 [[A]])
+; CHECK-NEXT:    ret i32 [[TMP0]]
+;
+entry:
+  %conv = zext i32 %a to i64
+  %conv1 = zext i32 %b to i64
+  %add = add i64 %conv1, %conv
+  %0 = icmp ult i64 %add, 4294967295
+  %select = select i1 %0, i64 %add, i64 4294967295
+  %conv2 = trunc i64 %select to i32
+  ret i32 %conv2
+}
+
+define i32 @uadd_sat32_min(i32 %a, i32 %b) {
+; CHECK-LABEL: define i32 @uadd_sat32_min(
+; CHECK-SAME: i32 [[A:%.*]], i32 [[B:%.*]]) {
+; CHECK-NEXT:  [[ENTRY:.*:]]
+; CHECK-NEXT:    [[TMP0:%.*]] = call i32 @llvm.uadd.sat.i32(i32 [[B]], i32 [[A]])
+; CHECK-NEXT:    ret i32 [[TMP0]]
+;
+entry:
+  %conv = zext i32 %a to i64
+  %conv1 = zext i32 %b to i64
+  %add = add i64 %conv1, %conv
+  %min = call i64 @llvm.umin.i64(i64 %add, i64 4294967295)
+  %conv2 = trunc i64 %min to i32
+  ret i32 %conv2
+}
+
+define i32 @usub_sat32(i32 %a, i32 %b) {
+; CHECK-LABEL: define i32 @usub_sat32(
+; CHECK-SAME: i32 [[A:%.*]], i32 [[B:%.*]]) {
+; CHECK-NEXT:  [[ENTRY:.*:]]
+; CHECK-NEXT:    [[TMP0:%.*]] = call i32 @llvm.usub.sat.i32(i32 [[A]], i32 [[B]])
+; CHECK-NEXT:    ret i32 [[TMP0]]
+;
+entry:
+  %conv = zext i32 %a to i64
+  %conv1 = zext i32 %b to i64
+  %sub = sub i64 %conv, %conv1
+  %cmp4 = icmp sgt i64 %sub, 0
+  %cmp6 = icmp slt i64 %sub, 4294967295
+  %cond = select i1 %cmp6, i64 %sub, i64 4294967295
+  %cond11 = select i1 %cmp4, i64 %cond, i64 0
+  %conv12 = trunc i64 %cond11 to i32
+  ret i32 %conv12
+}
+
+define i32 @usub_sat32_minmax(i32 %a, i32 %b) {
+; CHECK-LABEL: define i32 @usub_sat32_minmax(
+; CHECK-SAME: i32 [[A:%.*]], i32 [[B:%.*]]) {
+; CHECK-NEXT:  [[ENTRY:.*:]]
+; CHECK-NEXT:    [[TMP0:%.*]] = call i32 @llvm.usub.sat.i32(i32 [[A]], i32 [[B]])
+; CHECK-NEXT:    ret i32 [[TMP0]]
+;
+entry:
+  %conv = zext i32 %a to i64
+  %conv1 = zext i32 %b to i64
+  %sub = sub i64 %conv, %conv1
+  %cond = call i64 @llvm.smin.i64(i64 %sub, i64 4294967295)
+  %cond11 = call i64 @llvm.smax.i64(i64 %cond, i64 0)
+  %conv12 = trunc i64 %cond11 to i32
+  ret i32 %conv12
+}
+
+define i16 @uadd_sat16(i16 %a, i16 %b) {
+; CHECK-LABEL: define i16 @uadd_sat16(
+; CHECK-SAME: i16 [[A:%.*]], i16 [[B:%.*]]) {
+; CHECK-NEXT:  [[ENTRY:.*:]]
+; CHECK-NEXT:    [[TMP0:%.*]] = call i16 @llvm.uadd.sat.i16(i16 [[B]], i16 [[A]])
+; CHECK-NEXT:    ret i16 [[TMP0]]
+;
+entry:
+  %conv = zext i16 %a to i32
+  %conv1 = zext i16 %b to i32
+  %add = add i32 %conv1, %conv
+  %0 = icmp ult i32 %add, 65535
+  %select = select i1 %0, i32 %add, i32 65535
+  %conv2 = trunc i32 %select to i16
+  ret i16 %conv2
+}
+
+define i16 @uadd_sat16_min(i16 %a, i16 %b) {
+; CHECK-LABEL: define i16 @uadd_sat16_min(
+; CHECK-SAME: i16 [[A:%.*]], i16 [[B:%.*]]) {
+; CHECK-NEXT:  [[ENTRY:.*:]]
+; CHECK-NEXT:    [[TMP0:%.*]] = call i16 @llvm.uadd.sat.i16(i16 [[B]], i16 [[A]])
+; CHECK-NEXT:    ret i16 [[TMP0]]
+;
+entry:
+  %conv = zext i16 %a to i32
+  %conv1 = zext i16 %b to i32
+  %add = add i32 %conv1, %conv
+  %min = call i32 @llvm.umin.i32(i32 %add, i32 65535)
+  %conv2 = trunc i32 %min to i16
+  ret i16 %conv2
+}
+
+define i16 @usub_sat16(i16 %a, i16 %b) {
+; CHECK-LABEL: define i16 @usub_sat16(
+; CHECK-SAME: i16 [[A:%.*]], i16 [[B:%.*]]) {
+; CHECK-NEXT:  [[ENTRY:.*:]]
+; CHECK-NEXT:    [[TMP0:%.*]] = call i16 @llvm.usub.sat.i16(i16 [[A]], i16 [[B]])
+; CHECK-NEXT:    ret i16 [[TMP0]]
+;
+entry:
+  %conv = zext i16 %a to i32
+  %conv1 = zext i16 %b to i32
+  %sub = sub i32 %conv, %conv1
+  %cmp4 = icmp sgt i32 %sub, 0
+  %cmp6 = icmp slt i32 %sub, 65535
+  %cond = select i1 %cmp6, i32 %sub, i32 65535
+  %cond11 = select i1 %cmp4, i32 %cond, i32 0
+  %conv12 = trunc i32 %cond11 to i16
+  ret i16 %conv12
+}
+
+define i16 @usub_sat16_minmax(i16 %a, i16 %b) {
+; CHECK-LABEL: define i16 @usub_sat16_minmax(
+; CHECK-SAME: i16 [[A:%.*]], i16 [[B:%.*]]) {
+; CHECK-NEXT:  [[ENTRY:.*:]]
+; CHECK-NEXT:    [[TMP0:%.*]] = call i16 @llvm.usub.sat.i16(i16 [[A]], i16 [[B]])
+; CHECK-NEXT:    ret i16 [[TMP0]]
+;
+entry:
+  %conv = zext i16 %a to i32
+  %conv1 = zext i16 %b to i32
+  %sub = sub i32 %conv, %conv1
+  %cond = call i32 @llvm.smin.i32(i32 %sub, i32 65535)
+  %cond11 = call i32 @llvm.smax.i32(i32 %cond, i32 0)
+  %conv12 = trunc i32 %cond11 to i16
+  ret i16 %conv12
+}
+
+define i8 @uadd_sat8(i8 %a, i8 %b) {
+; CHECK-LABEL: define i8 @uadd_sat8(
+; CHECK-SAME: i8 [[A:%.*]], i8 [[B:%.*]]) {
+; CHECK-NEXT:  [[ENTRY:.*:]]
+; CHECK-NEXT:    [[TMP0:%.*]] = call i8 @llvm.uadd.sat.i8(i8 [[B]], i8 [[A]])
+; CHECK-NEXT:    ret i8 [[TMP0]]
+;
+entry:
+  %conv = zext i8 %a to i32
+  %conv1 = zext i8 %b to i32
+  %add = add i32 %conv1, %conv
+  %0 = icmp ult i32 %add, 255
+  %select = select i1 %0, i32 %add, i32 255
+  %conv2 = trunc i32 %select to i8
+  ret i8 %conv2
+}
+
+define i8 @uadd_sat8_min(i8 %a, i8 %b) {
+; CHECK-LABEL: define i8 @uadd_sat8_min(
+; CHECK-SAME: i8 [[A:%.*]], i8 [[B:%.*]]) {
+; CHECK-NEXT:  [[ENTRY:.*:]]
+; CHECK-NEXT:    [[TMP0:%.*]] = call i8 @llvm.uadd.sat.i8(i8 [[B]], i8 [[A]])
+; CHECK-NEXT:    ret i8 [[TMP0]]
+;
+entry:
+  %conv = zext i8 %a to i32
+  %conv1 = zext i8 %b to i32
+  %add = add i32 %conv1, %conv
+  %min = call i32 @llvm.umin.i32(i32 %add, i32 255)
+  %conv2 = trunc i32 %min to i8
+  ret i8 %conv2
+}
+
+define i8 @usub_sat8(i8 %a, i8 %b) {
+; CHECK-LABEL: define i8 @usub_sat8(
+; CHECK-SAME: i8 [[A:%.*]], i8 [[B:%.*]]) {
+; CHECK-NEXT:  [[ENTRY:.*:]]
+; CHECK-NEXT:    [[TMP0:%.*]] = call i8 @llvm.usub.sat.i8(i8 [[A]], i8 [[B]])
+; CHECK-NEXT:    ret i8 [[TMP0]]
+;
+entry:
+  %conv = zext i8 %a to i32
+  %conv1 = zext i8 %b to i32
+  %sub = sub i32 %conv, %conv1
+  %cmp4 = icmp sgt i32 %sub, 0
+  %cmp6 = icmp slt i32 %sub, 255
+  %cond = select i1 %cmp6, i32 %sub, i32 255
+  %cond11 = select i1 %cmp4, i32 %cond, i32 0
+  %conv12 = trunc i32 %cond11 to i8
+  ret i8 %conv12
+}
+
+define i8 @usub_sat8_minmax(i8 %a, i8 %b) {
+; CHECK-LABEL: define i8 @usub_sat8_minmax(
+; CHECK-SAME: i8 [[A:%.*]], i8 [[B:%.*]]) {
+; CHECK-NEXT:  [[ENTRY:.*:]]
+; CHECK-NEXT:    [[TMP0:%.*]] = call i8 @llvm.usub.sat.i8(i8 [[A]], i8 [[B]])
+; CHECK-NEXT:    ret i8 [[TMP0]]
+;
+entry:
+  %conv = zext i8 %a to i32
+  %conv1 = zext i8 %b to i32
+  %sub = sub i32 %conv, %conv1
+  %cond = call i32 @llvm.smin.i32(i32 %sub, i32 255)
+  %cond11 = call i32 @llvm.smax.i32(i32 %cond, i32 0)
+  %conv12 = trunc i32 %cond11 to i8
+  ret i8 %conv12
+}
+
+define i64 @uadd_sat64(i64 %a, i64 %b) {
+; CHECK-LABEL: define i64 @uadd_sat64(
+; CHECK-SAME: i64 [[A:%.*]], i64 [[B:%.*]]) {
+; CHECK-NEXT:  [[ENTRY:.*:]]
+; CHECK-NEXT:    [[TMP0:%.*]] = call i64 @llvm.uadd.sat.i64(i64 [[B]], i64 [[A]])
+; CHECK-NEXT:    ret i64 [[TMP0]]
+;
+entry:
+  %conv = zext i64 %a to i65
+  %conv1 = zext i64 %b to i65
+  %add = add i65 %conv1, %conv
+  %0 = icmp ult i65 %add, 18446744073709551615
+  %select = select i1 %0, i65 %add, i65 18446744073709551615
+  %conv2 = trunc i65 %select to i64
+  ret i64 %conv2
+}
+
+define i64 @usub_sat64(i64 %a, i64 %b) {
+; CHECK-LABEL: define i64 @usub_sat64(
+; CHECK-SAME: i64 [[A:%.*]], i64 [[B:%.*]]) {
+; CHECK-NEXT:  [[ENTRY:.*:]]
+; CHECK-NEXT:    [[TMP0:%.*]] = call i64 @llvm.usub.sat.i64(i64 [[A]], i64 [[B]])
+; CHECK-NEXT:    ret i64 [[TMP0]]
+;
+entry:
+  %conv = zext i64 %a to i128
+  %conv1 = zext i64 %b to i128
+  %sub = sub i128 %conv, %conv1
+  %cmp4 = icmp sgt i128 %sub, 0
+  %cmp6 = icmp slt i128 %sub, 18446744073709551615
+  %cond = select i1 %cmp6, i128 %sub, i128 18446744073709551615
+  %cond11 = select i1 %cmp4, i128 %cond, i128 0
+  %conv12 = trunc i128 %cond11 to i64
+  ret i64 %conv12
+}
+
+define <4 x i32> @uadd_satv4i32(<4 x i32> %a, <4 x i32> %b) {
+; CHECK-LABEL: define <4 x i32> @uadd_satv4i32(
+; CHECK-SAME: <4 x i32> [[A:%.*]], <4 x i32> [[B:%.*]]) {
+; CHECK-NEXT:  [[ENTRY:.*:]]
+; CHECK-NEXT:    [[TMP0:%.*]] = call <4 x i32> @llvm.uadd.sat.v4i32(<4 x i32> [[B]], <4 x i32> [[A]])
+; CHECK-NEXT:    ret <4 x i32> [[TMP0]]
+;
+entry:
+  %conv = zext <4 x i32> %a to <4 x i64>
+  %conv1 = zext <4 x i32> %b to <4 x i64>
+  %add = add <4 x i64> %conv1, %conv
+  %0 = icmp ult <4 x i64> %add, <i64 4294967295, i64 4294967295, i64 4294967295, i64 4294967295>
+  %select = select <4 x i1> %0, <4 x i64> %add, <4 x i64> <i64 4294967295, i64 4294967295, i64 4294967295, i64 4294967295>
+  %conv7 = trunc <4 x i64> %select to <4 x i32>
+  ret <4 x i32> %conv7
+}
+
+define <8 x i16> @uadd_satv8i16_minmax(<8 x i16> %a, <8 x i16> %b) {
+; CHECK-LABEL: define <8 x i16> @uadd_satv8i16_minmax(
+; CHECK-SAME: <8 x i16> [[A:%.*]], <8 x i16> [[B:%.*]]) {
+; CHECK-NEXT:  [[ENTRY:.*:]]
+; CHECK-NEXT:    [[TMP0:%.*]] = call <8 x i16> @llvm.uadd.sat.v8i16(<8 x i16> [[B]], <8 x i16> [[A]])
+; CHECK-NEXT:    ret <8 x i16> [[TMP0]]
+;
+entry:
+  %conv = zext <8 x i16> %a to <8 x i32>
+  %conv1 = zext <8 x i16> %b to <8 x i32>
+  %add = add <8 x i32> %conv1, %conv
+  %select = call <8 x i32> @llvm.umin.v8i32(<8 x i32> %add, <8 x i32> <i32 65535, i32 65535, i32 65535, i32 65535, i32 65535, i32 65535, i32 65535, i32 65535>)
+  %conv7 = trunc <8 x i32> %select to <8 x i16>
+  ret <8 x i16> %conv7
+}
+
+define <16 x i8> @usub_satv16i8(<16 x i8> %a, <16 x i8> %b) {
+; CHECK-LABEL: define <16 x i8> @usub_satv16i8(
+; CHECK-SAME: <16 x i8> [[A:%.*]], <16 x i8> [[B:%.*]]) {
+; CHECK-NEXT:  [[ENTRY:.*:]]
+; CHECK-NEXT:    [[TMP0:%.*]] = call <16 x i8> @llvm.usub.sat.v16i8(<16 x i8> [[B]], <16 x i8> [[A]])
+; CHECK-NEXT:    ret <16 x i8> [[TMP0]]
+;
+entry:
+  %conv = zext <16 x i8> %a to <16 x i32>
+  %conv1 = zext <16 x i8> %b to <16 x i32>
+  %sub = sub <16 x i32> %conv1, %conv
+  %0 = icmp slt <16 x i32> %sub, <i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255>
+  %select = select <16 x i1> %0, <16 x i32> %sub, <16 x i32> <i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255>
+  %1 = icmp sgt <16 x i32> %select, <i32 0, i32 0, i32 0, i32 0, i32 0, i32 0, i32 0, i32 0, i32 0, i32 0, i32 0, i32 0, i32 0, i32 0, i32 0, i32 0>
+  %select8 = select <16 x i1> %1, <16 x i32> %select, <16 x i32> <i32 0, i32 0, i32 0, i32 0, i32 0, i32 0, i32 0, i32 0, i32 0, i32 0, i32 0, i32 0, i32 0, i32 0, i32 0, i32 0>
+  %conv7 = trunc <16 x i32> %select8 to <16 x i8>
+  ret <16 x i8> %conv7
+}
+
+define <2 x i64> @usub_satv2i64_minmax(<2 x i64> %a, <2 x i64> %b) {
+; CHECK-LABEL: define <2 x i64> @usub_satv2i64_minmax(
+; CHECK-SAME: <2 x i64> [[A:%.*]], <2 x i64> [[B:%.*]]) {
+; CHECK-NEXT:  [[ENTRY:.*:]]
+; CHECK-NEXT:    [[TMP0:%.*]] = call <2 x i64> @llvm.usub.sat.v2i64(<2 x i64> [[B]], <2 x i64> [[A]])
+; CHECK-NEXT:    ret <2 x i64> [[TMP0]]
+;
+entry:
+  %conv = zext <2 x i64> %a to <2 x i128>
+  %conv1 = zext <2 x i64> %b to <2 x i128>
+  %sub = sub <2 x i128> %conv1, %conv
+  %select = call <2 x i128> @llvm.smin.v2i128(<2 x i128> %sub, <2 x i128> <i128 18446744073709551615, i128 18446744073709551615>)
+  %select8 = call <2 x i128> @llvm.smax.v2i128(<2 x i128> %select, <2 x i128> <i128 0, i128 0>)
+  %conv7 = trunc <2 x i128> %select8 to <2 x i64>
+  ret <2 x i64> %conv7
+}
+
+define i32 @uadd_sat32_extra_use_1(i32 %a, i32 %b) {
+; CHECK-LABEL: define i32 @uadd_sat32_extra_use_1(
+; CHECK-SAME: i32 [[A:%.*]], i32 [[B:%.*]]) {
+; CHECK-NEXT:  [[ENTRY:.*:]]
+; CHECK-NEXT:    [[TMP0:%.*]] = call i32 @llvm.uadd.sat.i32(i32 [[B]], i32 [[A]])
+; CHECK-NEXT:    [[SELECT:%.*]] = zext i32 [[TMP0]] to i64
+; CHECK-NEXT:    call void @use64(i64 [[SELECT]])
+; CHECK-NEXT:    ret i32 [[TMP0]]
+;
+entry:
+  %conv = zext i32 %a to i64
+  %conv1 = zext i32 %b to i64
+  %add = add i64 %conv1, %conv
+  %0 = icmp ult i64 %add, 4294967295
+  %select = select i1 %0, i64 %add, i64 4294967295
+  %conv7 = trunc i64 %select to i32
+  call void @use64(i64 %select)
+  ret i32 %conv7
+}
+
+define i32 @uadd_sat32_extra_use_2(i32 %a, i32 %b) {
+; CHECK-LABEL: define i32 @uadd_sat32_extra_use_2(
+; CHECK-SAME: i32 [[A:%.*]], i32 [[B:%.*]]) {
+; CHECK-NEXT:  [[ENTRY:.*:]]
+; CHECK-NEXT:    [[CONV:%.*]] = zext i32 [[A]] to i64
+; CHECK-NEXT:    [[CONV1:%.*]] = zext i32 [[B]] to i64
+; CHECK-NEXT:    [[ADD:%.*]] = add nuw nsw i64 [[CONV1]], [[CONV]]
+; CHECK-NEXT:    [[SELECT:%.*]] = call i64 @llvm.umin.i64(i64 [[ADD]], i64 4294967295)
+; CHECK-NEXT:    [[CONV7:%.*]] = trunc nuw i64 [[SELECT]] to i32
+; CHECK-NEXT:    call void @use64(i64 [[ADD]])
+; CHECK-NEXT:    ret i32 [[CONV7]]
+;
+entry:
+  %conv = zext i32 %a to i64
+  %conv1 = zext i32 %b to i64
+  %add = add i64 %conv1, %conv
+  %0 = icmp ult i64 %add, 4294967295
+  %select = select i1 %0, i64 %add, i64 4294967295
+  %conv7 = trunc i64 %select to i32
+  call void @use64(i64 %add)
+  ret i32 %conv7
+}
+
+define i32 @usub_sat32_extra_use_3(i32 %a, i32 %b) {
+; CHECK-LABEL: define i32 @usub_sat32_extra_use_3(
+; CHECK-SAME: i32 [[A:%.*]], i32 [[B:%.*]]) {
+; CHECK-NEXT:  [[ENTRY:.*:]]
+; CHECK-NEXT:    [[CONV:%.*]] = zext i32 [[A]] to i64
+; CHECK-NEXT:    [[CONV1:%.*]] = zext i32 [[B]] to i64
+; CHECK-NEXT:    [[SUB:%.*]] = sub nsw i64 [[CONV]], [[CONV1]]
+; CHECK-NEXT:    [[COND:%.*]] = call i64 @llvm.smin.i64(i64 [[SUB]], i64 4294967295)
+; CHECK-NEXT:    [[COND11:%.*]] = call i64 @llvm.smax.i64(i64 [[COND]], i64 0)
+; CHECK-NEX...
[truncated]

@goldsteinn
Copy link
Contributor

Hi,

can you please add alive2 proofs? See: https://llvm.org/docs/InstCombineContributorGuide.html#proofs for more details.

Bail out if BinOp is not known non-negative.
@huihzhang
Copy link
Contributor Author

Thanks @goldsteinn!
Caught a bug while doing alive2 testing. Update fix in commit#3 .

When trying to fold "umin(UINT_MAX, BinOp(zext(A), zext(B)))" into uadd_sat, we need to make sure BinOp is known non-negative.

Please also see alive2 verification below:
uadd_sat: https://alive2.llvm.org/ce/z/v-LJZr
usub_sat: https://alive2.llvm.org/ce/z/h34acm

@goldsteinn
Copy link
Contributor

The smin(UINT_MAX, (sub (zext A), (zext B)) is dead code: https://alive2.llvm.org/ce/z/AZ-puW

@huihzhang
Copy link
Contributor Author

Thanks @dtcxzyw @goldsteinn for the feedbacks!
I will post an update, and simplify into below pattern match rules:

  1. fold smax(UINT_MIN, sub(zext(A), zext(B))) into usub_sat,
  2. fold umin(UINT_MAX, add(zext(A), zext(B))) into uadd_sat.

@dtcxzyw
Copy link
Member

dtcxzyw commented Jul 19, 2024

@huihzhang Please update the alive2 link for smax(UINT_MIN, sub(zext(A), zext(B))) -> usub_sat in the PR description.

1. fold smax(UINT_MIN, sub(zext(A), zext(B))) into usub_sat;
2. fold umin(UINT_MAX, add(zext(A), zext(B))) into uadd_sat.
@huihzhang
Copy link
Contributor Author

huihzhang commented Jul 22, 2024

@dtcxzyw @goldsteinn, I have pushed an update to use the simplified pattern match rules for unsigned saturation.
Please help take a look, thanks a lot!

Alive2 link for usub_sat is updated as well, edited into the very first comment.

// Pattern match for unsigned saturation.
if (match(&MinMax1, m_UMin(m_BinOp(AddSub), m_APInt(MaxValue)))) {
// Bail out if AddSub could be negative.
if (!isKnownNonNegative(AddSub, SQ.getWithInstruction(AddSub)))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this check necessary? I don't see any llvm.assume in the alive2 proof.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This check was used to reject "umin(UINT_MAX, sub) -> usub_sat".
I pushed a new update to check for BinOp opcode 'Add', so that !isKnownNonNegative check can be deleted.

llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp Outdated Show resolved Hide resolved
llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp Outdated Show resolved Hide resolved
: ComputeMaxSignificantBits(Op1, 0, AddSub);
unsigned NewBitWidth = IsUnsignedSaturate
? std::max(Op0MaxBitWidth, Op1MaxBitWidth)
: (*MaxValue + 1).logBase2() + 1;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The compute known bits stuff should be after the:

  if (!shouldChangeType(Ty->getScalarType()->getIntegerBitWidth(), NewBitWidth))
    return nullptr;

  // Also make sure that the inner min/max and the add/sub have one use.

Checks below.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The problem is when trying to fold "smax(UINT_MIN, sub(zext(A), zext(B))) -> usub_sat", the MaxValue is not given. I was trying to use computeKnownBits to determine NewBitWidth.
I pushed a new update to first try setting NewBitWidth to half of the bitwidth of MinMax1, when MaxValue is not given.
Later use the results from computeKnownBits to try to reduce NewBitWidth further, and check for legality.
Please let me know if this approach is more sensible?

return nullptr;

// Create the new type (which can be a vector type)
Type *NewTy = Ty->getWithNewBitWidth(NewBitWidth);

Intrinsic::ID IntrinsicID;
if (AddSub->getOpcode() == Instruction::Add)
IntrinsicID = Intrinsic::sadd_sat;
IntrinsicID =
IsUnsignedSaturate ? Intrinsic::uadd_sat : Intrinsic::sadd_sat;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't see any logic thats checking the right binop for the given minmax pattern. I.e umin(UINT_MAX, add) -> uadd_sat is okay, but umin(UINT_MAX, sub) -> usub_sat isn.t

Copy link
Contributor Author

@huihzhang huihzhang Jul 24, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"umin(UINT_MAX, sub) -> usub_sat" was previously rejected by "!isKnownNonNegative".

-  if (match(&MinMax1, m_UMin(m_BinOp(AddSub), m_APInt(MaxValue)))) {
-    // Bail out if AddSub could be negative.
-    if (!isKnownNonNegative(AddSub, SQ.getWithInstruction(AddSub)))
-      return nullptr;
+  if (match(&MinMax1, m_UMin(m_BinOp(AddSub), m_APInt(MaxValue))) &&
+      AddSub->getOpcode() == Instruction::Add) {

In the new update, I check for BinOp opcode equals 'Add', to make sure we don't accept umin(UINT_MAX, sub) case.

@huihzhang
Copy link
Contributor Author

Thanks a lot @dtcxzyw @goldsteinn for the review feedbacks!
I pushed an update to address review comments.
My main concern is when folding "smax(UINT_MIN, sub(zext(A), zext(B)))" into usub_sat, the MaxValue is not given.
I am using a different approach to first estimate NewBitWith, later use the results from computeKnownBits to try to reduce NewBitWidth , and check for legality.
Please help take a look, and let me know if this approach is more sensible?

@dtcxzyw dtcxzyw requested a review from goldsteinn July 24, 2024 08:49
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants