Skip to content

Commit

Permalink
[InstCombine] Improve bitfield addition (#33874)
Browse files Browse the repository at this point in the history
  • Loading branch information
ParkHanbum committed Feb 3, 2024
1 parent 459e768 commit a10b733
Show file tree
Hide file tree
Showing 2 changed files with 188 additions and 52 deletions.
146 changes: 146 additions & 0 deletions llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3379,6 +3379,149 @@ Value *InstCombinerImpl::foldAndOrOfICmps(ICmpInst *LHS, ICmpInst *RHS,
return foldAndOrOfICmpsUsingRanges(LHS, RHS, IsAnd);
}

struct BitFieldAddBitMask {
const APInt *Lower;
const APInt *Upper;
};
struct BitFieldOptBitMask {
const APInt *Lower;
const APInt *Upper;
const APInt *New;
};
struct BitFieldAddInfo {
Value *X;
Value *Y;
bool opt;
union {
BitFieldAddBitMask AddMask;
BitFieldOptBitMask OptMask;
};
};

static Value *foldBitFieldArithmetic(BinaryOperator &I,
InstCombiner::BuilderTy &Builder) {
auto *Disjoint = dyn_cast<PossiblyDisjointInst>(&I);
if (!Disjoint || !Disjoint->isDisjoint())
return nullptr;

unsigned BitWidth = I.getType()->getScalarSizeInBits();
auto AccumulateY = [&](Value *LoY, Value *UpY, APInt LoMask,
APInt UpMask) -> Value * {
Value *Y = nullptr;
auto CLoY = dyn_cast_or_null<Constant>(LoY);
auto CUpY = dyn_cast_or_null<Constant>(UpY);
if ((CLoY == nullptr) ^ (CUpY == nullptr))
return nullptr;

if (CLoY && CUpY) {
APInt IUpY = CUpY->getUniqueInteger();
APInt ILoY = CLoY->getUniqueInteger();
if (!(IUpY.isSubsetOf(UpMask) && ILoY.isSubsetOf(LoMask)))
return nullptr;
Y = ConstantInt::get(CLoY->getType(), ILoY + IUpY);
} else if (LoY == UpY) {
Y = LoY;
}

return Y;
};

auto MatchBitFieldAdd =
[&](BinaryOperator &I) -> std::optional<BitFieldAddInfo> {
const APInt *OptLoMask, *OptUpMask, *LoMask, *UpMask, *UpMask2 = nullptr;
Value *X, *Y, *UpY;
auto BitFieldAddUpper = m_CombineOr(
m_And(m_c_Add(m_And(m_Value(X), m_APInt(UpMask)), m_Value(UpY)),
m_APInt(UpMask2)),
m_c_Add(m_And(m_Value(X), m_APInt(UpMask)), m_Value(UpY)));
auto BitFieldAdd =
m_c_Or(BitFieldAddUpper,
m_And(m_c_Add(m_Deferred(X), m_Value(Y)), m_APInt(LoMask)));
auto BitFieldAddIC =
m_c_Or(m_And(m_c_Add(m_Value(X), m_Value(Y)), m_APInt(LoMask)),
m_And(m_c_Add(m_Deferred(X), m_Value(UpY)), m_APInt(UpMask)));
auto OptBitFieldAdd = m_c_Or(
m_c_Xor(m_CombineOr(
m_c_Add(m_And(m_Value(X), m_APInt(OptLoMask)),
m_And(m_Value(Y), m_APInt(OptLoMask))),
m_c_Add(m_And(m_Value(X), m_APInt(OptLoMask)), m_Value(Y))),
m_CombineOr(m_And(m_Deferred(X), m_APInt(OptUpMask)),
m_And(m_c_Xor(m_Deferred(X), m_Value(UpY)),
m_APInt(OptUpMask)))),
BitFieldAddUpper);

if (match(&I, BitFieldAdd) || match(&I, BitFieldAddIC)) {
APInt Mask = APInt::getBitsSet(BitWidth, BitWidth - UpMask->countl_zero(),
BitWidth);
if (!((UpMask2 == nullptr || *UpMask == *UpMask2) &&
(LoMask->popcount() >= 2 && UpMask->popcount() >= 2) &&
(LoMask->isShiftedMask() && UpMask->isShiftedMask()) &&
((*LoMask & *UpMask) == 0) &&
((Mask ^ *LoMask ^ *UpMask).isAllOnes())))
return std::nullopt;

if (!(Y = AccumulateY(Y, UpY, *LoMask, *UpMask)))
return std::nullopt;

return {{X, Y, false, {{LoMask, UpMask}}}};
}

if (match(&I, OptBitFieldAdd)) {
APInt Mask = APInt::getBitsSet(
BitWidth, BitWidth - OptUpMask->countl_zero(), BitWidth);
APInt Mask2 = APInt::getBitsSet(
BitWidth, BitWidth - UpMask->countl_zero(), BitWidth);
if (!((UpMask2 == nullptr || *UpMask == *UpMask2) &&
(UpMask->isShiftedMask() && UpMask->popcount() >= 2) &&
((*UpMask & (*OptLoMask | *OptUpMask)) == 0) &&
((~*OptLoMask ^ Mask) == *OptUpMask) &&
(Mask2 ^ *UpMask ^ (*OptLoMask ^ *OptUpMask)).isAllOnes()))
return std::nullopt;

if (!(Y = AccumulateY(Y, UpY, (*OptLoMask + *OptUpMask), *UpMask)))
return std::nullopt;

struct BitFieldAddInfo Info = {X, Y, true, {{OptLoMask, OptUpMask}}};
Info.OptMask.New = UpMask;
return {Info};
}

return std::nullopt;
};

auto Info = MatchBitFieldAdd(I);
if (Info) {
Value *X = Info->X;
Value *Y = Info->Y;
APInt BitLoMask, BitUpMask;
if (Info->opt) {
unsigned NewHiBit = BitWidth - (Info->OptMask.New->countl_zero() + 1);
BitLoMask = *Info->OptMask.Lower | *Info->OptMask.New;
BitLoMask.clearBit(NewHiBit);
BitUpMask = *Info->OptMask.Upper;
BitUpMask.setBit(NewHiBit);
} else {
unsigned LowerHiBit = BitWidth - (Info->AddMask.Lower->countl_zero() + 1);
unsigned UpperHiBit = BitWidth - (Info->AddMask.Upper->countl_zero() + 1);
BitLoMask = *Info->AddMask.Lower | *Info->AddMask.Upper;
BitLoMask.clearBit(LowerHiBit);
BitLoMask.clearBit(UpperHiBit);
BitUpMask = APInt::getOneBitSet(BitWidth, LowerHiBit);
BitUpMask.setBit(UpperHiBit);
}

auto AndXLower = Builder.CreateAnd(X, BitLoMask);
auto AndYLower = Builder.CreateAnd(Y, BitLoMask);
auto Add = Builder.CreateNUWAdd(AndXLower, AndYLower);
auto Xor1 = Builder.CreateXor(X, Y);
auto AndUpper = Builder.CreateAnd(Xor1, BitUpMask);
auto Xor = Builder.CreateXor(Add, AndUpper);
return Xor;
}

return nullptr;
}

// FIXME: We use commutative matchers (m_c_*) for some, but not all, matches
// here. We should standardize that construct where it is needed or choose some
// other way to ensure that commutated variants of patterns are not missed.
Expand Down Expand Up @@ -3945,6 +4088,9 @@ Instruction *InstCombinerImpl::visitOr(BinaryOperator &I) {
return BinaryOperator::CreateAnd(X, ConstantInt::get(Ty, *C1 | *C2));
}

if (Value *Res = foldBitFieldArithmetic(I, Builder))
return replaceInstUsesWith(I, Res);

if (Instruction *Res = foldBitwiseLogicWithIntrinsics(I, Builder))
return Res;

Expand Down
94 changes: 42 additions & 52 deletions llvm/test/Transforms/InstCombine/or.ll
Original file line number Diff line number Diff line change
Expand Up @@ -1907,12 +1907,12 @@ define i32 @test_or_add_xor(i32 %a, i32 %b, i32 %c) {
define i8 @src_2_bitfield_op(i8 %x, i8 %y) {
; CHECK-LABEL: @src_2_bitfield_op(
; CHECK-NEXT: entry:
; CHECK-NEXT: [[NARROW:%.*]] = add i8 [[Y:%.*]], [[X:%.*]]
; CHECK-NEXT: [[BF_VALUE:%.*]] = and i8 [[NARROW]], 7
; CHECK-NEXT: [[BF_LSHR:%.*]] = and i8 [[X]], 24
; CHECK-NEXT: [[BF_LSHR1228:%.*]] = add i8 [[BF_LSHR]], [[Y]]
; CHECK-NEXT: [[BF_SHL:%.*]] = and i8 [[BF_LSHR1228]], 24
; CHECK-NEXT: [[BF_SET20:%.*]] = or disjoint i8 [[BF_VALUE]], [[BF_SHL]]
; CHECK-NEXT: [[TMP0:%.*]] = and i8 [[X:%.*]], 11
; CHECK-NEXT: [[TMP1:%.*]] = and i8 [[Y:%.*]], 11
; CHECK-NEXT: [[TMP2:%.*]] = add nuw nsw i8 [[TMP0]], [[TMP1]]
; CHECK-NEXT: [[TMP3:%.*]] = xor i8 [[X]], [[Y]]
; CHECK-NEXT: [[TMP4:%.*]] = and i8 [[TMP3]], 20
; CHECK-NEXT: [[BF_SET20:%.*]] = xor i8 [[TMP2]], [[TMP4]]
; CHECK-NEXT: ret i8 [[BF_SET20]]
;
entry:
Expand All @@ -1928,11 +1928,10 @@ entry:
define i8 @src_2_bitfield_const(i8 %x) {
; CHECK-LABEL: @src_2_bitfield_const(
; CHECK-NEXT: entry:
; CHECK-NEXT: [[NARROW:%.*]] = add i8 [[X:%.*]], 1
; CHECK-NEXT: [[BF_VALUE:%.*]] = and i8 [[NARROW]], 7
; CHECK-NEXT: [[BF_LSHR1228:%.*]] = add i8 [[X]], 8
; CHECK-NEXT: [[BF_SHL:%.*]] = and i8 [[BF_LSHR1228]], 24
; CHECK-NEXT: [[BF_SET20:%.*]] = or disjoint i8 [[BF_VALUE]], [[BF_SHL]]
; CHECK-NEXT: [[TMP0:%.*]] = and i8 [[X:%.*]], 11
; CHECK-NEXT: [[TMP1:%.*]] = add nuw nsw i8 [[TMP0]], 9
; CHECK-NEXT: [[TMP2:%.*]] = and i8 [[X]], 20
; CHECK-NEXT: [[BF_SET20:%.*]] = xor i8 [[TMP1]], [[TMP2]]
; CHECK-NEXT: ret i8 [[BF_SET20]]
;
entry:
Expand All @@ -1948,16 +1947,12 @@ entry:
define i8 @src_3_bitfield_op(i8 %x, i8 %y) {
; CHECK-LABEL: @src_3_bitfield_op(
; CHECK-NEXT: entry:
; CHECK-NEXT: [[NARROW:%.*]] = add i8 [[Y:%.*]], [[X:%.*]]
; CHECK-NEXT: [[BF_VALUE:%.*]] = and i8 [[NARROW]], 7
; CHECK-NEXT: [[BF_LSHR:%.*]] = and i8 [[X]], 24
; CHECK-NEXT: [[BF_LSHR1244:%.*]] = add i8 [[BF_LSHR]], [[Y]]
; CHECK-NEXT: [[BF_SHL:%.*]] = and i8 [[BF_LSHR1244]], 24
; CHECK-NEXT: [[BF_SET20:%.*]] = or disjoint i8 [[BF_VALUE]], [[BF_SHL]]
; CHECK-NEXT: [[BF_LSHR22:%.*]] = and i8 [[X]], -32
; CHECK-NEXT: [[BF_LSHR2547:%.*]] = add i8 [[BF_LSHR22]], [[Y]]
; CHECK-NEXT: [[BF_VALUE30:%.*]] = and i8 [[BF_LSHR2547]], -32
; CHECK-NEXT: [[BF_SET33:%.*]] = or disjoint i8 [[BF_SET20]], [[BF_VALUE30]]
; CHECK-NEXT: [[TMP0:%.*]] = and i8 [[X:%.*]], 107
; CHECK-NEXT: [[TMP1:%.*]] = and i8 [[Y:%.*]], 107
; CHECK-NEXT: [[TMP2:%.*]] = add nuw i8 [[TMP0]], [[TMP1]]
; CHECK-NEXT: [[TMP3:%.*]] = xor i8 [[X]], [[Y]]
; CHECK-NEXT: [[TMP4:%.*]] = and i8 [[TMP3]], -108
; CHECK-NEXT: [[BF_SET33:%.*]] = xor i8 [[TMP2]], [[TMP4]]
; CHECK-NEXT: ret i8 [[BF_SET33]]
;
entry:
Expand All @@ -1977,14 +1972,10 @@ entry:
define i8 @src_3_bitfield_const(i8 %x) {
; CHECK-LABEL: @src_3_bitfield_const(
; CHECK-NEXT: entry:
; CHECK-NEXT: [[NARROW:%.*]] = add i8 [[X:%.*]], 1
; CHECK-NEXT: [[BF_VALUE:%.*]] = and i8 [[NARROW]], 7
; CHECK-NEXT: [[BF_LSHR1244:%.*]] = add i8 [[X]], 8
; CHECK-NEXT: [[BF_SHL:%.*]] = and i8 [[BF_LSHR1244]], 24
; CHECK-NEXT: [[BF_SET20:%.*]] = or disjoint i8 [[BF_VALUE]], [[BF_SHL]]
; CHECK-NEXT: [[TMP0:%.*]] = and i8 [[X]], -32
; CHECK-NEXT: [[BF_VALUE30:%.*]] = add i8 [[TMP0]], 32
; CHECK-NEXT: [[BF_SET33:%.*]] = or disjoint i8 [[BF_SET20]], [[BF_VALUE30]]
; CHECK-NEXT: [[TMP0:%.*]] = and i8 [[X:%.*]], 107
; CHECK-NEXT: [[TMP1:%.*]] = add nuw i8 [[TMP0]], 41
; CHECK-NEXT: [[TMP2:%.*]] = and i8 [[X]], -108
; CHECK-NEXT: [[BF_SET33:%.*]] = xor i8 [[TMP1]], [[TMP2]]
; CHECK-NEXT: ret i8 [[BF_SET33]]
;
entry:
Expand Down Expand Up @@ -2064,12 +2055,12 @@ entry:
define i8 @src_bit_arithmetic_bitsize_1_high(i8 %x, i8 %y) {
; CHECK-LABEL: @src_bit_arithmetic_bitsize_1_high(
; CHECK-NEXT: entry:
; CHECK-NEXT: [[NARROW:%.*]] = add i8 [[Y:%.*]], [[X:%.*]]
; CHECK-NEXT: [[BF_VALUE:%.*]] = and i8 [[NARROW]], 7
; CHECK-NEXT: [[BF_LSHR:%.*]] = and i8 [[X]], 120
; CHECK-NEXT: [[BF_LSHR1244:%.*]] = add i8 [[BF_LSHR]], [[Y]]
; CHECK-NEXT: [[BF_SHL:%.*]] = and i8 [[BF_LSHR1244]], 120
; CHECK-NEXT: [[BF_SET20:%.*]] = or disjoint i8 [[BF_VALUE]], [[BF_SHL]]
; CHECK-NEXT: [[TMP0:%.*]] = and i8 [[X:%.*]], 59
; CHECK-NEXT: [[TMP1:%.*]] = and i8 [[Y:%.*]], 59
; CHECK-NEXT: [[TMP2:%.*]] = add nuw nsw i8 [[TMP0]], [[TMP1]]
; CHECK-NEXT: [[TMP3:%.*]] = xor i8 [[X]], [[Y]]
; CHECK-NEXT: [[TMP4:%.*]] = and i8 [[TMP3]], 68
; CHECK-NEXT: [[BF_SET20:%.*]] = xor i8 [[TMP2]], [[TMP4]]
; CHECK-NEXT: [[BF_LSHR22:%.*]] = and i8 [[X]], -128
; CHECK-NEXT: [[BF_LSHR2547:%.*]] = add i8 [[BF_LSHR22]], [[Y]]
; CHECK-NEXT: [[BF_VALUE30:%.*]] = and i8 [[BF_LSHR2547]], -128
Expand Down Expand Up @@ -2122,12 +2113,12 @@ entry:
define i8 @src_bit_arithmetic_bitmask_mid_over_high(i8 %x, i8 %y) {
; CHECK-LABEL: @src_bit_arithmetic_bitmask_mid_over_high(
; CHECK-NEXT: entry:
; CHECK-NEXT: [[NARROW:%.*]] = add i8 [[Y:%.*]], [[X:%.*]]
; CHECK-NEXT: [[BF_VALUE:%.*]] = and i8 [[NARROW]], 7
; CHECK-NEXT: [[BF_LSHR:%.*]] = and i8 [[X]], 56
; CHECK-NEXT: [[BF_LSHR1244:%.*]] = add i8 [[BF_LSHR]], [[Y]]
; CHECK-NEXT: [[BF_SHL:%.*]] = and i8 [[BF_LSHR1244]], 56
; CHECK-NEXT: [[BF_SET20:%.*]] = or disjoint i8 [[BF_VALUE]], [[BF_SHL]]
; CHECK-NEXT: [[TMP0:%.*]] = and i8 [[X:%.*]], 27
; CHECK-NEXT: [[TMP1:%.*]] = and i8 [[Y:%.*]], 27
; CHECK-NEXT: [[TMP2:%.*]] = add nuw nsw i8 [[TMP0]], [[TMP1]]
; CHECK-NEXT: [[TMP3:%.*]] = xor i8 [[X]], [[Y]]
; CHECK-NEXT: [[TMP4:%.*]] = and i8 [[TMP3]], 36
; CHECK-NEXT: [[BF_SET20:%.*]] = xor i8 [[TMP2]], [[TMP4]]
; CHECK-NEXT: [[BF_LSHR22:%.*]] = and i8 [[X]], -32
; CHECK-NEXT: [[BF_LSHR2547:%.*]] = add i8 [[BF_LSHR22]], [[Y]]
; CHECK-NEXT: [[BF_VALUE30:%.*]] = and i8 [[BF_LSHR2547]], -32
Expand Down Expand Up @@ -2180,12 +2171,12 @@ entry:
define i8 @src_bit_arithmetic_bitmask_high_under_mid(i8 %x, i8 %y) {
; CHECK-LABEL: @src_bit_arithmetic_bitmask_high_under_mid(
; CHECK-NEXT: entry:
; CHECK-NEXT: [[NARROW:%.*]] = add i8 [[Y:%.*]], [[X:%.*]]
; CHECK-NEXT: [[BF_VALUE:%.*]] = and i8 [[NARROW]], 7
; CHECK-NEXT: [[BF_LSHR:%.*]] = and i8 [[X]], 24
; CHECK-NEXT: [[BF_LSHR1244:%.*]] = add i8 [[BF_LSHR]], [[Y]]
; CHECK-NEXT: [[BF_SHL:%.*]] = and i8 [[BF_LSHR1244]], 24
; CHECK-NEXT: [[BF_SET20:%.*]] = or disjoint i8 [[BF_VALUE]], [[BF_SHL]]
; CHECK-NEXT: [[TMP0:%.*]] = and i8 [[X:%.*]], 11
; CHECK-NEXT: [[TMP1:%.*]] = and i8 [[Y:%.*]], 11
; CHECK-NEXT: [[TMP2:%.*]] = add nuw nsw i8 [[TMP0]], [[TMP1]]
; CHECK-NEXT: [[TMP3:%.*]] = xor i8 [[X]], [[Y]]
; CHECK-NEXT: [[TMP4:%.*]] = and i8 [[TMP3]], 20
; CHECK-NEXT: [[BF_SET20:%.*]] = xor i8 [[TMP2]], [[TMP4]]
; CHECK-NEXT: [[BF_LSHR22:%.*]] = and i8 [[X]], -16
; CHECK-NEXT: [[BF_LSHR2547:%.*]] = add i8 [[BF_LSHR22]], [[Y]]
; CHECK-NEXT: [[BF_VALUE30:%.*]] = and i8 [[BF_LSHR2547]], -16
Expand Down Expand Up @@ -2287,11 +2278,10 @@ entry:
define i8 @src_bit_arithmetic_addition_under_bitmask_high(i8 %x) {
; CHECK-LABEL: @src_bit_arithmetic_addition_under_bitmask_high(
; CHECK-NEXT: entry:
; CHECK-NEXT: [[NARROW:%.*]] = add i8 [[X:%.*]], 1
; CHECK-NEXT: [[BF_VALUE:%.*]] = and i8 [[NARROW]], 7
; CHECK-NEXT: [[BF_LSHR1244:%.*]] = add i8 [[X]], 8
; CHECK-NEXT: [[BF_SHL:%.*]] = and i8 [[BF_LSHR1244]], 24
; CHECK-NEXT: [[BF_SET20:%.*]] = or disjoint i8 [[BF_VALUE]], [[BF_SHL]]
; CHECK-NEXT: [[TMP0:%.*]] = and i8 [[X:%.*]], 11
; CHECK-NEXT: [[TMP1:%.*]] = add nuw nsw i8 [[TMP0]], 9
; CHECK-NEXT: [[TMP2:%.*]] = and i8 [[X]], 20
; CHECK-NEXT: [[BF_SET20:%.*]] = xor i8 [[TMP1]], [[TMP2]]
; CHECK-NEXT: [[BF_LSHR22:%.*]] = and i8 [[X]], -32
; CHECK-NEXT: [[BF_SET33:%.*]] = or disjoint i8 [[BF_SET20]], [[BF_LSHR22]]
; CHECK-NEXT: ret i8 [[BF_SET33]]
Expand Down

0 comments on commit a10b733

Please sign in to comment.