Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[InstCombine] Improve bitfield addition #77184

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
249 changes: 249 additions & 0 deletions llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3485,6 +3485,252 @@ static Value *foldOrOfInversions(BinaryOperator &I,
return nullptr;
}

struct BitFieldAddBitMask {
dtcxzyw marked this conversation as resolved.
Show resolved Hide resolved
const APInt *Lower;
const APInt *Upper;
};
struct BitFieldOptBitMask {
const APInt *Lower;
const APInt *Upper;
const APInt *New;
};
struct BitFieldAddInfo {
Value *X;
Value *Y;
bool opt;
union {
BitFieldAddBitMask AddMask;
BitFieldOptBitMask OptMask;
};
};

/// Bitfield operation is consisted of three-step as following,
/// 1. extracting the bits
/// 2. performing operations
/// 3. eliminating the bits beyond the specified range
///
/// Depending on the location of the bitfield on which we want to perform
/// the operation, all or only some of these steps are performed.
///
/// Consider:
/// %narrow = add i8 %y, %x
/// %bf.value = and i8 %narrow, 7
/// %bf.lshr = and i8 %x, 24
/// %bf.lshr1244 = add i8 %bf.lshr, %y
/// %bf.shl = and i8 %bf.lshr1244, 24
/// %bf.set20 = or disjoint i8 %bf.value, %bf.shl
///
/// This example show us bitfield operation that doing 0-3 bit first, 4-5 bit
/// second. as you can see, first 0-3 bitfield operation do not proceed step 1,
/// it is not necessary because it located bottom of bitfield. after that,
/// second 4-5 bit operation proceed 3-step as above described.
///
/// After the operation for each bitfield is completed, all bits are collected
/// through the `or disjoint` operation and the result is returned.
///
/// Our optimizing oppotunity is reducing 3-step of bitfield operation.
/// We show you optimized example with constant for more intuitive describing.
///
/// Consider:
/// (first) (second) (final)
/// ????????(x) ????????(x) 00000???
/// + 00000001 & 00011000 | 000??000
/// ---------- ---------- ----------
/// 0000???? 000??000 = 000?????
/// & 00000111 + 00001000
/// = 00000??? ----------
/// 00???000
/// & 00011000
/// ----------
/// = 000??000
///
/// Optimized:
/// (first) (second) (final)
/// 000????? (x) 000????? (x) 000????? (x&11) + 9
/// & 00001011 & 00010100 ^ 000?0?00 (x&20)
/// ---------- ---------- ----------
/// 0000?0?? (x & 11) = 000?0?00 = 000?????
/// + 00001001
/// ----------
/// = 000????? (x&11) + 9
///
/// 1. Extract each bitfield exclude high bit.
/// 2. Add sum of all values to be added to each bitfield.
/// 3. Extract high bits of each bitfield.
/// 4. Perform ExcludeOR with 2 and 3.
///
/// The most important logic here is part 4. ExclusiveOR operation is performed
/// on the highest bit of each pre-extracted bit field and the value after the
/// addition operation. Through this, we can obtain normally addition perfomed
/// results for the highest bit of the bitfield without removing the overflowed
/// bit.
static Value *foldBitFieldArithmetic(BinaryOperator &I,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To address the reported issues, this has been written in terms of improving bitfield math, but I'm not certain if we should be addressing this in terms of more generic canonicalizations or not - are we likely to hit sub-parts of this in other places do you think?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

honestly I have no idea. I focused how to implement your optimization idea in my boundary of llvm knowledge. some advise please to me to let me can think about that. what's mean terms of more generic canonicalizations or not?

InstCombiner::BuilderTy &Builder) {
auto *Disjoint = dyn_cast<PossiblyDisjointInst>(&I);
if (!Disjoint || !Disjoint->isDisjoint())
return nullptr;

unsigned BitWidth = I.getType()->getScalarSizeInBits();

// If operand of bitfield operation is a constant, sum of the constants is
// computed and returned. if operand is not a constant, operand is
// returned. if this operation is not a bitfield operation, null is returned.
auto AccumulateY = [&](Value *LoY, Value *UpY, const APInt LoMask,
const APInt UpMask) -> Value * {
Value *Y = nullptr;
auto *CLoY = dyn_cast_or_null<Constant>(LoY);
auto *CUpY = dyn_cast_or_null<Constant>(UpY);
// If one of operand is constant, other also must be constant.
if ((CLoY == nullptr) ^ (CUpY == nullptr))
return nullptr;

if (CLoY && CUpY) {
APInt IUpY = CUpY->getUniqueInteger();
APInt ILoY = CLoY->getUniqueInteger();
// Each operands bits must in range of its own field.
if (!(IUpY.isSubsetOf(UpMask) && ILoY.isSubsetOf(LoMask)))
return nullptr;
Y = ConstantInt::get(CLoY->getType(), ILoY + IUpY);
} else if (LoY == UpY) {
Y = LoY;
}

return Y;
};

// Perform whether this `OR disjoint` instruction is bitfield operation
// In the case of bitfield operation, the information necessary
// to optimize the bitfield operation is extracted and returned as
// BitFieldAddInfo.
auto MatchBitFieldAdd =
[&](BinaryOperator &I) -> std::optional<BitFieldAddInfo> {
const APInt *OptLoMask, *OptUpMask, *LoMask, *UpMask, *UpMask2 = nullptr;
Value *X, *Y, *UpY;

// Bitfield has more than 2 member.
// ((X&UpMask)+UpY)&UpMask2 | (X&UpMask)+UpY
auto BitFieldAddUpper = m_CombineOr(
m_And(m_c_Add(m_And(m_Value(X), m_APInt(UpMask)), m_Value(UpY)),
m_APInt(UpMask2)),
m_c_Add(m_And(m_Value(X), m_APInt(UpMask)), m_Value(UpY)));
// Bitfield has more than 2 member but bottom bitfield
// BitFieldAddUpper | (X+Y)&LoMask
auto BitFieldAdd =
m_c_Or(BitFieldAddUpper,
m_And(m_c_Add(m_Deferred(X), m_Value(Y)), m_APInt(LoMask)));
// When bitfield has only 2 member
// (X+Y)&HiMask | (X+UpY)&LoMask
auto BitFieldAddIC =
m_c_Or(m_And(m_c_Add(m_Value(X), m_Value(Y)), m_APInt(LoMask)),
m_And(m_c_Add(m_Deferred(X), m_Value(UpY)), m_APInt(UpMask)));
// When `Or optimized-bitfield, BitFieldAddUpper` matched
// OptUpMask = highest bits of each bitfield
// OptLoMask = all bit of bitfield excluded highest bit
// BitFieldAddUpper | ((X&OptLoMask)+Y) ^ ((X&OptUpMask))
auto OptBitFieldAdd = m_c_Or(
m_c_Xor(m_CombineOr(
// When Y is not the constant.
m_c_Add(m_And(m_Value(X), m_APInt(OptLoMask)),
m_And(m_Value(Y), m_APInt(OptLoMask))),
// When Y is Constant, it can be accumulated.
m_c_Add(m_And(m_Value(X), m_APInt(OptLoMask)), m_Value(Y))),
// If Y is a constant, X^Y&OptUpMask can be pre-computed and
// OptUpMask is its result.
m_CombineOr(m_And(m_Deferred(X), m_APInt(OptUpMask)),
m_And(m_c_Xor(m_Deferred(X), m_Value(UpY)),
m_APInt(OptUpMask)))),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you mean to be overwriting OptUpMask here and OptLoMask above?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, but I made a mistake and fixed it

BitFieldAddUpper);

// Match bitfield operation.
if (match(&I, BitFieldAdd) || match(&I, BitFieldAddIC)) {
APInt Mask = APInt::getBitsSet(BitWidth, BitWidth - UpMask->countl_zero(),
BitWidth);

if (!((UpMask2 == nullptr || *UpMask == *UpMask2) &&
(LoMask->popcount() >= 2 && UpMask->popcount() >= 2) &&
(LoMask->isShiftedMask() && UpMask->isShiftedMask()) &&
// Lo & Hi mask must have no common bits
((*LoMask & *UpMask) == 0) &&
// These masks must fill all bits while having no common bits.
((Mask ^ *LoMask ^ *UpMask).isAllOnes())))
return std::nullopt;

if (!(Y = AccumulateY(Y, UpY, *LoMask, *UpMask)))
return std::nullopt;

return {{X, Y, false, {{LoMask, UpMask}}}};
}

// Match already optimized bitfield operation.
if (match(&I, OptBitFieldAdd)) {
APInt Mask = APInt::getBitsSet(
BitWidth, BitWidth - OptUpMask->countl_zero(), BitWidth);
APInt Mask2 = APInt::getBitsSet(
BitWidth, BitWidth - UpMask->countl_zero(), BitWidth);

// OptLoMask : includes bits of each bit field member, but excludes
// highest bit of each bit field.
// OptHiMask : includes bits only highest bit of each member.
if (!((UpMask2 == nullptr || *UpMask == *UpMask2) &&
(UpMask->isShiftedMask() && UpMask->popcount() >= 2) &&
// must have no common bit if this operation is bitfield
((*UpMask & (*OptLoMask | *OptUpMask)) == 0) &&
// NOT(OptLoMask) must be equals OptUpMask
((~*OptLoMask ^ Mask) == *OptUpMask) &&
// These masks must fill all bits while having no common bits.
(Mask2 ^ *UpMask ^ (*OptLoMask ^ *OptUpMask)).isAllOnes()))
return std::nullopt;
dtcxzyw marked this conversation as resolved.
Show resolved Hide resolved

if (!(Y = AccumulateY(Y, UpY, (*OptLoMask + *OptUpMask), *UpMask)))
return std::nullopt;

struct BitFieldAddInfo Info = {X, Y, true, {{OptLoMask, OptUpMask}}};
Info.OptMask.New = UpMask;
return {Info};
}

return std::nullopt;
};

if (std::optional<BitFieldAddInfo> Info = MatchBitFieldAdd(I)) {
Value *X = Info->X;
Value *Y = Info->Y;
APInt BitLoMask, BitUpMask;
if (Info->opt) {
unsigned NewHiBit = BitWidth - (Info->OptMask.New->countl_zero() + 1);
// BitLoMask inlude bits of OptMask.New exclude its highest bit
BitLoMask = *Info->OptMask.Lower | *Info->OptMask.New;
BitLoMask.clearBit(NewHiBit);
// BitUpMask only include highest bit of OptMask.New
BitUpMask = *Info->OptMask.Upper;
BitUpMask.setBit(NewHiBit);
} else {
// In case BitField operation, we create new optmized bitfield mask.
unsigned LowerHiBit = BitWidth - (Info->AddMask.Lower->countl_zero() + 1);
unsigned UpperHiBit = BitWidth - (Info->AddMask.Upper->countl_zero() + 1);
// BitLoMask include all bits of each bitfield but exclude its highest
// bits
BitLoMask = *Info->AddMask.Lower | *Info->AddMask.Upper;
BitLoMask.clearBit(LowerHiBit);
BitLoMask.clearBit(UpperHiBit);
// BitUpMask only include highest bit of each bitfield.
BitUpMask = APInt::getOneBitSet(BitWidth, LowerHiBit);
BitUpMask.setBit(UpperHiBit);
}

// Create optimized bitfield operation logic using the created bitmask.
Value *AndXLower = Builder.CreateAnd(X, BitLoMask);
Value *AndYLower = Builder.CreateAnd(Y, BitLoMask);
Value *Add = Builder.CreateNUWAdd(AndXLower, AndYLower);
Value *Xor1 = Builder.CreateXor(X, Y);
Value *AndUpper = Builder.CreateAnd(Xor1, BitUpMask);
Value *Xor = Builder.CreateXor(Add, AndUpper);
return Xor;
}

return nullptr;
}
dtcxzyw marked this conversation as resolved.
Show resolved Hide resolved

// FIXME: We use commutative matchers (m_c_*) for some, but not all, matches
// here. We should standardize that construct where it is needed or choose some
// other way to ensure that commutated variants of patterns are not missed.
Expand Down Expand Up @@ -4034,6 +4280,9 @@ Instruction *InstCombinerImpl::visitOr(BinaryOperator &I) {
if (Value *V = SimplifyAddWithRemainder(I))
return replaceInstUsesWith(I, V);

if (Value *Res = foldBitFieldArithmetic(I, Builder))
return replaceInstUsesWith(I, Res);

return nullptr;
}

Expand Down
Loading
Loading