Skip to content

Commit

Permalink
[SelectionDAG][RISCV] Fix break of vnsrl pattern in issue #94265 (#95563
Browse files Browse the repository at this point in the history
)

Added a RISCV overload of `isTruncateFree` to fix the break of vnsrl described in issue #94265.

Fixes #94265
  • Loading branch information
Fros1er authored Jul 14, 2024
1 parent c28ddf9 commit c8dc21d
Show file tree
Hide file tree
Showing 4 changed files with 58 additions and 1 deletion.
12 changes: 11 additions & 1 deletion llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2586,6 +2586,17 @@ bool TargetLowering::SimplifyDemandedBits(
break;

if (Src.getNode()->hasOneUse()) {
if (isTruncateFree(Src, VT) &&
!isTruncateFree(Src.getValueType(), VT)) {
// If truncate is only free at trunc(srl), do not turn it into
// srl(trunc). The check is done by first check the truncate is free
// at Src's opcode(srl), then check the truncate is not done by
// referencing sub-register. In test, if both trunc(srl) and
// srl(trunc)'s trunc are free, srl(trunc) performs better. If only
// trunc(srl)'s trunc is free, trunc(srl) is better.
break;
}

std::optional<uint64_t> ShAmtC =
TLO.DAG.getValidShiftAmount(Src, DemandedElts, Depth + 2);
if (!ShAmtC || *ShAmtC >= BitWidth)
Expand All @@ -2596,7 +2607,6 @@ bool TargetLowering::SimplifyDemandedBits(
APInt::getHighBitsSet(OperandBitWidth, OperandBitWidth - BitWidth);
HighBits.lshrInPlace(ShVal);
HighBits = HighBits.trunc(BitWidth);

if (!(HighBits & DemandedBits)) {
// None of the shifted in bits are needed. Add a truncate of the
// shift input, then shift it.
Expand Down
15 changes: 15 additions & 0 deletions llvm/lib/Target/RISCV/RISCVISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1894,6 +1894,21 @@ bool RISCVTargetLowering::isTruncateFree(EVT SrcVT, EVT DstVT) const {
return (SrcBits == 64 && DestBits == 32);
}

bool RISCVTargetLowering::isTruncateFree(SDValue Val, EVT VT2) const {
EVT SrcVT = Val.getValueType();
// free truncate from vnsrl and vnsra
if (Subtarget.hasStdExtV() &&
(Val.getOpcode() == ISD::SRL || Val.getOpcode() == ISD::SRA) &&
SrcVT.isVector() && VT2.isVector()) {
unsigned SrcBits = SrcVT.getVectorElementType().getSizeInBits();
unsigned DestBits = VT2.getVectorElementType().getSizeInBits();
if (SrcBits == DestBits * 2) {
return true;
}
}
return TargetLowering::isTruncateFree(Val, VT2);
}

bool RISCVTargetLowering::isZExtFree(SDValue Val, EVT VT2) const {
// Zexts are free if they can be combined with a load.
// Don't advertise i32->i64 zextload as being free for RV64. It interacts
Expand Down
1 change: 1 addition & 0 deletions llvm/lib/Target/RISCV/RISCVISelLowering.h
Original file line number Diff line number Diff line change
Expand Up @@ -497,6 +497,7 @@ class RISCVTargetLowering : public TargetLowering {
bool isLegalAddImmediate(int64_t Imm) const override;
bool isTruncateFree(Type *SrcTy, Type *DstTy) const override;
bool isTruncateFree(EVT SrcVT, EVT DstVT) const override;
bool isTruncateFree(SDValue Val, EVT VT2) const override;
bool isZExtFree(SDValue Val, EVT VT2) const override;
bool isSExtCheaperThanZExt(EVT SrcVT, EVT DstVT) const override;
bool signExtendConstant(const ConstantInt *CI) const override;
Expand Down
31 changes: 31 additions & 0 deletions llvm/test/CodeGen/RISCV/pr94265.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
; RUN: llc < %s -mtriple=riscv32-- -mattr=+v | FileCheck -check-prefix=RV32I %s
; RUN: llc < %s -mtriple=riscv64-- -mattr=+v | FileCheck -check-prefix=RV64I %s

define <8 x i16> @PR94265(<8 x i32> %a0) #0 {
; RV32I-LABEL: PR94265:
; RV32I: # %bb.0:
; RV32I-NEXT: vsetivli zero, 8, e32, m2, ta, ma
; RV32I-NEXT: vsra.vi v10, v8, 31
; RV32I-NEXT: vsrl.vi v10, v10, 26
; RV32I-NEXT: vadd.vv v8, v8, v10
; RV32I-NEXT: vsetvli zero, zero, e16, m1, ta, ma
; RV32I-NEXT: vnsrl.wi v10, v8, 6
; RV32I-NEXT: vsll.vi v8, v10, 10
; RV32I-NEXT: ret
;
; RV64I-LABEL: PR94265:
; RV64I: # %bb.0:
; RV64I-NEXT: vsetivli zero, 8, e32, m2, ta, ma
; RV64I-NEXT: vsra.vi v10, v8, 31
; RV64I-NEXT: vsrl.vi v10, v10, 26
; RV64I-NEXT: vadd.vv v8, v8, v10
; RV64I-NEXT: vsetvli zero, zero, e16, m1, ta, ma
; RV64I-NEXT: vnsrl.wi v10, v8, 6
; RV64I-NEXT: vsll.vi v8, v10, 10
; RV64I-NEXT: ret
%t1 = sdiv <8 x i32> %a0, <i32 64, i32 64, i32 64, i32 64, i32 64, i32 64, i32 64, i32 64>
%t2 = trunc <8 x i32> %t1 to <8 x i16>
%t3 = shl <8 x i16> %t2, <i16 10, i16 10, i16 10, i16 10, i16 10, i16 10, i16 10, i16 10>
ret <8 x i16> %t3
}

0 comments on commit c8dc21d

Please sign in to comment.