From 88bd507dc2dd9c235b54d718cf84e4ef80d94bc9 Mon Sep 17 00:00:00 2001 From: Noah Goldstein Date: Mon, 9 Sep 2024 11:07:38 -0700 Subject: [PATCH] [X86] Handle shifts + and in `LowerSELECTWithCmpZero` shifts are the same as sub where rhs == 0 is identity. and is the inverted case where: `SELECT (AND(X,1) == 0), (AND Y, Z), Y` -> `(AND Y, (OR NEG(AND(X, 1)), Z))` With -1 as the identity. Closes #107910 --- llvm/lib/Target/X86/X86ISelLowering.cpp | 60 ++++++++++++++----- .../pull-conditional-binop-through-shift.ll | 55 +++++++---------- 2 files changed, 68 insertions(+), 47 deletions(-) diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp index 99477d76f50bce..a1d466eee691c9 100644 --- a/llvm/lib/Target/X86/X86ISelLowering.cpp +++ b/llvm/lib/Target/X86/X86ISelLowering.cpp @@ -24086,36 +24086,38 @@ static SDValue LowerSELECTWithCmpZero(SDValue CmpVal, SDValue LHS, SDValue RHS, if (X86CC == X86::COND_E && CmpVal.getOpcode() == ISD::AND && isOneConstant(CmpVal.getOperand(1))) { - auto SplatLSB = [&]() { + auto SplatLSB = [&](EVT SplatVT) { // we need mask of all zeros or ones with same size of the other // operands. SDValue Neg = CmpVal; - if (CmpVT.bitsGT(VT)) - Neg = DAG.getNode(ISD::TRUNCATE, DL, VT, CmpVal); - else if (CmpVT.bitsLT(VT)) + if (CmpVT.bitsGT(SplatVT)) + Neg = DAG.getNode(ISD::TRUNCATE, DL, SplatVT, CmpVal); + else if (CmpVT.bitsLT(SplatVT)) Neg = DAG.getNode( - ISD::AND, DL, VT, - DAG.getNode(ISD::ANY_EXTEND, DL, VT, CmpVal.getOperand(0)), - DAG.getConstant(1, DL, VT)); - return DAG.getNegative(Neg, DL, VT); // -(and (x, 0x1)) + ISD::AND, DL, SplatVT, + DAG.getNode(ISD::ANY_EXTEND, DL, SplatVT, CmpVal.getOperand(0)), + DAG.getConstant(1, DL, SplatVT)); + return DAG.getNegative(Neg, DL, SplatVT); // -(and (x, 0x1)) }; // SELECT (AND(X,1) == 0), 0, -1 -> NEG(AND(X,1)) if (isNullConstant(LHS) && isAllOnesConstant(RHS)) - return SplatLSB(); + return SplatLSB(VT); // SELECT (AND(X,1) == 0), C1, C2 -> XOR(C1,AND(NEG(AND(X,1)),XOR(C1,C2)) if (!Subtarget.canUseCMOV() && isa(LHS) && isa(RHS)) { - SDValue Mask = SplatLSB(); + SDValue Mask = SplatLSB(VT); SDValue Diff = DAG.getNode(ISD::XOR, DL, VT, LHS, RHS); SDValue Flip = DAG.getNode(ISD::AND, DL, VT, Mask, Diff); return DAG.getNode(ISD::XOR, DL, VT, LHS, Flip); } SDValue Src1, Src2; - auto isIdentityPattern = [&]() { + auto isIdentityPatternZero = [&]() { switch (RHS.getOpcode()) { + default: + break; case ISD::OR: case ISD::XOR: case ISD::ADD: @@ -24125,6 +24127,9 @@ static SDValue LowerSELECTWithCmpZero(SDValue CmpVal, SDValue LHS, SDValue RHS, return true; } break; + case ISD::SHL: + case ISD::SRA: + case ISD::SRL: case ISD::SUB: if (RHS.getOperand(0) == LHS) { Src1 = RHS.getOperand(1); @@ -24136,15 +24141,40 @@ static SDValue LowerSELECTWithCmpZero(SDValue CmpVal, SDValue LHS, SDValue RHS, return false; }; + auto isIdentityPatternOnes = [&]() { + switch (LHS.getOpcode()) { + default: + break; + case ISD::AND: + if (LHS.getOperand(0) == RHS || LHS.getOperand(1) == RHS) { + Src1 = LHS.getOperand(LHS.getOperand(0) == RHS ? 1 : 0); + Src2 = RHS; + return true; + } + break; + } + return false; + }; + // Convert 'identity' patterns (iff X is 0 or 1): // SELECT (AND(X,1) == 0), Y, (OR Y, Z) -> (OR Y, (AND NEG(AND(X,1)), Z)) // SELECT (AND(X,1) == 0), Y, (XOR Y, Z) -> (XOR Y, (AND NEG(AND(X,1)), Z)) // SELECT (AND(X,1) == 0), Y, (ADD Y, Z) -> (ADD Y, (AND NEG(AND(X,1)), Z)) // SELECT (AND(X,1) == 0), Y, (SUB Y, Z) -> (SUB Y, (AND NEG(AND(X,1)), Z)) - if (!Subtarget.canUseCMOV() && isIdentityPattern()) { - SDValue Mask = SplatLSB(); - SDValue And = DAG.getNode(ISD::AND, DL, VT, Mask, Src1); // Mask & z - return DAG.getNode(RHS.getOpcode(), DL, VT, Src2, And); // y Op And + // SELECT (AND(X,1) == 0), Y, (SHL Y, Z) -> (SHL Y, (AND NEG(AND(X,1)), Z)) + // SELECT (AND(X,1) == 0), Y, (SRA Y, Z) -> (SRA Y, (AND NEG(AND(X,1)), Z)) + // SELECT (AND(X,1) == 0), Y, (SRL Y, Z) -> (SRL Y, (AND NEG(AND(X,1)), Z)) + if (!Subtarget.canUseCMOV() && isIdentityPatternZero()) { + SDValue Mask = SplatLSB(Src1.getValueType()); + SDValue And = DAG.getNode(ISD::AND, DL, Src1.getValueType(), Mask, + Src1); // Mask & z + return DAG.getNode(RHS.getOpcode(), DL, VT, Src2, And); // y Op And + } + // SELECT (AND(X,1) == 0), (AND Y, Z), Y -> (AND Y, (OR NEG(AND(X, 1)), Z)) + if (!Subtarget.canUseCMOV() && isIdentityPatternOnes()) { + SDValue Mask = SplatLSB(VT); + SDValue Or = DAG.getNode(ISD::OR, DL, VT, Mask, Src1); // Mask | z + return DAG.getNode(LHS.getOpcode(), DL, VT, Src2, Or); // y Op Or } } diff --git a/llvm/test/CodeGen/X86/pull-conditional-binop-through-shift.ll b/llvm/test/CodeGen/X86/pull-conditional-binop-through-shift.ll index 59effaf11d4e7b..8c858e04de2a14 100644 --- a/llvm/test/CodeGen/X86/pull-conditional-binop-through-shift.ll +++ b/llvm/test/CodeGen/X86/pull-conditional-binop-through-shift.ll @@ -711,15 +711,15 @@ define i32 @shl_signbit_select_add(i32 %x, i1 %cond, ptr %dst) { ; ; X86-LABEL: shl_signbit_select_add: ; X86: # %bb.0: -; X86-NEXT: movl {{[0-9]+}}(%esp), %ecx +; X86-NEXT: movzbl {{[0-9]+}}(%esp), %ecx +; X86-NEXT: andb $1, %cl +; X86-NEXT: movl {{[0-9]+}}(%esp), %edx ; X86-NEXT: movl {{[0-9]+}}(%esp), %eax -; X86-NEXT: testb $1, {{[0-9]+}}(%esp) -; X86-NEXT: je .LBB24_2 -; X86-NEXT: # %bb.1: -; X86-NEXT: shll $4, %eax -; X86-NEXT: .LBB24_2: +; X86-NEXT: negb %cl +; X86-NEXT: andb $4, %cl +; X86-NEXT: shll %cl, %eax ; X86-NEXT: addl $123456, %eax # imm = 0x1E240 -; X86-NEXT: movl %eax, (%ecx) +; X86-NEXT: movl %eax, (%edx) ; X86-NEXT: retl %t0 = shl i32 %x, 4 %t1 = select i1 %cond, i32 %t0, i32 %x @@ -772,23 +772,15 @@ define i32 @lshr_signbit_select_add(i32 %x, i1 %cond, ptr %dst, i32 %y) { ; ; X86-LABEL: lshr_signbit_select_add: ; X86: # %bb.0: -; X86-NEXT: pushl %esi -; X86-NEXT: .cfi_def_cfa_offset 8 -; X86-NEXT: .cfi_offset %esi, -8 -; X86-NEXT: movl {{[0-9]+}}(%esp), %edx ; X86-NEXT: movzbl {{[0-9]+}}(%esp), %ecx -; X86-NEXT: movl {{[0-9]+}}(%esp), %esi -; X86-NEXT: movl %esi, %eax +; X86-NEXT: andb $1, %cl +; X86-NEXT: movl {{[0-9]+}}(%esp), %edx +; X86-NEXT: movl {{[0-9]+}}(%esp), %eax +; X86-NEXT: negb %cl +; X86-NEXT: andb {{[0-9]+}}(%esp), %cl ; X86-NEXT: shrl %cl, %eax -; X86-NEXT: testb $1, {{[0-9]+}}(%esp) -; X86-NEXT: jne .LBB26_2 -; X86-NEXT: # %bb.1: -; X86-NEXT: movl %esi, %eax -; X86-NEXT: .LBB26_2: ; X86-NEXT: addl $123456, %eax # imm = 0x1E240 ; X86-NEXT: movl %eax, (%edx) -; X86-NEXT: popl %esi -; X86-NEXT: .cfi_def_cfa_offset 4 ; X86-NEXT: retl %t0 = lshr i32 %x, %y %t1 = select i1 %cond, i32 %t0, i32 %x @@ -810,15 +802,15 @@ define i32 @ashr_signbit_select_add(i32 %x, i1 %cond, ptr %dst) { ; ; X86-LABEL: ashr_signbit_select_add: ; X86: # %bb.0: -; X86-NEXT: movl {{[0-9]+}}(%esp), %ecx +; X86-NEXT: movzbl {{[0-9]+}}(%esp), %ecx +; X86-NEXT: andb $1, %cl +; X86-NEXT: movl {{[0-9]+}}(%esp), %edx ; X86-NEXT: movl {{[0-9]+}}(%esp), %eax -; X86-NEXT: testb $1, {{[0-9]+}}(%esp) -; X86-NEXT: je .LBB27_2 -; X86-NEXT: # %bb.1: -; X86-NEXT: sarl $4, %eax -; X86-NEXT: .LBB27_2: +; X86-NEXT: negb %cl +; X86-NEXT: andb $4, %cl +; X86-NEXT: sarl %cl, %eax ; X86-NEXT: addl $123456, %eax # imm = 0x1E240 -; X86-NEXT: movl %eax, (%ecx) +; X86-NEXT: movl %eax, (%edx) ; X86-NEXT: retl %t0 = ashr i32 %x, 4 %t1 = select i1 %cond, i32 %t0, i32 %x @@ -841,12 +833,11 @@ define i32 @and_signbit_select_add(i32 %x, i1 %cond, ptr %dst, i32 %y) { ; X86-LABEL: and_signbit_select_add: ; X86: # %bb.0: ; X86-NEXT: movl {{[0-9]+}}(%esp), %ecx -; X86-NEXT: movl {{[0-9]+}}(%esp), %eax -; X86-NEXT: testb $1, {{[0-9]+}}(%esp) -; X86-NEXT: jne .LBB28_2 -; X86-NEXT: # %bb.1: +; X86-NEXT: movzbl {{[0-9]+}}(%esp), %eax +; X86-NEXT: andl $1, %eax +; X86-NEXT: negl %eax +; X86-NEXT: orl {{[0-9]+}}(%esp), %eax ; X86-NEXT: andl {{[0-9]+}}(%esp), %eax -; X86-NEXT: .LBB28_2: ; X86-NEXT: addl $123456, %eax # imm = 0x1E240 ; X86-NEXT: movl %eax, (%ecx) ; X86-NEXT: retl