Skip to content

Commit

Permalink
[GlobalIsel] Combine select of binops (llvm#76763)
Browse files Browse the repository at this point in the history
  • Loading branch information
tschuett authored Jan 6, 2024
1 parent 5b33cff commit 1687555
Show file tree
Hide file tree
Showing 4 changed files with 322 additions and 28 deletions.
3 changes: 3 additions & 0 deletions llvm/include/llvm/CodeGen/GlobalISel/CombinerHelper.h
Original file line number Diff line number Diff line change
Expand Up @@ -910,6 +910,9 @@ class CombinerHelper {

bool tryFoldSelectOfConstants(GSelect *Select, BuildFnTy &MatchInfo);

/// Try to fold select(cc, binop(), binop()) -> binop(select(), X)
bool tryFoldSelectOfBinOps(GSelect *Select, BuildFnTy &MatchInfo);

bool isOneOrOneSplat(Register Src, bool AllowUndefs);
bool isZeroOrZeroSplat(Register Src, bool AllowUndefs);
bool isConstantSplatVector(Register Src, int64_t SplatValue,
Expand Down
103 changes: 103 additions & 0 deletions llvm/include/llvm/CodeGen/GlobalISel/GenericMachineInstrs.h
Original file line number Diff line number Diff line change
Expand Up @@ -558,6 +558,109 @@ class GVecReduce : public GenericMachineInstr {
}
};

// Represents a binary operation, i.e, x = y op z.
class GBinOp : public GenericMachineInstr {
public:
Register getLHSReg() const { return getReg(1); }
Register getRHSReg() const { return getReg(2); }

static bool classof(const MachineInstr *MI) {
switch (MI->getOpcode()) {
// Integer.
case TargetOpcode::G_ADD:
case TargetOpcode::G_SUB:
case TargetOpcode::G_MUL:
case TargetOpcode::G_SDIV:
case TargetOpcode::G_UDIV:
case TargetOpcode::G_SREM:
case TargetOpcode::G_UREM:
case TargetOpcode::G_SMIN:
case TargetOpcode::G_SMAX:
case TargetOpcode::G_UMIN:
case TargetOpcode::G_UMAX:
// Floating point.
case TargetOpcode::G_FMINNUM:
case TargetOpcode::G_FMAXNUM:
case TargetOpcode::G_FMINNUM_IEEE:
case TargetOpcode::G_FMAXNUM_IEEE:
case TargetOpcode::G_FMINIMUM:
case TargetOpcode::G_FMAXIMUM:
case TargetOpcode::G_FADD:
case TargetOpcode::G_FSUB:
case TargetOpcode::G_FMUL:
case TargetOpcode::G_FDIV:
case TargetOpcode::G_FPOW:
// Logical.
case TargetOpcode::G_AND:
case TargetOpcode::G_OR:
case TargetOpcode::G_XOR:
return true;
default:
return false;
}
};
};

// Represents an integer binary operation.
class GIntBinOp : public GBinOp {
public:
static bool classof(const MachineInstr *MI) {
switch (MI->getOpcode()) {
case TargetOpcode::G_ADD:
case TargetOpcode::G_SUB:
case TargetOpcode::G_MUL:
case TargetOpcode::G_SDIV:
case TargetOpcode::G_UDIV:
case TargetOpcode::G_SREM:
case TargetOpcode::G_UREM:
case TargetOpcode::G_SMIN:
case TargetOpcode::G_SMAX:
case TargetOpcode::G_UMIN:
case TargetOpcode::G_UMAX:
return true;
default:
return false;
}
};
};

// Represents a floating point binary operation.
class GFBinOp : public GBinOp {
public:
static bool classof(const MachineInstr *MI) {
switch (MI->getOpcode()) {
case TargetOpcode::G_FMINNUM:
case TargetOpcode::G_FMAXNUM:
case TargetOpcode::G_FMINNUM_IEEE:
case TargetOpcode::G_FMAXNUM_IEEE:
case TargetOpcode::G_FMINIMUM:
case TargetOpcode::G_FMAXIMUM:
case TargetOpcode::G_FADD:
case TargetOpcode::G_FSUB:
case TargetOpcode::G_FMUL:
case TargetOpcode::G_FDIV:
case TargetOpcode::G_FPOW:
return true;
default:
return false;
}
};
};

// Represents a logical binary operation.
class GLogicalBinOp : public GBinOp {
public:
static bool classof(const MachineInstr *MI) {
switch (MI->getOpcode()) {
case TargetOpcode::G_AND:
case TargetOpcode::G_OR:
case TargetOpcode::G_XOR:
return true;
default:
return false;
}
};
};

} // namespace llvm

Expand Down
93 changes: 65 additions & 28 deletions llvm/lib/CodeGen/GlobalISel/CombinerHelper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6390,8 +6390,7 @@ bool CombinerHelper::tryFoldSelectOfConstants(GSelect *Select,
if (TrueValue.isZero() && FalseValue.isOne()) {
MatchInfo = [=](MachineIRBuilder &B) {
B.setInstrAndDebugLoc(*Select);
Register Inner = MRI.createGenericVirtualRegister(CondTy);
B.buildNot(Inner, Cond);
auto Inner = B.buildNot(CondTy, Cond);
B.buildZExtOrTrunc(Dest, Inner);
};
return true;
Expand All @@ -6401,8 +6400,7 @@ bool CombinerHelper::tryFoldSelectOfConstants(GSelect *Select,
if (TrueValue.isZero() && FalseValue.isAllOnes()) {
MatchInfo = [=](MachineIRBuilder &B) {
B.setInstrAndDebugLoc(*Select);
Register Inner = MRI.createGenericVirtualRegister(CondTy);
B.buildNot(Inner, Cond);
auto Inner = B.buildNot(CondTy, Cond);
B.buildSExtOrTrunc(Dest, Inner);
};
return true;
Expand All @@ -6412,8 +6410,7 @@ bool CombinerHelper::tryFoldSelectOfConstants(GSelect *Select,
if (TrueValue - 1 == FalseValue) {
MatchInfo = [=](MachineIRBuilder &B) {
B.setInstrAndDebugLoc(*Select);
Register Inner = MRI.createGenericVirtualRegister(TrueTy);
B.buildZExtOrTrunc(Inner, Cond);
auto Inner = B.buildZExtOrTrunc(TrueTy, Cond);
B.buildAdd(Dest, Inner, False);
};
return true;
Expand All @@ -6423,8 +6420,7 @@ bool CombinerHelper::tryFoldSelectOfConstants(GSelect *Select,
if (TrueValue + 1 == FalseValue) {
MatchInfo = [=](MachineIRBuilder &B) {
B.setInstrAndDebugLoc(*Select);
Register Inner = MRI.createGenericVirtualRegister(TrueTy);
B.buildSExtOrTrunc(Inner, Cond);
auto Inner = B.buildSExtOrTrunc(TrueTy, Cond);
B.buildAdd(Dest, Inner, False);
};
return true;
Expand All @@ -6434,8 +6430,7 @@ bool CombinerHelper::tryFoldSelectOfConstants(GSelect *Select,
if (TrueValue.isPowerOf2() && FalseValue.isZero()) {
MatchInfo = [=](MachineIRBuilder &B) {
B.setInstrAndDebugLoc(*Select);
Register Inner = MRI.createGenericVirtualRegister(TrueTy);
B.buildZExtOrTrunc(Inner, Cond);
auto Inner = B.buildZExtOrTrunc(TrueTy, Cond);
// The shift amount must be scalar.
LLT ShiftTy = TrueTy.isVector() ? TrueTy.getElementType() : TrueTy;
auto ShAmtC = B.buildConstant(ShiftTy, TrueValue.exactLogBase2());
Expand All @@ -6447,8 +6442,7 @@ bool CombinerHelper::tryFoldSelectOfConstants(GSelect *Select,
if (TrueValue.isAllOnes()) {
MatchInfo = [=](MachineIRBuilder &B) {
B.setInstrAndDebugLoc(*Select);
Register Inner = MRI.createGenericVirtualRegister(TrueTy);
B.buildSExtOrTrunc(Inner, Cond);
auto Inner = B.buildSExtOrTrunc(TrueTy, Cond);
B.buildOr(Dest, Inner, False, Flags);
};
return true;
Expand All @@ -6458,10 +6452,8 @@ bool CombinerHelper::tryFoldSelectOfConstants(GSelect *Select,
if (FalseValue.isAllOnes()) {
MatchInfo = [=](MachineIRBuilder &B) {
B.setInstrAndDebugLoc(*Select);
Register Not = MRI.createGenericVirtualRegister(CondTy);
B.buildNot(Not, Cond);
Register Inner = MRI.createGenericVirtualRegister(TrueTy);
B.buildSExtOrTrunc(Inner, Not);
auto Not = B.buildNot(CondTy, Cond);
auto Inner = B.buildSExtOrTrunc(TrueTy, Not);
B.buildOr(Dest, Inner, True, Flags);
};
return true;
Expand Down Expand Up @@ -6496,8 +6488,7 @@ bool CombinerHelper::tryFoldBoolSelectToLogic(GSelect *Select,
if ((Cond == True) || isOneOrOneSplat(True, /* AllowUndefs */ true)) {
MatchInfo = [=](MachineIRBuilder &B) {
B.setInstrAndDebugLoc(*Select);
Register Ext = MRI.createGenericVirtualRegister(TrueTy);
B.buildZExtOrTrunc(Ext, Cond);
auto Ext = B.buildZExtOrTrunc(TrueTy, Cond);
B.buildOr(DstReg, Ext, False, Flags);
};
return true;
Expand All @@ -6508,8 +6499,7 @@ bool CombinerHelper::tryFoldBoolSelectToLogic(GSelect *Select,
if ((Cond == False) || isZeroOrZeroSplat(False, /* AllowUndefs */ true)) {
MatchInfo = [=](MachineIRBuilder &B) {
B.setInstrAndDebugLoc(*Select);
Register Ext = MRI.createGenericVirtualRegister(TrueTy);
B.buildZExtOrTrunc(Ext, Cond);
auto Ext = B.buildZExtOrTrunc(TrueTy, Cond);
B.buildAnd(DstReg, Ext, True);
};
return true;
Expand All @@ -6520,11 +6510,9 @@ bool CombinerHelper::tryFoldBoolSelectToLogic(GSelect *Select,
MatchInfo = [=](MachineIRBuilder &B) {
B.setInstrAndDebugLoc(*Select);
// First the not.
Register Inner = MRI.createGenericVirtualRegister(CondTy);
B.buildNot(Inner, Cond);
auto Inner = B.buildNot(CondTy, Cond);
// Then an ext to match the destination register.
Register Ext = MRI.createGenericVirtualRegister(TrueTy);
B.buildZExtOrTrunc(Ext, Inner);
auto Ext = B.buildZExtOrTrunc(TrueTy, Inner);
B.buildOr(DstReg, Ext, True, Flags);
};
return true;
Expand All @@ -6535,11 +6523,9 @@ bool CombinerHelper::tryFoldBoolSelectToLogic(GSelect *Select,
MatchInfo = [=](MachineIRBuilder &B) {
B.setInstrAndDebugLoc(*Select);
// First the not.
Register Inner = MRI.createGenericVirtualRegister(CondTy);
B.buildNot(Inner, Cond);
auto Inner = B.buildNot(CondTy, Cond);
// Then an ext to match the destination register.
Register Ext = MRI.createGenericVirtualRegister(TrueTy);
B.buildZExtOrTrunc(Ext, Inner);
auto Ext = B.buildZExtOrTrunc(TrueTy, Inner);
B.buildAnd(DstReg, Ext, False);
};
return true;
Expand All @@ -6548,6 +6534,54 @@ bool CombinerHelper::tryFoldBoolSelectToLogic(GSelect *Select,
return false;
}

bool CombinerHelper::tryFoldSelectOfBinOps(GSelect *Select,
BuildFnTy &MatchInfo) {
Register DstReg = Select->getReg(0);
Register Cond = Select->getCondReg();
Register False = Select->getFalseReg();
Register True = Select->getTrueReg();
LLT DstTy = MRI.getType(DstReg);

GBinOp *LHS = getOpcodeDef<GBinOp>(True, MRI);
GBinOp *RHS = getOpcodeDef<GBinOp>(False, MRI);

// We need two binops of the same kind on the true/false registers.
if (!LHS || !RHS || LHS->getOpcode() != RHS->getOpcode())
return false;

// Note that there are no constraints on CondTy.
unsigned Flags = (LHS->getFlags() & RHS->getFlags()) | Select->getFlags();
unsigned Opcode = LHS->getOpcode();

// Fold select(cond, binop(x, y), binop(z, y))
// --> binop(select(cond, x, z), y)
if (LHS->getRHSReg() == RHS->getRHSReg()) {
MatchInfo = [=](MachineIRBuilder &B) {
B.setInstrAndDebugLoc(*Select);
auto Sel = B.buildSelect(DstTy, Cond, LHS->getLHSReg(), RHS->getLHSReg(),
Select->getFlags());
B.buildInstr(Opcode, {DstReg}, {Sel, LHS->getRHSReg()}, Flags);
};
return true;
}

// Fold select(cond, binop(x, y), binop(x, z))
// --> binop(x, select(cond, y, z))
if (LHS->getLHSReg() == RHS->getLHSReg()) {
MatchInfo = [=](MachineIRBuilder &B) {
B.setInstrAndDebugLoc(*Select);
auto Sel = B.buildSelect(DstTy, Cond, LHS->getRHSReg(), RHS->getRHSReg(),
Select->getFlags());
B.buildInstr(Opcode, {DstReg}, {LHS->getLHSReg(), Sel}, Flags);
};
return true;
}

// FIXME: use isCommutable().

return false;
}

bool CombinerHelper::matchSelect(MachineInstr &MI, BuildFnTy &MatchInfo) {
GSelect *Select = cast<GSelect>(&MI);

Expand All @@ -6557,5 +6591,8 @@ bool CombinerHelper::matchSelect(MachineInstr &MI, BuildFnTy &MatchInfo) {
if (tryFoldBoolSelectToLogic(Select, MatchInfo))
return true;

if (tryFoldSelectOfBinOps(Select, MatchInfo))
return true;

return false;
}
Loading

0 comments on commit 1687555

Please sign in to comment.