Skip to content

Commit

Permalink
[X86] Convert shift+clamp -> avx2 shift folds to use SDPatternMatch::…
Browse files Browse the repository at this point in the history
…m_SetCC. NFC.
  • Loading branch information
RKSimon committed Jul 15, 2024
1 parent 054d7b1 commit c2580af
Showing 1 changed file with 21 additions and 29 deletions.
50 changes: 21 additions & 29 deletions llvm/lib/Target/X86/X86ISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46193,15 +46193,13 @@ static SDValue combineSelect(SDNode *N, SelectionDAG &DAG,
if (N->getOpcode() == ISD::VSELECT &&
(LHS.getOpcode() == ISD::SRL || LHS.getOpcode() == ISD::SHL) &&
supportedVectorVarShift(VT, Subtarget, LHS.getOpcode())) {
APInt SV;
using namespace llvm::SDPatternMatch;
// fold select(icmp_ult(amt,BW),shl(x,amt),0) -> avx2 psllv(x,amt)
// fold select(icmp_ult(amt,BW),srl(x,amt),0) -> avx2 psrlv(x,amt)
if (Cond.getOpcode() == ISD::SETCC &&
Cond.getOperand(0) == LHS.getOperand(1) &&
cast<CondCodeSDNode>(Cond.getOperand(2))->get() == ISD::SETULT &&
ISD::isConstantSplatVector(Cond.getOperand(1).getNode(), SV) &&
ISD::isConstantSplatVectorAllZeros(RHS.getNode()) &&
SV == VT.getScalarSizeInBits()) {
if (ISD::isConstantSplatVectorAllZeros(RHS.getNode()) &&
sd_match(Cond, m_SetCC(m_Specific(LHS.getOperand(1)),
m_SpecificInt(VT.getScalarSizeInBits()),
m_SpecificCondCode(ISD::SETULT)))) {
return DAG.getNode(LHS.getOpcode() == ISD::SRL ? X86ISD::VSRLV
: X86ISD::VSHLV,
DL, VT, LHS.getOperand(0), LHS.getOperand(1));
Expand Down Expand Up @@ -48020,10 +48018,12 @@ static SDValue combineShiftToPMULH(SDNode *N, SelectionDAG &DAG,

static SDValue combineShiftLeft(SDNode *N, SelectionDAG &DAG,
const X86Subtarget &Subtarget) {
using namespace llvm::SDPatternMatch;
SDValue N0 = N->getOperand(0);
SDValue N1 = N->getOperand(1);
ConstantSDNode *N1C = dyn_cast<ConstantSDNode>(N1);
EVT VT = N0.getValueType();
unsigned EltSizeInBits = VT.getScalarSizeInBits();
SDLoc DL(N);

// Exploits AVX2 VSHLV/VSRLV instructions for efficient unsigned vector shifts
Expand All @@ -48033,21 +48033,16 @@ static SDValue combineShiftLeft(SDNode *N, SelectionDAG &DAG,
SDValue Cond = N0.getOperand(0);
SDValue N00 = N0.getOperand(1);
SDValue N01 = N0.getOperand(2);
APInt SV;
// fold shl(select(icmp_ult(amt,BW),x,0),amt) -> avx2 psllv(x,amt)
if (Cond.getOpcode() == ISD::SETCC && Cond.getOperand(0) == N1 &&
cast<CondCodeSDNode>(Cond.getOperand(2))->get() == ISD::SETULT &&
ISD::isConstantSplatVector(Cond.getOperand(1).getNode(), SV) &&
ISD::isConstantSplatVectorAllZeros(N01.getNode()) &&
SV == VT.getScalarSizeInBits()) {
if (ISD::isConstantSplatVectorAllZeros(N01.getNode()) &&
sd_match(Cond, m_SetCC(m_Specific(N1), m_SpecificInt(EltSizeInBits),
m_SpecificCondCode(ISD::SETULT)))) {
return DAG.getNode(X86ISD::VSHLV, DL, VT, N00, N1);
}
// fold shl(select(icmp_uge(amt,BW),0,x),amt) -> avx2 psllv(x,amt)
if (Cond.getOpcode() == ISD::SETCC && Cond.getOperand(0) == N1 &&
cast<CondCodeSDNode>(Cond.getOperand(2))->get() == ISD::SETUGE &&
ISD::isConstantSplatVector(Cond.getOperand(1).getNode(), SV) &&
ISD::isConstantSplatVectorAllZeros(N00.getNode()) &&
SV == VT.getScalarSizeInBits()) {
if (ISD::isConstantSplatVectorAllZeros(N00.getNode()) &&
sd_match(Cond, m_SetCC(m_Specific(N1), m_SpecificInt(EltSizeInBits),
m_SpecificCondCode(ISD::SETUGE)))) {
return DAG.getNode(X86ISD::VSHLV, DL, VT, N01, N1);
}
}
Expand Down Expand Up @@ -48160,9 +48155,11 @@ static SDValue combineShiftRightArithmetic(SDNode *N, SelectionDAG &DAG,
static SDValue combineShiftRightLogical(SDNode *N, SelectionDAG &DAG,
TargetLowering::DAGCombinerInfo &DCI,
const X86Subtarget &Subtarget) {
using namespace llvm::SDPatternMatch;
SDValue N0 = N->getOperand(0);
SDValue N1 = N->getOperand(1);
EVT VT = N0.getValueType();
unsigned EltSizeInBits = VT.getScalarSizeInBits();
SDLoc DL(N);

if (SDValue V = combineShiftToPMULH(N, DAG, DL, Subtarget))
Expand All @@ -48175,21 +48172,16 @@ static SDValue combineShiftRightLogical(SDNode *N, SelectionDAG &DAG,
SDValue Cond = N0.getOperand(0);
SDValue N00 = N0.getOperand(1);
SDValue N01 = N0.getOperand(2);
APInt SV;
// fold srl(select(icmp_ult(amt,BW),x,0),amt) -> avx2 psrlv(x,amt)
if (Cond.getOpcode() == ISD::SETCC && Cond.getOperand(0) == N1 &&
cast<CondCodeSDNode>(Cond.getOperand(2))->get() == ISD::SETULT &&
ISD::isConstantSplatVector(Cond.getOperand(1).getNode(), SV) &&
ISD::isConstantSplatVectorAllZeros(N01.getNode()) &&
SV == VT.getScalarSizeInBits()) {
if (ISD::isConstantSplatVectorAllZeros(N01.getNode()) &&
sd_match(Cond, m_SetCC(m_Specific(N1), m_SpecificInt(EltSizeInBits),
m_SpecificCondCode(ISD::SETULT)))) {
return DAG.getNode(X86ISD::VSRLV, DL, VT, N00, N1);
}
// fold srl(select(icmp_uge(amt,BW),0,x),amt) -> avx2 psrlv(x,amt)
if (Cond.getOpcode() == ISD::SETCC && Cond.getOperand(0) == N1 &&
cast<CondCodeSDNode>(Cond.getOperand(2))->get() == ISD::SETUGE &&
ISD::isConstantSplatVector(Cond.getOperand(1).getNode(), SV) &&
ISD::isConstantSplatVectorAllZeros(N00.getNode()) &&
SV == VT.getScalarSizeInBits()) {
if (ISD::isConstantSplatVectorAllZeros(N00.getNode()) &&
sd_match(Cond, m_SetCC(m_Specific(N1), m_SpecificInt(EltSizeInBits),
m_SpecificCondCode(ISD::SETUGE)))) {
return DAG.getNode(X86ISD::VSRLV, DL, VT, N01, N1);
}
}
Expand Down

0 comments on commit c2580af

Please sign in to comment.