From 1db7e0cb6a13985b279058754b4cb11718e31681 Mon Sep 17 00:00:00 2001 From: Hanbum Park Date: Sat, 6 Jan 2024 16:44:52 +0900 Subject: [PATCH] [InstCombine] Improve bitfield addition (#33874) Proof: https://alive2.llvm.org/ce/z/RUL3YU Fixes #33874 --- .../InstCombine/InstCombineAndOrXor.cpp | 146 ++++++++++++++++++ 1 file changed, 146 insertions(+) diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp index 5fd944a859ef09..eea133704a2e78 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp @@ -3338,6 +3338,149 @@ Value *InstCombinerImpl::foldAndOrOfICmps(ICmpInst *LHS, ICmpInst *RHS, return foldAndOrOfICmpsUsingRanges(LHS, RHS, IsAnd); } +struct BitFieldAddBitMask { + const APInt *Lower; + const APInt *Upper; +}; +struct BitFieldOptBitMask { + const APInt *Lower; + const APInt *Upper; + const APInt *New; +}; +struct BitFieldAddInfo { + Value *X; + Value *Y; + bool opt; + union { + BitFieldAddBitMask AddMask; + BitFieldOptBitMask OptMask; + }; +}; + +static Value *foldBitFieldArithmetic(BinaryOperator &I, + InstCombiner::BuilderTy &Builder) { + auto *Disjoint = dyn_cast(&I); + if (!Disjoint || !Disjoint->isDisjoint()) + return nullptr; + + unsigned BitWidth = I.getType()->getScalarSizeInBits(); + auto AccumulateY = [&](Value *LoY, Value *UpY, APInt LoMask, + APInt UpMask) -> Value * { + Value *Y = nullptr; + auto CLoY = dyn_cast_or_null(LoY); + auto CUpY = dyn_cast_or_null(UpY); + if ((CLoY == nullptr) ^ (CUpY == nullptr)) + return nullptr; + + if (CLoY && CUpY) { + APInt IUpY = CUpY->getUniqueInteger(); + APInt ILoY = CLoY->getUniqueInteger(); + if (!(IUpY.isSubsetOf(UpMask) && ILoY.isSubsetOf(LoMask))) + return nullptr; + Y = ConstantInt::get(CLoY->getType(), ILoY + IUpY); + } else if (LoY == UpY) { + Y = LoY; + } + + return Y; + }; + + auto MatchBitFieldAdd = + [&](BinaryOperator &I) -> std::optional { + const APInt *OptLoMask, *OptUpMask, *LoMask, *UpMask, *UpMask2 = nullptr; + Value *X, *Y, *UpY; + auto BitFieldAddUpper = m_CombineOr( + m_And(m_c_Add(m_And(m_Value(X), m_APInt(UpMask)), m_Value(UpY)), + m_APInt(UpMask2)), + m_c_Add(m_And(m_Value(X), m_APInt(UpMask)), m_Value(UpY))); + auto BitFieldAdd = + m_c_Or(BitFieldAddUpper, + m_And(m_c_Add(m_Deferred(X), m_Value(Y)), m_APInt(LoMask))); + auto BitFieldAddIC = + m_c_Or(m_And(m_c_Add(m_Value(X), m_Value(Y)), m_APInt(LoMask)), + m_And(m_c_Add(m_Deferred(X), m_Value(UpY)), m_APInt(UpMask))); + auto OptBitFieldAdd = m_c_Or( + m_c_Xor(m_CombineOr( + m_c_Add(m_And(m_Value(X), m_APInt(OptLoMask)), + m_And(m_Value(Y), m_APInt(OptLoMask))), + m_c_Add(m_And(m_Value(X), m_APInt(OptLoMask)), m_Value(Y))), + m_CombineOr(m_And(m_Deferred(X), m_APInt(OptUpMask)), + m_And(m_c_Xor(m_Deferred(X), m_Value(UpY)), + m_APInt(OptUpMask)))), + BitFieldAddUpper); + + if (match(&I, BitFieldAdd) || match(&I, BitFieldAddIC)) { + APInt Mask = APInt::getBitsSet(BitWidth, BitWidth - UpMask->countl_zero(), + BitWidth); + if (!((UpMask2 == nullptr || *UpMask == *UpMask2) && + (LoMask->popcount() >= 2 && UpMask->popcount() >= 2) && + (LoMask->isShiftedMask() && UpMask->isShiftedMask()) && + ((*LoMask & *UpMask) == 0) && + ((Mask ^ *LoMask ^ *UpMask).isAllOnes()))) + return std::nullopt; + + if (!(Y = AccumulateY(Y, UpY, *LoMask, *UpMask))) + return std::nullopt; + + return {{X, Y, false, {{LoMask, UpMask}}}}; + } + + if (match(&I, OptBitFieldAdd)) { + APInt Mask = APInt::getBitsSet( + BitWidth, BitWidth - OptUpMask->countl_zero(), BitWidth); + APInt Mask2 = APInt::getBitsSet( + BitWidth, BitWidth - UpMask->countl_zero(), BitWidth); + if (!((UpMask2 == nullptr || *UpMask == *UpMask2) && + (UpMask->isShiftedMask() && UpMask->popcount() >= 2) && + ((*UpMask & (*OptLoMask | *OptUpMask)) == 0) && + ((~*OptLoMask ^ Mask) == *OptUpMask) && + (Mask2 ^ *UpMask ^ (*OptLoMask ^ *OptUpMask)).isAllOnes())) + return std::nullopt; + + if (!(Y = AccumulateY(Y, UpY, (*OptLoMask + *OptUpMask), *UpMask))) + return std::nullopt; + + struct BitFieldAddInfo Info = {X, Y, true, {{OptLoMask, OptUpMask}}}; + Info.OptMask.New = UpMask; + return {Info}; + } + + return std::nullopt; + }; + + auto Info = MatchBitFieldAdd(I); + if (Info) { + Value *X = Info->X; + Value *Y = Info->Y; + APInt BitLoMask, BitUpMask; + if (Info->opt) { + unsigned NewHiBit = BitWidth - (Info->OptMask.New->countl_zero() + 1); + BitLoMask = *Info->OptMask.Lower | *Info->OptMask.New; + BitLoMask.clearBit(NewHiBit); + BitUpMask = *Info->OptMask.Upper; + BitUpMask.setBit(NewHiBit); + } else { + unsigned LowerHiBit = BitWidth - (Info->AddMask.Lower->countl_zero() + 1); + unsigned UpperHiBit = BitWidth - (Info->AddMask.Upper->countl_zero() + 1); + BitLoMask = *Info->AddMask.Lower | *Info->AddMask.Upper; + BitLoMask.clearBit(LowerHiBit); + BitLoMask.clearBit(UpperHiBit); + BitUpMask = APInt::getOneBitSet(BitWidth, LowerHiBit); + BitUpMask.setBit(UpperHiBit); + } + + auto AndXLower = Builder.CreateAnd(X, BitLoMask); + auto AndYLower = Builder.CreateAnd(Y, BitLoMask); + auto Add = Builder.CreateNUWAdd(AndXLower, AndYLower); + auto Xor1 = Builder.CreateXor(X, Y); + auto AndUpper = Builder.CreateAnd(Xor1, BitUpMask); + auto Xor = Builder.CreateXor(Add, AndUpper); + return Xor; + } + + return nullptr; +} + // FIXME: We use commutative matchers (m_c_*) for some, but not all, matches // here. We should standardize that construct where it is needed or choose some // other way to ensure that commutated variants of patterns are not missed. @@ -3911,6 +4054,9 @@ Instruction *InstCombinerImpl::visitOr(BinaryOperator &I) { if (Instruction *Res = foldBitwiseLogicWithIntrinsics(I, Builder)) return Res; + if (Value *Res = foldBitFieldArithmetic(I, Builder)) + return replaceInstUsesWith(I, Res); + return nullptr; }