From c2580afed7e55f13762d56400dc346f222ea5884 Mon Sep 17 00:00:00 2001 From: Simon Pilgrim Date: Mon, 15 Jul 2024 11:42:12 +0100 Subject: [PATCH] [X86] Convert shift+clamp -> avx2 shift folds to use SDPatternMatch::m_SetCC. NFC. --- llvm/lib/Target/X86/X86ISelLowering.cpp | 50 +++++++++++-------------- 1 file changed, 21 insertions(+), 29 deletions(-) diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp index a731541ca7778e..91a5526a82bbec 100644 --- a/llvm/lib/Target/X86/X86ISelLowering.cpp +++ b/llvm/lib/Target/X86/X86ISelLowering.cpp @@ -46193,15 +46193,13 @@ static SDValue combineSelect(SDNode *N, SelectionDAG &DAG, if (N->getOpcode() == ISD::VSELECT && (LHS.getOpcode() == ISD::SRL || LHS.getOpcode() == ISD::SHL) && supportedVectorVarShift(VT, Subtarget, LHS.getOpcode())) { - APInt SV; + using namespace llvm::SDPatternMatch; // fold select(icmp_ult(amt,BW),shl(x,amt),0) -> avx2 psllv(x,amt) // fold select(icmp_ult(amt,BW),srl(x,amt),0) -> avx2 psrlv(x,amt) - if (Cond.getOpcode() == ISD::SETCC && - Cond.getOperand(0) == LHS.getOperand(1) && - cast(Cond.getOperand(2))->get() == ISD::SETULT && - ISD::isConstantSplatVector(Cond.getOperand(1).getNode(), SV) && - ISD::isConstantSplatVectorAllZeros(RHS.getNode()) && - SV == VT.getScalarSizeInBits()) { + if (ISD::isConstantSplatVectorAllZeros(RHS.getNode()) && + sd_match(Cond, m_SetCC(m_Specific(LHS.getOperand(1)), + m_SpecificInt(VT.getScalarSizeInBits()), + m_SpecificCondCode(ISD::SETULT)))) { return DAG.getNode(LHS.getOpcode() == ISD::SRL ? X86ISD::VSRLV : X86ISD::VSHLV, DL, VT, LHS.getOperand(0), LHS.getOperand(1)); @@ -48020,10 +48018,12 @@ static SDValue combineShiftToPMULH(SDNode *N, SelectionDAG &DAG, static SDValue combineShiftLeft(SDNode *N, SelectionDAG &DAG, const X86Subtarget &Subtarget) { + using namespace llvm::SDPatternMatch; SDValue N0 = N->getOperand(0); SDValue N1 = N->getOperand(1); ConstantSDNode *N1C = dyn_cast(N1); EVT VT = N0.getValueType(); + unsigned EltSizeInBits = VT.getScalarSizeInBits(); SDLoc DL(N); // Exploits AVX2 VSHLV/VSRLV instructions for efficient unsigned vector shifts @@ -48033,21 +48033,16 @@ static SDValue combineShiftLeft(SDNode *N, SelectionDAG &DAG, SDValue Cond = N0.getOperand(0); SDValue N00 = N0.getOperand(1); SDValue N01 = N0.getOperand(2); - APInt SV; // fold shl(select(icmp_ult(amt,BW),x,0),amt) -> avx2 psllv(x,amt) - if (Cond.getOpcode() == ISD::SETCC && Cond.getOperand(0) == N1 && - cast(Cond.getOperand(2))->get() == ISD::SETULT && - ISD::isConstantSplatVector(Cond.getOperand(1).getNode(), SV) && - ISD::isConstantSplatVectorAllZeros(N01.getNode()) && - SV == VT.getScalarSizeInBits()) { + if (ISD::isConstantSplatVectorAllZeros(N01.getNode()) && + sd_match(Cond, m_SetCC(m_Specific(N1), m_SpecificInt(EltSizeInBits), + m_SpecificCondCode(ISD::SETULT)))) { return DAG.getNode(X86ISD::VSHLV, DL, VT, N00, N1); } // fold shl(select(icmp_uge(amt,BW),0,x),amt) -> avx2 psllv(x,amt) - if (Cond.getOpcode() == ISD::SETCC && Cond.getOperand(0) == N1 && - cast(Cond.getOperand(2))->get() == ISD::SETUGE && - ISD::isConstantSplatVector(Cond.getOperand(1).getNode(), SV) && - ISD::isConstantSplatVectorAllZeros(N00.getNode()) && - SV == VT.getScalarSizeInBits()) { + if (ISD::isConstantSplatVectorAllZeros(N00.getNode()) && + sd_match(Cond, m_SetCC(m_Specific(N1), m_SpecificInt(EltSizeInBits), + m_SpecificCondCode(ISD::SETUGE)))) { return DAG.getNode(X86ISD::VSHLV, DL, VT, N01, N1); } } @@ -48160,9 +48155,11 @@ static SDValue combineShiftRightArithmetic(SDNode *N, SelectionDAG &DAG, static SDValue combineShiftRightLogical(SDNode *N, SelectionDAG &DAG, TargetLowering::DAGCombinerInfo &DCI, const X86Subtarget &Subtarget) { + using namespace llvm::SDPatternMatch; SDValue N0 = N->getOperand(0); SDValue N1 = N->getOperand(1); EVT VT = N0.getValueType(); + unsigned EltSizeInBits = VT.getScalarSizeInBits(); SDLoc DL(N); if (SDValue V = combineShiftToPMULH(N, DAG, DL, Subtarget)) @@ -48175,21 +48172,16 @@ static SDValue combineShiftRightLogical(SDNode *N, SelectionDAG &DAG, SDValue Cond = N0.getOperand(0); SDValue N00 = N0.getOperand(1); SDValue N01 = N0.getOperand(2); - APInt SV; // fold srl(select(icmp_ult(amt,BW),x,0),amt) -> avx2 psrlv(x,amt) - if (Cond.getOpcode() == ISD::SETCC && Cond.getOperand(0) == N1 && - cast(Cond.getOperand(2))->get() == ISD::SETULT && - ISD::isConstantSplatVector(Cond.getOperand(1).getNode(), SV) && - ISD::isConstantSplatVectorAllZeros(N01.getNode()) && - SV == VT.getScalarSizeInBits()) { + if (ISD::isConstantSplatVectorAllZeros(N01.getNode()) && + sd_match(Cond, m_SetCC(m_Specific(N1), m_SpecificInt(EltSizeInBits), + m_SpecificCondCode(ISD::SETULT)))) { return DAG.getNode(X86ISD::VSRLV, DL, VT, N00, N1); } // fold srl(select(icmp_uge(amt,BW),0,x),amt) -> avx2 psrlv(x,amt) - if (Cond.getOpcode() == ISD::SETCC && Cond.getOperand(0) == N1 && - cast(Cond.getOperand(2))->get() == ISD::SETUGE && - ISD::isConstantSplatVector(Cond.getOperand(1).getNode(), SV) && - ISD::isConstantSplatVectorAllZeros(N00.getNode()) && - SV == VT.getScalarSizeInBits()) { + if (ISD::isConstantSplatVectorAllZeros(N00.getNode()) && + sd_match(Cond, m_SetCC(m_Specific(N1), m_SpecificInt(EltSizeInBits), + m_SpecificCondCode(ISD::SETUGE)))) { return DAG.getNode(X86ISD::VSRLV, DL, VT, N01, N1); } }