From 013e4e49270a020e4d02962d37b4e407dd5d52dd Mon Sep 17 00:00:00 2001 From: Rose Date: Thu, 16 May 2024 13:54:17 -0400 Subject: [PATCH] [InstCombine] Reduce multiplicands of even numbers when a shift is involved We can improve analysis, codegen, and enable other folds if we can take expressions like (x * 6) >> 2 and replace them with (x * 3) >> 1 (assuming no overflow of course). Because every shift is a division of 2, we can replace a multiplication with an even number with that number divided by 2 and require one less shift, and keep going until we get 0 or an odd number for the multiplicand. Alive2 Proofs: https://alive2.llvm.org/ce/z/C9FvwB https://alive2.llvm.org/ce/z/7Zsx3b --- .../InstCombine/InstCombineShifts.cpp | 53 +++++++++++++++---- llvm/test/Transforms/InstCombine/ashr-lshr.ll | 8 +-- llvm/test/Transforms/InstCombine/lshr.ll | 12 ++--- 3 files changed, 53 insertions(+), 20 deletions(-) diff --git a/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp b/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp index 4a014ab6e044e5..a19c44026d716d 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp @@ -1504,16 +1504,21 @@ Instruction *InstCombinerImpl::visitLShr(BinaryOperator &I) { // able to invert the transform and perf may suffer with an extra mul // instruction. if (Op0->hasOneUse()) { - APInt NewMulC = MulC->lshr(ShAmtC); - // if c is divisible by (1 << ShAmtC): - // lshr (mul nuw x, MulC), ShAmtC -> mul nuw nsw x, (MulC >> ShAmtC) - if (MulC->eq(NewMulC.shl(ShAmtC))) { - auto *NewMul = - BinaryOperator::CreateNUWMul(X, ConstantInt::get(Ty, NewMulC)); - assert(ShAmtC != 0 && - "lshr X, 0 should be handled by simplifyLShrInst."); - NewMul->setHasNoSignedWrap(true); - return NewMul; + assert(ShAmtC != 0 && + "lshr X, 0 should be handled by simplifyLShrInst."); + unsigned CommonZeros = std::min(MulC->countr_zero(), ShAmtC); + if (CommonZeros != 0) { + APInt NewMulC = MulC->lshr(CommonZeros); + unsigned NewShAmtC = ShAmtC - CommonZeros; + + // We can reduce expressions such as like lshr (mul nuw x, 6), 2 -> + // lshr (mul nuw nsw x, 3), 1 + auto *NewMul = Builder.CreateMul(X, ConstantInt::get(Ty, NewMulC), "", + /*NUW=*/true, /*NSW=*/true); + auto *NewLshr = BinaryOperator::CreateLShr( + NewMul, ConstantInt::get(Ty, NewShAmtC)); + NewLshr->copyIRFlags(&I); // We can preserve 'exact'-ness. + return NewLshr; } } } @@ -1712,6 +1717,34 @@ Instruction *InstCombinerImpl::visitAShr(BinaryOperator &I) { return BinaryOperator::CreateAShr(X, ConstantInt::get(Ty, AmtSum)); } + if (match(Op0, m_OneUse(m_NSWMul(m_Value(X), m_APInt(ShOp1))))) { + unsigned CommonZeros = std::min(ShOp1->countr_zero(), ShAmt); + if (CommonZeros != 0) { + APInt NewMulC = ShOp1->ashr(CommonZeros); + unsigned NewShAmtC = ShAmt - CommonZeros; + // if c is divisible by (1 << ShAmtC): + // ashr (mul nsw x, MulC), ShAmtC -> mul nsw x, (MulC >> ShAmtC) + if (NewShAmtC == 0) { + auto *NewMul = + BinaryOperator::CreateNSWMul(X, ConstantInt::get(Ty, NewMulC)); + NewMul->setHasNoUnsignedWrap( + cast(Op0)->hasNoUnsignedWrap()); + return NewMul; + } + + // We can reduce expressions such as ashr (mul nsw x, 6), 2 -> ashr (mul + // nsw x, 3), 1 + auto *NewMul = Builder.CreateMul( + X, ConstantInt::get(Ty, NewMulC), "", + /*NUW*/ cast(Op0)->hasNoUnsignedWrap(), + /*NSW*/ true); + auto *NewAshr = + BinaryOperator::CreateAShr(NewMul, ConstantInt::get(Ty, NewShAmtC)); + NewAshr->copyIRFlags(&I); // We can preserve 'exact'-ness. + return NewAshr; + } + } + if (match(Op0, m_OneUse(m_SExt(m_Value(X)))) && (Ty->isVectorTy() || shouldChangeType(Ty, X->getType()))) { // ashr (sext X), C --> sext (ashr X, C') diff --git a/llvm/test/Transforms/InstCombine/ashr-lshr.ll b/llvm/test/Transforms/InstCombine/ashr-lshr.ll index 36752036383ebd..d58f91be18b49e 100644 --- a/llvm/test/Transforms/InstCombine/ashr-lshr.ll +++ b/llvm/test/Transforms/InstCombine/ashr-lshr.ll @@ -629,8 +629,8 @@ define i32 @lshr_mul_times_3_div_2_exact(i32 %x) { define i32 @reduce_shift(i32 %x) { ; CHECK-LABEL: @reduce_shift( -; CHECK-NEXT: [[MUL:%.*]] = mul nsw i32 [[X:%.*]], 12 -; CHECK-NEXT: [[SHR:%.*]] = ashr i32 [[MUL]], 4 +; CHECK-NEXT: [[TMP1:%.*]] = mul nsw i32 [[X:%.*]], 3 +; CHECK-NEXT: [[SHR:%.*]] = ashr i32 [[TMP1]], 2 ; CHECK-NEXT: ret i32 [[SHR]] ; %mul = mul nsw i32 %x, 12 @@ -898,8 +898,8 @@ define i32 @reduce_shift_wrong_mul(i32 %x) { define i32 @reduce_shift_exact(i32 %x) { ; CHECK-LABEL: @reduce_shift_exact( -; CHECK-NEXT: [[MUL:%.*]] = mul nsw i32 [[X:%.*]], 12 -; CHECK-NEXT: [[SHR:%.*]] = ashr exact i32 [[MUL]], 4 +; CHECK-NEXT: [[TMP1:%.*]] = mul nsw i32 [[X:%.*]], 3 +; CHECK-NEXT: [[SHR:%.*]] = ashr exact i32 [[TMP1]], 2 ; CHECK-NEXT: ret i32 [[SHR]] ; %mul = mul nsw i32 %x, 12 diff --git a/llvm/test/Transforms/InstCombine/lshr.ll b/llvm/test/Transforms/InstCombine/lshr.ll index dac700014de881..f5bf4c2fd0271c 100644 --- a/llvm/test/Transforms/InstCombine/lshr.ll +++ b/llvm/test/Transforms/InstCombine/lshr.ll @@ -702,8 +702,8 @@ define i32 @shl_add_lshr_neg(i32 %x, i32 %y, i32 %z) { define i32 @mul_splat_fold_wrong_mul_const(i32 %x) { ; CHECK-LABEL: @mul_splat_fold_wrong_mul_const( -; CHECK-NEXT: [[M:%.*]] = mul nuw i32 [[X:%.*]], 65538 -; CHECK-NEXT: [[T:%.*]] = lshr i32 [[M]], 16 +; CHECK-NEXT: [[TMP1:%.*]] = lshr i32 [[X:%.*]], 15 +; CHECK-NEXT: [[T:%.*]] = add nuw nsw i32 [[TMP1]], [[X]] ; CHECK-NEXT: ret i32 [[T]] ; %m = mul nuw i32 %x, 65538 @@ -1528,8 +1528,8 @@ define <2 x i8> @bool_add_lshr_vec_wrong_shift_amt(<2 x i1> %a, <2 x i1> %b) { define i32 @reduce_shift(i32 %x) { ; CHECK-LABEL: @reduce_shift( -; CHECK-NEXT: [[MUL:%.*]] = mul nuw i32 [[X:%.*]], 12 -; CHECK-NEXT: [[SHR:%.*]] = lshr i32 [[MUL]], 4 +; CHECK-NEXT: [[TMP1:%.*]] = mul nuw nsw i32 [[X:%.*]], 3 +; CHECK-NEXT: [[SHR:%.*]] = lshr i32 [[TMP1]], 2 ; CHECK-NEXT: ret i32 [[SHR]] ; %mul = mul nuw i32 %x, 12 @@ -1565,8 +1565,8 @@ define i32 @reduce_shift_wrong_mul(i32 %x) { define i32 @reduce_shift_exact(i32 %x) { ; CHECK-LABEL: @reduce_shift_exact( -; CHECK-NEXT: [[MUL:%.*]] = mul nuw i32 [[X:%.*]], 12 -; CHECK-NEXT: [[SHR:%.*]] = lshr exact i32 [[MUL]], 4 +; CHECK-NEXT: [[TMP1:%.*]] = mul nuw nsw i32 [[X:%.*]], 3 +; CHECK-NEXT: [[SHR:%.*]] = lshr exact i32 [[TMP1]], 2 ; CHECK-NEXT: ret i32 [[SHR]] ; %mul = mul nuw i32 %x, 12