Skip to content

Commit

Permalink
[instcombine] Extend logical reduction canonicalization to scalable v…
Browse files Browse the repository at this point in the history
…ectors (#99366)

Summary:
These transformations do not depend on the type being fixed in size, so
enable them for scalable vectors too. Unlike for fixed vectors, these
are only a canonicalization - the bitcast lowering for and/or/add is not
legal on a scalable vector type.

Test Plan: 

Reviewers: 

Subscribers: 

Tasks: 

Tags: 


Differential Revision: https://phabricator.intern.facebook.com/D60250914
  • Loading branch information
preames authored and yuxuanchen1997 committed Jul 25, 2024
1 parent 740cfd1 commit 9c50606
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 15 deletions.
16 changes: 8 additions & 8 deletions llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3430,8 +3430,8 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) {
}

if (match(Arg, m_ZExtOrSExtOrSelf(m_Value(Vect)))) {
if (auto *FTy = dyn_cast<FixedVectorType>(Vect->getType()))
if (FTy->getElementType() == Builder.getInt1Ty()) {
if (auto *VTy = dyn_cast<VectorType>(Vect->getType()))
if (VTy->getElementType() == Builder.getInt1Ty()) {
Value *Res = Builder.CreateAddReduce(Vect);
if (Arg != Vect)
Res = Builder.CreateCast(cast<CastInst>(Arg)->getOpcode(), Res,
Expand Down Expand Up @@ -3460,8 +3460,8 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) {
}

if (match(Arg, m_ZExtOrSExtOrSelf(m_Value(Vect)))) {
if (auto *FTy = dyn_cast<FixedVectorType>(Vect->getType()))
if (FTy->getElementType() == Builder.getInt1Ty()) {
if (auto *VTy = dyn_cast<VectorType>(Vect->getType()))
if (VTy->getElementType() == Builder.getInt1Ty()) {
Value *Res = Builder.CreateAndReduce(Vect);
if (Res->getType() != II->getType())
Res = Builder.CreateZExt(Res, II->getType());
Expand Down Expand Up @@ -3491,8 +3491,8 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) {
}

if (match(Arg, m_ZExtOrSExtOrSelf(m_Value(Vect)))) {
if (auto *FTy = dyn_cast<FixedVectorType>(Vect->getType()))
if (FTy->getElementType() == Builder.getInt1Ty()) {
if (auto *VTy = dyn_cast<VectorType>(Vect->getType()))
if (VTy->getElementType() == Builder.getInt1Ty()) {
Value *Res = IID == Intrinsic::vector_reduce_umin
? Builder.CreateAndReduce(Vect)
: Builder.CreateOrReduce(Vect);
Expand Down Expand Up @@ -3533,8 +3533,8 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) {
}

if (match(Arg, m_ZExtOrSExtOrSelf(m_Value(Vect)))) {
if (auto *FTy = dyn_cast<FixedVectorType>(Vect->getType()))
if (FTy->getElementType() == Builder.getInt1Ty()) {
if (auto *VTy = dyn_cast<VectorType>(Vect->getType()))
if (VTy->getElementType() == Builder.getInt1Ty()) {
Instruction::CastOps ExtOpc = Instruction::CastOps::CastOpsEnd;
if (Arg != Vect)
ExtOpc = cast<CastInst>(Arg)->getOpcode();
Expand Down
14 changes: 7 additions & 7 deletions llvm/test/Transforms/InstCombine/vector-logical-reductions.ll
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ define i1 @reduction_logical_mul(<2 x i1> %x) {

define i1 @reduction_logical_mul_nxv2i1(<vscale x 2 x i1> %x) {
; CHECK-LABEL: @reduction_logical_mul_nxv2i1(
; CHECK-NEXT: [[R:%.*]] = call i1 @llvm.vector.reduce.mul.nxv2i1(<vscale x 2 x i1> [[X:%.*]])
; CHECK-NEXT: [[R:%.*]] = call i1 @llvm.vector.reduce.and.nxv2i1(<vscale x 2 x i1> [[X:%.*]])
; CHECK-NEXT: ret i1 [[R]]
;
%r = call i1 @llvm.vector.reduce.mul.nxv2i1(<vscale x 2 x i1> %x)
Expand All @@ -71,7 +71,7 @@ define i1 @reduction_logical_xor(<2 x i1> %x) {

define i1 @reduction_logical_xor_nxv2i1(<vscale x 2 x i1> %x) {
; CHECK-LABEL: @reduction_logical_xor_nxv2i1(
; CHECK-NEXT: [[R:%.*]] = call i1 @llvm.vector.reduce.xor.nxv2i1(<vscale x 2 x i1> [[X:%.*]])
; CHECK-NEXT: [[R:%.*]] = call i1 @llvm.vector.reduce.add.nxv2i1(<vscale x 2 x i1> [[X:%.*]])
; CHECK-NEXT: ret i1 [[R]]
;
%r = call i1 @llvm.vector.reduce.xor.nxv2i1(<vscale x 2 x i1> %x)
Expand All @@ -90,7 +90,7 @@ define i1 @reduction_logical_smin(<2 x i1> %x) {

define i1 @reduction_logical_smin_nxv2i1(<vscale x 2 x i1> %x) {
; CHECK-LABEL: @reduction_logical_smin_nxv2i1(
; CHECK-NEXT: [[R:%.*]] = call i1 @llvm.vector.reduce.smin.nxv2i1(<vscale x 2 x i1> [[X:%.*]])
; CHECK-NEXT: [[R:%.*]] = call i1 @llvm.vector.reduce.or.nxv2i1(<vscale x 2 x i1> [[X:%.*]])
; CHECK-NEXT: ret i1 [[R]]
;
%r = call i1 @llvm.vector.reduce.smin.nxv2i1(<vscale x 2 x i1> %x)
Expand All @@ -109,7 +109,7 @@ define i1 @reduction_logical_smax(<2 x i1> %x) {

define i1 @reduction_logical_smax_nxv2i1(<vscale x 2 x i1> %x) {
; CHECK-LABEL: @reduction_logical_smax_nxv2i1(
; CHECK-NEXT: [[R:%.*]] = call i1 @llvm.vector.reduce.smax.nxv2i1(<vscale x 2 x i1> [[X:%.*]])
; CHECK-NEXT: [[R:%.*]] = call i1 @llvm.vector.reduce.and.nxv2i1(<vscale x 2 x i1> [[X:%.*]])
; CHECK-NEXT: ret i1 [[R]]
;
%r = call i1 @llvm.vector.reduce.smax.nxv2i1(<vscale x 2 x i1> %x)
Expand All @@ -128,7 +128,7 @@ define i1 @reduction_logical_umin(<2 x i1> %x) {

define i1 @reduction_logical_umin_nxv2i1(<vscale x 2 x i1> %x) {
; CHECK-LABEL: @reduction_logical_umin_nxv2i1(
; CHECK-NEXT: [[R:%.*]] = call i1 @llvm.vector.reduce.umin.nxv2i1(<vscale x 2 x i1> [[X:%.*]])
; CHECK-NEXT: [[R:%.*]] = call i1 @llvm.vector.reduce.and.nxv2i1(<vscale x 2 x i1> [[X:%.*]])
; CHECK-NEXT: ret i1 [[R]]
;
%r = call i1 @llvm.vector.reduce.umin.nxv2i1(<vscale x 2 x i1> %x)
Expand All @@ -147,7 +147,7 @@ define i1 @reduction_logical_umax(<2 x i1> %x) {

define i1 @reduction_logical_umax_nxv2i1(<vscale x 2 x i1> %x) {
; CHECK-LABEL: @reduction_logical_umax_nxv2i1(
; CHECK-NEXT: [[R:%.*]] = call i1 @llvm.vector.reduce.umax.nxv2i1(<vscale x 2 x i1> [[X:%.*]])
; CHECK-NEXT: [[R:%.*]] = call i1 @llvm.vector.reduce.or.nxv2i1(<vscale x 2 x i1> [[X:%.*]])
; CHECK-NEXT: ret i1 [[R]]
;
%r = call i1 @llvm.vector.reduce.umax.nxv2i1(<vscale x 2 x i1> %x)
Expand Down Expand Up @@ -199,7 +199,7 @@ define i1 @reduction_logical_and_reverse_v2i1(<2 x i1> %p) {

define i1 @reduction_logical_xor_reverse_nxv2i1(<vscale x 2 x i1> %p) {
; CHECK-LABEL: @reduction_logical_xor_reverse_nxv2i1(
; CHECK-NEXT: [[RED:%.*]] = call i1 @llvm.vector.reduce.xor.nxv2i1(<vscale x 2 x i1> [[P:%.*]])
; CHECK-NEXT: [[RED:%.*]] = call i1 @llvm.vector.reduce.add.nxv2i1(<vscale x 2 x i1> [[P:%.*]])
; CHECK-NEXT: ret i1 [[RED]]
;
%rev = call <vscale x 2 x i1> @llvm.vector.reverse.nxv2i1(<vscale x 2 x i1> %p)
Expand Down

0 comments on commit 9c50606

Please sign in to comment.