Skip to content

Commit

Permalink
[InstCombine] Improve bitfield addition (#33874)
Browse files Browse the repository at this point in the history
  • Loading branch information
ParkHanbum committed Jan 23, 2024
1 parent acd825c commit 1db7e0c
Showing 1 changed file with 146 additions and 0 deletions.
146 changes: 146 additions & 0 deletions llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3338,6 +3338,149 @@ Value *InstCombinerImpl::foldAndOrOfICmps(ICmpInst *LHS, ICmpInst *RHS,
return foldAndOrOfICmpsUsingRanges(LHS, RHS, IsAnd);
}

struct BitFieldAddBitMask {
const APInt *Lower;
const APInt *Upper;
};
struct BitFieldOptBitMask {
const APInt *Lower;
const APInt *Upper;
const APInt *New;
};
struct BitFieldAddInfo {
Value *X;
Value *Y;
bool opt;
union {
BitFieldAddBitMask AddMask;
BitFieldOptBitMask OptMask;
};
};

static Value *foldBitFieldArithmetic(BinaryOperator &I,
InstCombiner::BuilderTy &Builder) {
auto *Disjoint = dyn_cast<PossiblyDisjointInst>(&I);
if (!Disjoint || !Disjoint->isDisjoint())
return nullptr;

unsigned BitWidth = I.getType()->getScalarSizeInBits();
auto AccumulateY = [&](Value *LoY, Value *UpY, APInt LoMask,
APInt UpMask) -> Value * {
Value *Y = nullptr;
auto CLoY = dyn_cast_or_null<Constant>(LoY);
auto CUpY = dyn_cast_or_null<Constant>(UpY);
if ((CLoY == nullptr) ^ (CUpY == nullptr))
return nullptr;

if (CLoY && CUpY) {
APInt IUpY = CUpY->getUniqueInteger();
APInt ILoY = CLoY->getUniqueInteger();
if (!(IUpY.isSubsetOf(UpMask) && ILoY.isSubsetOf(LoMask)))
return nullptr;
Y = ConstantInt::get(CLoY->getType(), ILoY + IUpY);
} else if (LoY == UpY) {
Y = LoY;
}

return Y;
};

auto MatchBitFieldAdd =
[&](BinaryOperator &I) -> std::optional<BitFieldAddInfo> {
const APInt *OptLoMask, *OptUpMask, *LoMask, *UpMask, *UpMask2 = nullptr;
Value *X, *Y, *UpY;
auto BitFieldAddUpper = m_CombineOr(
m_And(m_c_Add(m_And(m_Value(X), m_APInt(UpMask)), m_Value(UpY)),
m_APInt(UpMask2)),
m_c_Add(m_And(m_Value(X), m_APInt(UpMask)), m_Value(UpY)));
auto BitFieldAdd =
m_c_Or(BitFieldAddUpper,
m_And(m_c_Add(m_Deferred(X), m_Value(Y)), m_APInt(LoMask)));
auto BitFieldAddIC =
m_c_Or(m_And(m_c_Add(m_Value(X), m_Value(Y)), m_APInt(LoMask)),
m_And(m_c_Add(m_Deferred(X), m_Value(UpY)), m_APInt(UpMask)));
auto OptBitFieldAdd = m_c_Or(
m_c_Xor(m_CombineOr(
m_c_Add(m_And(m_Value(X), m_APInt(OptLoMask)),
m_And(m_Value(Y), m_APInt(OptLoMask))),
m_c_Add(m_And(m_Value(X), m_APInt(OptLoMask)), m_Value(Y))),
m_CombineOr(m_And(m_Deferred(X), m_APInt(OptUpMask)),
m_And(m_c_Xor(m_Deferred(X), m_Value(UpY)),
m_APInt(OptUpMask)))),
BitFieldAddUpper);

if (match(&I, BitFieldAdd) || match(&I, BitFieldAddIC)) {
APInt Mask = APInt::getBitsSet(BitWidth, BitWidth - UpMask->countl_zero(),
BitWidth);
if (!((UpMask2 == nullptr || *UpMask == *UpMask2) &&
(LoMask->popcount() >= 2 && UpMask->popcount() >= 2) &&
(LoMask->isShiftedMask() && UpMask->isShiftedMask()) &&
((*LoMask & *UpMask) == 0) &&
((Mask ^ *LoMask ^ *UpMask).isAllOnes())))
return std::nullopt;

if (!(Y = AccumulateY(Y, UpY, *LoMask, *UpMask)))
return std::nullopt;

return {{X, Y, false, {{LoMask, UpMask}}}};
}

if (match(&I, OptBitFieldAdd)) {
APInt Mask = APInt::getBitsSet(
BitWidth, BitWidth - OptUpMask->countl_zero(), BitWidth);
APInt Mask2 = APInt::getBitsSet(
BitWidth, BitWidth - UpMask->countl_zero(), BitWidth);
if (!((UpMask2 == nullptr || *UpMask == *UpMask2) &&
(UpMask->isShiftedMask() && UpMask->popcount() >= 2) &&
((*UpMask & (*OptLoMask | *OptUpMask)) == 0) &&
((~*OptLoMask ^ Mask) == *OptUpMask) &&
(Mask2 ^ *UpMask ^ (*OptLoMask ^ *OptUpMask)).isAllOnes()))
return std::nullopt;

if (!(Y = AccumulateY(Y, UpY, (*OptLoMask + *OptUpMask), *UpMask)))
return std::nullopt;

struct BitFieldAddInfo Info = {X, Y, true, {{OptLoMask, OptUpMask}}};
Info.OptMask.New = UpMask;
return {Info};
}

return std::nullopt;
};

auto Info = MatchBitFieldAdd(I);
if (Info) {
Value *X = Info->X;
Value *Y = Info->Y;
APInt BitLoMask, BitUpMask;
if (Info->opt) {
unsigned NewHiBit = BitWidth - (Info->OptMask.New->countl_zero() + 1);
BitLoMask = *Info->OptMask.Lower | *Info->OptMask.New;
BitLoMask.clearBit(NewHiBit);
BitUpMask = *Info->OptMask.Upper;
BitUpMask.setBit(NewHiBit);
} else {
unsigned LowerHiBit = BitWidth - (Info->AddMask.Lower->countl_zero() + 1);
unsigned UpperHiBit = BitWidth - (Info->AddMask.Upper->countl_zero() + 1);
BitLoMask = *Info->AddMask.Lower | *Info->AddMask.Upper;
BitLoMask.clearBit(LowerHiBit);
BitLoMask.clearBit(UpperHiBit);
BitUpMask = APInt::getOneBitSet(BitWidth, LowerHiBit);
BitUpMask.setBit(UpperHiBit);
}

auto AndXLower = Builder.CreateAnd(X, BitLoMask);
auto AndYLower = Builder.CreateAnd(Y, BitLoMask);
auto Add = Builder.CreateNUWAdd(AndXLower, AndYLower);
auto Xor1 = Builder.CreateXor(X, Y);
auto AndUpper = Builder.CreateAnd(Xor1, BitUpMask);
auto Xor = Builder.CreateXor(Add, AndUpper);
return Xor;
}

return nullptr;
}

// FIXME: We use commutative matchers (m_c_*) for some, but not all, matches
// here. We should standardize that construct where it is needed or choose some
// other way to ensure that commutated variants of patterns are not missed.
Expand Down Expand Up @@ -3911,6 +4054,9 @@ Instruction *InstCombinerImpl::visitOr(BinaryOperator &I) {
if (Instruction *Res = foldBitwiseLogicWithIntrinsics(I, Builder))
return Res;

if (Value *Res = foldBitFieldArithmetic(I, Builder))
return replaceInstUsesWith(I, Res);

return nullptr;
}

Expand Down

0 comments on commit 1db7e0c

Please sign in to comment.