Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[GlobalIsel] Combine G_ADD and G_SUB with constants #97771

Merged
merged 6 commits into from
Aug 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions llvm/include/llvm/CodeGen/GlobalISel/CombinerHelper.h
Original file line number Diff line number Diff line change
Expand Up @@ -892,6 +892,16 @@ class CombinerHelper {

bool matchCastOfSelect(const MachineInstr &Cast, const MachineInstr &SelectMI,
BuildFnTy &MatchInfo);
bool matchFoldAPlusC1MinusC2(const MachineInstr &MI, BuildFnTy &MatchInfo);

bool matchFoldC2MinusAPlusC1(const MachineInstr &MI, BuildFnTy &MatchInfo);

bool matchFoldAMinusC1MinusC2(const MachineInstr &MI, BuildFnTy &MatchInfo);

bool matchFoldC1Minus2MinusC2(const MachineInstr &MI, BuildFnTy &MatchInfo);

// fold ((A-C1)+C2) -> (A+(C2-C1))
bool matchFoldAMinusC1PlusC2(const MachineInstr &MI, BuildFnTy &MatchInfo);

private:
/// Checks for legality of an indexed variant of \p LdSt.
Expand Down
4 changes: 4 additions & 0 deletions llvm/include/llvm/CodeGen/GlobalISel/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
#include "llvm/IR/DebugLoc.h"
#include "llvm/Support/Alignment.h"
#include "llvm/Support/Casting.h"

#include <cstdint>

namespace llvm {
Expand Down Expand Up @@ -178,6 +179,9 @@ std::optional<APInt> getIConstantVRegVal(Register VReg,
std::optional<int64_t> getIConstantVRegSExtVal(Register VReg,
const MachineRegisterInfo &MRI);

/// \p VReg is defined by a G_CONSTANT, return the corresponding value.
APInt getIConstantFromReg(Register VReg, const MachineRegisterInfo &MRI);

/// Simple struct used to hold a constant integer value and a virtual
/// register.
struct ValueAndVReg {
Expand Down
57 changes: 56 additions & 1 deletion llvm/include/llvm/Target/GlobalISel/Combine.td
Original file line number Diff line number Diff line change
Expand Up @@ -1747,6 +1747,56 @@ def APlusBMinusCPlusA : GICombineRule<
(G_ADD $root, $A, $sub1)),
(apply (G_SUB $root, $B, $C))>;

// fold (A+C1)-C2 -> A+(C1-C2)
def APlusC1MinusC2: GICombineRule<
(defs root:$root, build_fn_matchinfo:$matchinfo),
(match (G_CONSTANT $c2, $imm2),
(G_CONSTANT $c1, $imm1),
(G_ADD $add, $A, $c1),
(G_SUB $root, $add, $c2):$root,
[{ return Helper.matchFoldAPlusC1MinusC2(*${root}, ${matchinfo}); }]),
(apply [{ Helper.applyBuildFn(*${root}, ${matchinfo}); }])>;

// fold C2-(A+C1) -> (C2-C1)-A
def C2MinusAPlusC1: GICombineRule<
(defs root:$root, build_fn_matchinfo:$matchinfo),
(match (G_CONSTANT $c2, $imm2),
(G_CONSTANT $c1, $imm1),
(G_ADD $add, $A, $c1),
(G_SUB $root, $c2, $add):$root,
[{ return Helper.matchFoldC2MinusAPlusC1(*${root}, ${matchinfo}); }]),
(apply [{ Helper.applyBuildFn(*${root}, ${matchinfo}); }])>;

// fold (A-C1)-C2 -> A-(C1+C2)
def AMinusC1MinusC2: GICombineRule<
(defs root:$root, build_fn_matchinfo:$matchinfo),
(match (G_CONSTANT $c2, $imm2),
(G_CONSTANT $c1, $imm1),
(G_SUB $sub1, $A, $c1),
(G_SUB $root, $sub1, $c2):$root,
[{ return Helper.matchFoldAMinusC1MinusC2(*${root}, ${matchinfo}); }]),
(apply [{ Helper.applyBuildFn(*${root}, ${matchinfo}); }])>;

// fold (C1-A)-C2 -> (C1-C2)-A
def C1Minus2MinusC2: GICombineRule<
(defs root:$root, build_fn_matchinfo:$matchinfo),
(match (G_CONSTANT $c2, $imm2),
(G_CONSTANT $c1, $imm1),
(G_SUB $sub1, $c1, $A),
(G_SUB $root, $sub1, $c2):$root,
[{ return Helper.matchFoldC1Minus2MinusC2(*${root}, ${matchinfo}); }]),
(apply [{ Helper.applyBuildFn(*${root}, ${matchinfo}); }])>;

// fold ((A-C1)+C2) -> (A+(C2-C1))
def AMinusC1PlusC2: GICombineRule<
(defs root:$root, build_fn_matchinfo:$matchinfo),
(match (G_CONSTANT $c2, $imm2),
(G_CONSTANT $c1, $imm1),
(G_SUB $sub, $A, $c1),
(G_ADD $root, $sub, $c2):$root,
[{ return Helper.matchFoldAMinusC1PlusC2(*${root}, ${matchinfo}); }]),
(apply [{ Helper.applyBuildFn(*${root}, ${matchinfo}); }])>;

def integer_reassoc_combines: GICombineGroup<[
ZeroMinusAPlusB,
APlusZeroMinusB,
Expand All @@ -1755,7 +1805,12 @@ def integer_reassoc_combines: GICombineGroup<[
AMinusBPlusCMinusA,
AMinusBPlusBMinusC,
APlusBMinusAplusC,
APlusBMinusCPlusA
APlusBMinusCPlusA,
APlusC1MinusC2,
C2MinusAPlusC1,
AMinusC1MinusC2,
C1Minus2MinusC2,
AMinusC1PlusC2
]>;

def freeze_of_non_undef_non_poison : GICombineRule<
Expand Down
115 changes: 115 additions & 0 deletions llvm/lib/CodeGen/GlobalISel/CombinerHelper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7433,3 +7433,118 @@ void CombinerHelper::applyExpandFPowI(MachineInstr &MI, int64_t Exponent) {
Builder.buildCopy(Dst, *Res);
MI.eraseFromParent();
}

bool CombinerHelper::matchFoldAPlusC1MinusC2(const MachineInstr &MI,
BuildFnTy &MatchInfo) {
// fold (A+C1)-C2 -> A+(C1-C2)
const GSub *Sub = cast<GSub>(&MI);
GAdd *Add = cast<GAdd>(MRI.getVRegDef(Sub->getLHSReg()));

if (!MRI.hasOneNonDBGUse(Add->getReg(0)))
return false;

APInt C2 = getIConstantFromReg(Sub->getRHSReg(), MRI);
APInt C1 = getIConstantFromReg(Add->getRHSReg(), MRI);

Register Dst = Sub->getReg(0);
LLT DstTy = MRI.getType(Dst);

MatchInfo = [=](MachineIRBuilder &B) {
auto Const = B.buildConstant(DstTy, C1 - C2);
B.buildAdd(Dst, Add->getLHSReg(), Const);
};

return true;
}

bool CombinerHelper::matchFoldC2MinusAPlusC1(const MachineInstr &MI,
BuildFnTy &MatchInfo) {
// fold C2-(A+C1) -> (C2-C1)-A
const GSub *Sub = cast<GSub>(&MI);
GAdd *Add = cast<GAdd>(MRI.getVRegDef(Sub->getRHSReg()));

if (!MRI.hasOneNonDBGUse(Add->getReg(0)))
return false;

APInt C2 = getIConstantFromReg(Sub->getLHSReg(), MRI);
APInt C1 = getIConstantFromReg(Add->getRHSReg(), MRI);

Register Dst = Sub->getReg(0);
LLT DstTy = MRI.getType(Dst);

MatchInfo = [=](MachineIRBuilder &B) {
auto Const = B.buildConstant(DstTy, C2 - C1);
B.buildSub(Dst, Const, Add->getLHSReg());
};

return true;
}

bool CombinerHelper::matchFoldAMinusC1MinusC2(const MachineInstr &MI,
BuildFnTy &MatchInfo) {
// fold (A-C1)-C2 -> A-(C1+C2)
const GSub *Sub1 = cast<GSub>(&MI);
GSub *Sub2 = cast<GSub>(MRI.getVRegDef(Sub1->getLHSReg()));

if (!MRI.hasOneNonDBGUse(Sub2->getReg(0)))
return false;

APInt C2 = getIConstantFromReg(Sub1->getRHSReg(), MRI);
APInt C1 = getIConstantFromReg(Sub2->getRHSReg(), MRI);

Register Dst = Sub1->getReg(0);
LLT DstTy = MRI.getType(Dst);

MatchInfo = [=](MachineIRBuilder &B) {
auto Const = B.buildConstant(DstTy, C1 + C2);
B.buildSub(Dst, Sub2->getLHSReg(), Const);
};

return true;
}

bool CombinerHelper::matchFoldC1Minus2MinusC2(const MachineInstr &MI,
BuildFnTy &MatchInfo) {
// fold (C1-A)-C2 -> (C1-C2)-A
const GSub *Sub1 = cast<GSub>(&MI);
GSub *Sub2 = cast<GSub>(MRI.getVRegDef(Sub1->getLHSReg()));

if (!MRI.hasOneNonDBGUse(Sub2->getReg(0)))
return false;

APInt C2 = getIConstantFromReg(Sub1->getRHSReg(), MRI);
APInt C1 = getIConstantFromReg(Sub2->getLHSReg(), MRI);

Register Dst = Sub1->getReg(0);
LLT DstTy = MRI.getType(Dst);

MatchInfo = [=](MachineIRBuilder &B) {
auto Const = B.buildConstant(DstTy, C1 - C2);
B.buildSub(Dst, Const, Sub2->getRHSReg());
};

return true;
}

bool CombinerHelper::matchFoldAMinusC1PlusC2(const MachineInstr &MI,
BuildFnTy &MatchInfo) {
// fold ((A-C1)+C2) -> (A+(C2-C1))
const GAdd *Add = cast<GAdd>(&MI);
GSub *Sub = cast<GSub>(MRI.getVRegDef(Add->getLHSReg()));

if (!MRI.hasOneNonDBGUse(Sub->getReg(0)))
return false;

APInt C2 = getIConstantFromReg(Add->getRHSReg(), MRI);
APInt C1 = getIConstantFromReg(Sub->getRHSReg(), MRI);

Register Dst = Add->getReg(0);
LLT DstTy = MRI.getType(Dst);

MatchInfo = [=](MachineIRBuilder &B) {
auto Const = B.buildConstant(DstTy, C2 - C1);
B.buildAdd(Dst, Sub->getLHSReg(), Const);
};

return true;
}
7 changes: 7 additions & 0 deletions llvm/lib/CodeGen/GlobalISel/Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,13 @@ std::optional<APInt> llvm::getIConstantVRegVal(Register VReg,
return ValAndVReg->Value;
}

APInt llvm::getIConstantFromReg(Register Reg, const MachineRegisterInfo &MRI) {
MachineInstr *Const = MRI.getVRegDef(Reg);
assert((Const && Const->getOpcode() == TargetOpcode::G_CONSTANT) &&
"expected a G_CONSTANT on Reg");
return Const->getOperand(1).getCImm()->getValue();
}

std::optional<int64_t>
llvm::getIConstantVRegSExtVal(Register VReg, const MachineRegisterInfo &MRI) {
std::optional<APInt> Val = getIConstantVRegVal(VReg, MRI);
Expand Down
Loading
Loading