Skip to content

Commit

Permalink
[InstCombine] Improve bitfield addition (#33874)
Browse files Browse the repository at this point in the history
  • Loading branch information
ParkHanbum committed Jul 18, 2024
1 parent 97e293f commit 9c5c9a1
Show file tree
Hide file tree
Showing 2 changed files with 291 additions and 52 deletions.
249 changes: 249 additions & 0 deletions llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3485,6 +3485,252 @@ static Value *foldOrOfInversions(BinaryOperator &I,
return nullptr;
}

struct BitFieldAddBitMask {
const APInt *Lower;
const APInt *Upper;
};
struct BitFieldOptBitMask {
const APInt *Lower;
const APInt *Upper;
const APInt *New;
};
struct BitFieldAddInfo {
Value *X;
Value *Y;
bool opt;
union {
BitFieldAddBitMask AddMask;
BitFieldOptBitMask OptMask;
};
};

/// Bitfield operation is consisted of three-step as following,
/// 1. extracting the bits
/// 2. performing operations
/// 3. eliminating the bits beyond the specified range
///
/// Depending on the location of the bitfield on which we want to perform
/// the operation, all or only some of these steps are performed.
///
/// Consider:
/// %narrow = add i8 %y, %x
/// %bf.value = and i8 %narrow, 7
/// %bf.lshr = and i8 %x, 24
/// %bf.lshr1244 = add i8 %bf.lshr, %y
/// %bf.shl = and i8 %bf.lshr1244, 24
/// %bf.set20 = or disjoint i8 %bf.value, %bf.shl
///
/// This example show us bitfield operation that doing 0-3 bit first, 4-5 bit
/// second. as you can see, first 0-3 bitfield operation do not proceed step 1,
/// it is not necessary because it located bottom of bitfield. after that,
/// second 4-5 bit operation proceed 3-step as above described.
///
/// After the operation for each bitfield is completed, all bits are collected
/// through the `or disjoint` operation and the result is returned.
///
/// Our optimizing oppotunity is reducing 3-step of bitfield operation.
/// We show you optimized example with constant for more intuitive describing.
///
/// Consider:
/// (first) (second) (final)
/// ????????(x) ????????(x) 00000???
/// + 00000001 & 00011000 | 000??000
/// ---------- ---------- ----------
/// 0000???? 000??000 = 000?????
/// & 00000111 + 00001000
/// = 00000??? ----------
/// 00???000
/// & 00011000
/// ----------
/// = 000??000
///
/// Optimized:
/// (first) (second) (final)
/// 000????? (x) 000????? (x) 000????? (x&11) + 9
/// & 00001011 & 00010100 ^ 000?0?00 (x&20)
/// ---------- ---------- ----------
/// 0000?0?? (x & 11) = 000?0?00 = 000?????
/// + 00001001
/// ----------
/// = 000????? (x&11) + 9
///
/// 1. Extract each bitfield exclude high bit.
/// 2. Add sum of all values to be added to each bitfield.
/// 3. Extract high bits of each bitfield.
/// 4. Perform ExcludeOR with 2 and 3.
///
/// The most important logic here is part 4. ExclusiveOR operation is performed
/// on the highest bit of each pre-extracted bit field and the value after the
/// addition operation. Through this, we can obtain normally addition perfomed
/// results for the highest bit of the bitfield without removing the overflowed
/// bit.
static Value *foldBitFieldArithmetic(BinaryOperator &I,
InstCombiner::BuilderTy &Builder) {
auto *Disjoint = dyn_cast<PossiblyDisjointInst>(&I);
if (!Disjoint || !Disjoint->isDisjoint())
return nullptr;

unsigned BitWidth = I.getType()->getScalarSizeInBits();

// If operand of bitfield operation is a constant, sum of the constants is
// computed and returned. if operand is not a constant, operand is
// returned. if this operation is not a bitfield operation, null is returned.
auto AccumulateY = [&](Value *LoY, Value *UpY, APInt LoMask,
APInt UpMask) -> Value * {
Value *Y = nullptr;
auto *CLoY = dyn_cast_or_null<Constant>(LoY);
auto *CUpY = dyn_cast_or_null<Constant>(UpY);
// If one of operand is constant, other also must be constant.
if ((CLoY == nullptr) ^ (CUpY == nullptr))
return nullptr;

if (CLoY && CUpY) {
APInt IUpY = CUpY->getUniqueInteger();
APInt ILoY = CLoY->getUniqueInteger();
// Each operands bits must in range of its own field.
if (!(IUpY.isSubsetOf(UpMask) && ILoY.isSubsetOf(LoMask)))
return nullptr;
Y = ConstantInt::get(CLoY->getType(), ILoY + IUpY);
} else if (LoY == UpY) {
Y = LoY;
}

return Y;
};

// Perform whether this `OR disjoint` instruction is bitfield operation
// In the case of bitfield operation, the information necessary
// to optimize the bitfield operation is extracted and returned as
// BitFieldAddInfo.
auto MatchBitFieldAdd =
[&](BinaryOperator &I) -> std::optional<BitFieldAddInfo> {
const APInt *OptLoMask, *OptUpMask, *LoMask, *UpMask, *UpMask2 = nullptr;
Value *X, *Y, *UpY;

// Bitfield has more than 2 member.
// ((X&UpMask)+UpY)&UpMask2 | (X&UpMask)+UpY
auto BitFieldAddUpper = m_CombineOr(
m_And(m_c_Add(m_And(m_Value(X), m_APInt(UpMask)), m_Value(UpY)),
m_APInt(UpMask2)),
m_c_Add(m_And(m_Value(X), m_APInt(UpMask)), m_Value(UpY)));
// Bitfield has more than 2 member but bottom bitfield
// BitFieldAddUpper | (X+Y)&LoMask
auto BitFieldAdd =
m_c_Or(BitFieldAddUpper,
m_And(m_c_Add(m_Deferred(X), m_Value(Y)), m_APInt(LoMask)));
// When bitfield has only 2 member
// (X+Y)&HiMask | (X+UpY)&LoMask
auto BitFieldAddIC =
m_c_Or(m_And(m_c_Add(m_Value(X), m_Value(Y)), m_APInt(LoMask)),
m_And(m_c_Add(m_Deferred(X), m_Value(UpY)), m_APInt(UpMask)));
// When `Or optimized-bitfield, BitFieldAddUpper` matched
// OptUpMask = highest bits of each bitfield
// OptLoMask = all bit of bitfield excluded highest bit
// BitFieldAddUpper | ((X&OptLoMask)+Y) ^ ((X&OptUpMask))
auto OptBitFieldAdd = m_c_Or(
m_c_Xor(m_CombineOr(
// When Y is not the constant.
m_c_Add(m_And(m_Value(X), m_APInt(OptLoMask)),
m_And(m_Value(Y), m_APInt(OptLoMask))),
// When Y is Constant, it can be accumulated.
m_c_Add(m_And(m_Value(X), m_APInt(OptLoMask)), m_Value(Y))),
// If Y is a constant, X^Y&OptUpMask can be pre-computed and
// OptUpMask is its result.
m_CombineOr(m_And(m_Deferred(X), m_APInt(OptUpMask)),
m_And(m_c_Xor(m_Deferred(X), m_Value(UpY)),
m_APInt(OptUpMask)))),
BitFieldAddUpper);

// Match bitfield operation.
if (match(&I, BitFieldAdd) || match(&I, BitFieldAddIC)) {
APInt Mask = APInt::getBitsSet(BitWidth, BitWidth - UpMask->countl_zero(),
BitWidth);

if (!((UpMask2 == nullptr || *UpMask == *UpMask2) &&
(LoMask->popcount() >= 2 && UpMask->popcount() >= 2) &&
(LoMask->isShiftedMask() && UpMask->isShiftedMask()) &&
// Lo & Hi mask must have no common bits
((*LoMask & *UpMask) == 0) &&
// These masks must fill all bits while having no common bits.
((Mask ^ *LoMask ^ *UpMask).isAllOnes())))
return std::nullopt;

if (!(Y = AccumulateY(Y, UpY, *LoMask, *UpMask)))
return std::nullopt;

return {{X, Y, false, {{LoMask, UpMask}}}};
}

// Match already optimized bitfield operation.
if (match(&I, OptBitFieldAdd)) {
APInt Mask = APInt::getBitsSet(
BitWidth, BitWidth - OptUpMask->countl_zero(), BitWidth);
APInt Mask2 = APInt::getBitsSet(
BitWidth, BitWidth - UpMask->countl_zero(), BitWidth);

// OptLoMask : includes bits of each bit field member, but excludes
// highest bit of each bit field.
// OptHiMask : includes bits only highest bit of each member.
if (!((UpMask2 == nullptr || *UpMask == *UpMask2) &&
(UpMask->isShiftedMask() && UpMask->popcount() >= 2) &&
// must have no common bit if this operation is bitfield
((*UpMask & (*OptLoMask | *OptUpMask)) == 0) &&
// NOT(OptLoMask) must be equals OptUpMask
((~*OptLoMask ^ Mask) == *OptUpMask) &&
// These masks must fill all bits while having no common bits.
(Mask2 ^ *UpMask ^ (*OptLoMask ^ *OptUpMask)).isAllOnes()))
return std::nullopt;

if (!(Y = AccumulateY(Y, UpY, (*OptLoMask + *OptUpMask), *UpMask)))
return std::nullopt;

struct BitFieldAddInfo Info = {X, Y, true, {{OptLoMask, OptUpMask}}};
Info.OptMask.New = UpMask;
return {Info};
}

return std::nullopt;
};

if (std::optional<BitFieldAddInfo> Info = MatchBitFieldAdd(I)) {
Value *X = Info->X;
Value *Y = Info->Y;
APInt BitLoMask, BitUpMask;
if (Info->opt) {
unsigned NewHiBit = BitWidth - (Info->OptMask.New->countl_zero() + 1);
// BitLoMask inlude bits of OptMask.New exclude its highest bit
BitLoMask = *Info->OptMask.Lower | *Info->OptMask.New;
BitLoMask.clearBit(NewHiBit);
// BitUpMask only include highest bit of OptMask.New
BitUpMask = *Info->OptMask.Upper;
BitUpMask.setBit(NewHiBit);
} else {
// In case BitField operation, we create new optmized bitfield mask.
unsigned LowerHiBit = BitWidth - (Info->AddMask.Lower->countl_zero() + 1);
unsigned UpperHiBit = BitWidth - (Info->AddMask.Upper->countl_zero() + 1);
// BitLoMask include all bits of each bitfield but exclude its highest
// bits
BitLoMask = *Info->AddMask.Lower | *Info->AddMask.Upper;
BitLoMask.clearBit(LowerHiBit);
BitLoMask.clearBit(UpperHiBit);
// BitUpMask only include highest bit of each bitfield.
BitUpMask = APInt::getOneBitSet(BitWidth, LowerHiBit);
BitUpMask.setBit(UpperHiBit);
}

// Create optimized bitfield operation logic using the created bitmask.
Value *AndXLower = Builder.CreateAnd(X, BitLoMask);
Value *AndYLower = Builder.CreateAnd(Y, BitLoMask);
Value *Add = Builder.CreateNUWAdd(AndXLower, AndYLower);
Value *Xor1 = Builder.CreateXor(X, Y);
Value *AndUpper = Builder.CreateAnd(Xor1, BitUpMask);
Value *Xor = Builder.CreateXor(Add, AndUpper);
return Xor;
}

return nullptr;
}

// FIXME: We use commutative matchers (m_c_*) for some, but not all, matches
// here. We should standardize that construct where it is needed or choose some
// other way to ensure that commutated variants of patterns are not missed.
Expand Down Expand Up @@ -4034,6 +4280,9 @@ Instruction *InstCombinerImpl::visitOr(BinaryOperator &I) {
if (Value *V = SimplifyAddWithRemainder(I))
return replaceInstUsesWith(I, V);

if (Value *Res = foldBitFieldArithmetic(I, Builder))
return replaceInstUsesWith(I, Res);

return nullptr;
}

Expand Down
Loading

0 comments on commit 9c5c9a1

Please sign in to comment.