Skip to content

Commit

Permalink
[InstCombine] Reduce multiplicands of even numbers when a shift is in…
Browse files Browse the repository at this point in the history
…volved

We can improve analysis, codegen, and enable other folds if we can take expressions like (x * 6) >> 2 and replace them with (x * 3) >> 1 (assuming no overflow of course).

Because every shift is a division of 2, we can replace a multiplication with an even number with that number divided by 2 and require one less shift, and keep going until we get 0 or an odd number for the multiplicand.

Alive2 Proofs:
https://alive2.llvm.org/ce/z/C9FvwB
https://alive2.llvm.org/ce/z/7Zsx3b
  • Loading branch information
AreaZR committed Jun 10, 2024
1 parent b042630 commit 013e4e4
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 20 deletions.
53 changes: 43 additions & 10 deletions llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1504,16 +1504,21 @@ Instruction *InstCombinerImpl::visitLShr(BinaryOperator &I) {
// able to invert the transform and perf may suffer with an extra mul
// instruction.
if (Op0->hasOneUse()) {
APInt NewMulC = MulC->lshr(ShAmtC);
// if c is divisible by (1 << ShAmtC):
// lshr (mul nuw x, MulC), ShAmtC -> mul nuw nsw x, (MulC >> ShAmtC)
if (MulC->eq(NewMulC.shl(ShAmtC))) {
auto *NewMul =
BinaryOperator::CreateNUWMul(X, ConstantInt::get(Ty, NewMulC));
assert(ShAmtC != 0 &&
"lshr X, 0 should be handled by simplifyLShrInst.");
NewMul->setHasNoSignedWrap(true);
return NewMul;
assert(ShAmtC != 0 &&
"lshr X, 0 should be handled by simplifyLShrInst.");
unsigned CommonZeros = std::min(MulC->countr_zero(), ShAmtC);
if (CommonZeros != 0) {
APInt NewMulC = MulC->lshr(CommonZeros);
unsigned NewShAmtC = ShAmtC - CommonZeros;

// We can reduce expressions such as like lshr (mul nuw x, 6), 2 ->
// lshr (mul nuw nsw x, 3), 1
auto *NewMul = Builder.CreateMul(X, ConstantInt::get(Ty, NewMulC), "",
/*NUW=*/true, /*NSW=*/true);
auto *NewLshr = BinaryOperator::CreateLShr(
NewMul, ConstantInt::get(Ty, NewShAmtC));
NewLshr->copyIRFlags(&I); // We can preserve 'exact'-ness.
return NewLshr;
}
}
}
Expand Down Expand Up @@ -1712,6 +1717,34 @@ Instruction *InstCombinerImpl::visitAShr(BinaryOperator &I) {
return BinaryOperator::CreateAShr(X, ConstantInt::get(Ty, AmtSum));
}

if (match(Op0, m_OneUse(m_NSWMul(m_Value(X), m_APInt(ShOp1))))) {
unsigned CommonZeros = std::min(ShOp1->countr_zero(), ShAmt);
if (CommonZeros != 0) {
APInt NewMulC = ShOp1->ashr(CommonZeros);
unsigned NewShAmtC = ShAmt - CommonZeros;
// if c is divisible by (1 << ShAmtC):
// ashr (mul nsw x, MulC), ShAmtC -> mul nsw x, (MulC >> ShAmtC)
if (NewShAmtC == 0) {
auto *NewMul =
BinaryOperator::CreateNSWMul(X, ConstantInt::get(Ty, NewMulC));
NewMul->setHasNoUnsignedWrap(
cast<OverflowingBinaryOperator>(Op0)->hasNoUnsignedWrap());
return NewMul;
}

// We can reduce expressions such as ashr (mul nsw x, 6), 2 -> ashr (mul
// nsw x, 3), 1
auto *NewMul = Builder.CreateMul(
X, ConstantInt::get(Ty, NewMulC), "",
/*NUW*/ cast<OverflowingBinaryOperator>(Op0)->hasNoUnsignedWrap(),
/*NSW*/ true);
auto *NewAshr =
BinaryOperator::CreateAShr(NewMul, ConstantInt::get(Ty, NewShAmtC));
NewAshr->copyIRFlags(&I); // We can preserve 'exact'-ness.
return NewAshr;
}
}

if (match(Op0, m_OneUse(m_SExt(m_Value(X)))) &&
(Ty->isVectorTy() || shouldChangeType(Ty, X->getType()))) {
// ashr (sext X), C --> sext (ashr X, C')
Expand Down
8 changes: 4 additions & 4 deletions llvm/test/Transforms/InstCombine/ashr-lshr.ll
Original file line number Diff line number Diff line change
Expand Up @@ -629,8 +629,8 @@ define i32 @lshr_mul_times_3_div_2_exact(i32 %x) {

define i32 @reduce_shift(i32 %x) {
; CHECK-LABEL: @reduce_shift(
; CHECK-NEXT: [[MUL:%.*]] = mul nsw i32 [[X:%.*]], 12
; CHECK-NEXT: [[SHR:%.*]] = ashr i32 [[MUL]], 4
; CHECK-NEXT: [[TMP1:%.*]] = mul nsw i32 [[X:%.*]], 3
; CHECK-NEXT: [[SHR:%.*]] = ashr i32 [[TMP1]], 2
; CHECK-NEXT: ret i32 [[SHR]]
;
%mul = mul nsw i32 %x, 12
Expand Down Expand Up @@ -898,8 +898,8 @@ define i32 @reduce_shift_wrong_mul(i32 %x) {

define i32 @reduce_shift_exact(i32 %x) {
; CHECK-LABEL: @reduce_shift_exact(
; CHECK-NEXT: [[MUL:%.*]] = mul nsw i32 [[X:%.*]], 12
; CHECK-NEXT: [[SHR:%.*]] = ashr exact i32 [[MUL]], 4
; CHECK-NEXT: [[TMP1:%.*]] = mul nsw i32 [[X:%.*]], 3
; CHECK-NEXT: [[SHR:%.*]] = ashr exact i32 [[TMP1]], 2
; CHECK-NEXT: ret i32 [[SHR]]
;
%mul = mul nsw i32 %x, 12
Expand Down
12 changes: 6 additions & 6 deletions llvm/test/Transforms/InstCombine/lshr.ll
Original file line number Diff line number Diff line change
Expand Up @@ -702,8 +702,8 @@ define i32 @shl_add_lshr_neg(i32 %x, i32 %y, i32 %z) {

define i32 @mul_splat_fold_wrong_mul_const(i32 %x) {
; CHECK-LABEL: @mul_splat_fold_wrong_mul_const(
; CHECK-NEXT: [[M:%.*]] = mul nuw i32 [[X:%.*]], 65538
; CHECK-NEXT: [[T:%.*]] = lshr i32 [[M]], 16
; CHECK-NEXT: [[TMP1:%.*]] = lshr i32 [[X:%.*]], 15
; CHECK-NEXT: [[T:%.*]] = add nuw nsw i32 [[TMP1]], [[X]]
; CHECK-NEXT: ret i32 [[T]]
;
%m = mul nuw i32 %x, 65538
Expand Down Expand Up @@ -1528,8 +1528,8 @@ define <2 x i8> @bool_add_lshr_vec_wrong_shift_amt(<2 x i1> %a, <2 x i1> %b) {

define i32 @reduce_shift(i32 %x) {
; CHECK-LABEL: @reduce_shift(
; CHECK-NEXT: [[MUL:%.*]] = mul nuw i32 [[X:%.*]], 12
; CHECK-NEXT: [[SHR:%.*]] = lshr i32 [[MUL]], 4
; CHECK-NEXT: [[TMP1:%.*]] = mul nuw nsw i32 [[X:%.*]], 3
; CHECK-NEXT: [[SHR:%.*]] = lshr i32 [[TMP1]], 2
; CHECK-NEXT: ret i32 [[SHR]]
;
%mul = mul nuw i32 %x, 12
Expand Down Expand Up @@ -1565,8 +1565,8 @@ define i32 @reduce_shift_wrong_mul(i32 %x) {

define i32 @reduce_shift_exact(i32 %x) {
; CHECK-LABEL: @reduce_shift_exact(
; CHECK-NEXT: [[MUL:%.*]] = mul nuw i32 [[X:%.*]], 12
; CHECK-NEXT: [[SHR:%.*]] = lshr exact i32 [[MUL]], 4
; CHECK-NEXT: [[TMP1:%.*]] = mul nuw nsw i32 [[X:%.*]], 3
; CHECK-NEXT: [[SHR:%.*]] = lshr exact i32 [[TMP1]], 2
; CHECK-NEXT: ret i32 [[SHR]]
;
%mul = mul nuw i32 %x, 12
Expand Down

0 comments on commit 013e4e4

Please sign in to comment.