From 33af112f99fe956fb93fb2b797a141ee93956283 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Timm=20B=C3=A4der?= Date: Sun, 14 Jul 2024 10:40:29 +0200 Subject: [PATCH 1/5] [clang][Interp] Fix modifying const objects in functions calls in ctors The current frame might not be a constructor for the object we're initializing, but a parent frame might. --- clang/lib/AST/Interp/Interp.cpp | 14 ++++++++++---- clang/test/AST/Interp/records.cpp | 25 +++++++++++++++++++++++++ 2 files changed, 35 insertions(+), 4 deletions(-) diff --git a/clang/lib/AST/Interp/Interp.cpp b/clang/lib/AST/Interp/Interp.cpp index 0411fcad88ad0a..70a470021e7f22 100644 --- a/clang/lib/AST/Interp/Interp.cpp +++ b/clang/lib/AST/Interp/Interp.cpp @@ -405,10 +405,16 @@ bool CheckConst(InterpState &S, CodePtr OpPC, const Pointer &Ptr) { // The This pointer is writable in constructors and destructors, // even if isConst() returns true. - if (const Function *Func = S.Current->getFunction(); - Func && (Func->isConstructor() || Func->isDestructor()) && - Ptr.block() == S.Current->getThis().block()) { - return true; + // TODO(perf): We could be hitting this code path quite a lot in complex + // constructors. Is there a better way to do this? + if (S.Current->getFunction()) { + for (const InterpFrame *Frame = S.Current; Frame; Frame = Frame->Caller) { + if (const Function *Func = Frame->getFunction(); + Func && (Func->isConstructor() || Func->isDestructor()) && + Ptr.block() == Frame->getThis().block()) { + return true; + } + } } if (!Ptr.isBlockPointer()) diff --git a/clang/test/AST/Interp/records.cpp b/clang/test/AST/Interp/records.cpp index 4b06fc7522d45c..2fc88a0b1df6a0 100644 --- a/clang/test/AST/Interp/records.cpp +++ b/clang/test/AST/Interp/records.cpp @@ -1512,3 +1512,28 @@ namespace OnePastEndAndBack { constexpr const Base *d = c - 1; static_assert(d == &a, ""); } + +namespace BitSet { + class Bitset { + unsigned Bit = 0; + + public: + constexpr Bitset() { + int Init[2] = {1,2}; + for (auto I : Init) + set(I); + } + constexpr void set(unsigned I) { + this->Bit++; + this->Bit = 1u << 1; + } + }; + + struct ArchInfo { + Bitset DefaultExts; + }; + + constexpr ArchInfo ARMV8A = { + Bitset() + }; +} From 61a4e1e70f07c89bd890ef2bc61a818e6a321d2d Mon Sep 17 00:00:00 2001 From: Simon Pilgrim Date: Sun, 14 Jul 2024 17:18:43 +0100 Subject: [PATCH 2/5] [DAG] Add SDPatternMatch::m_SetCC and update some combines to use it (#98646) The plan is to add more TernaryOp in the future (SELECT/VSELECT and FMA in particular) --- llvm/include/llvm/CodeGen/SDPatternMatch.h | 43 ++++++++++++++++++ llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp | 45 ++++++------------- .../CodeGen/SelectionDAGPatternMatchTest.cpp | 35 +++++++++++++++ 3 files changed, 92 insertions(+), 31 deletions(-) diff --git a/llvm/include/llvm/CodeGen/SDPatternMatch.h b/llvm/include/llvm/CodeGen/SDPatternMatch.h index f39fbd95b3beb7..07204d1f48c242 100644 --- a/llvm/include/llvm/CodeGen/SDPatternMatch.h +++ b/llvm/include/llvm/CodeGen/SDPatternMatch.h @@ -447,6 +447,49 @@ template <> struct EffectiveOperands { explicit EffectiveOperands(SDValue N) : Size(N->getNumOperands()) {} }; +// === Ternary operations === +template +struct TernaryOpc_match { + unsigned Opcode; + T0_P Op0; + T1_P Op1; + T2_P Op2; + + TernaryOpc_match(unsigned Opc, const T0_P &Op0, const T1_P &Op1, + const T2_P &Op2) + : Opcode(Opc), Op0(Op0), Op1(Op1), Op2(Op2) {} + + template + bool match(const MatchContext &Ctx, SDValue N) { + if (sd_context_match(N, Ctx, m_Opc(Opcode))) { + EffectiveOperands EO(N); + assert(EO.Size == 3); + return ((Op0.match(Ctx, N->getOperand(EO.FirstIndex)) && + Op1.match(Ctx, N->getOperand(EO.FirstIndex + 1))) || + (Commutable && Op0.match(Ctx, N->getOperand(EO.FirstIndex + 1)) && + Op1.match(Ctx, N->getOperand(EO.FirstIndex)))) && + Op2.match(Ctx, N->getOperand(EO.FirstIndex + 2)); + } + + return false; + } +}; + +template +inline TernaryOpc_match +m_SetCC(const T0_P &Op0, const T1_P &Op1, const T2_P &Op2) { + return TernaryOpc_match(ISD::SETCC, Op0, Op1, + Op2); +} + +template +inline TernaryOpc_match +m_c_SetCC(const T0_P &Op0, const T1_P &Op1, const T2_P &Op2) { + return TernaryOpc_match(ISD::SETCC, Op0, Op1, + Op2); +} + // === Binary operations === template diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp index cece76f6583077..2f1bcc9bed88b7 100644 --- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp @@ -2300,24 +2300,12 @@ static bool isTruncateOf(SelectionDAG &DAG, SDValue N, SDValue &Op, return true; } - if (N.getOpcode() != ISD::SETCC || - N.getValueType().getScalarType() != MVT::i1 || - cast(N.getOperand(2))->get() != ISD::SETNE) - return false; - - SDValue Op0 = N->getOperand(0); - SDValue Op1 = N->getOperand(1); - assert(Op0.getValueType() == Op1.getValueType()); - - if (isNullOrNullSplat(Op0)) - Op = Op1; - else if (isNullOrNullSplat(Op1)) - Op = Op0; - else + if (N.getValueType().getScalarType() != MVT::i1 || + !sd_match( + N, m_c_SetCC(m_Value(Op), m_Zero(), m_SpecificCondCode(ISD::SETNE)))) return false; Known = DAG.computeKnownBits(Op); - return (Known.Zero | 1).isAllOnes(); } @@ -2544,16 +2532,12 @@ static SDValue foldAddSubBoolOfMaskedVal(SDNode *N, const SDLoc &DL, return SDValue(); // Match the zext operand as a setcc of a boolean. - if (Z.getOperand(0).getOpcode() != ISD::SETCC || - Z.getOperand(0).getValueType() != MVT::i1) + if (Z.getOperand(0).getValueType() != MVT::i1) return SDValue(); // Match the compare as: setcc (X & 1), 0, eq. - SDValue SetCC = Z.getOperand(0); - ISD::CondCode CC = cast(SetCC->getOperand(2))->get(); - if (CC != ISD::SETEQ || !isNullConstant(SetCC.getOperand(1)) || - SetCC.getOperand(0).getOpcode() != ISD::AND || - !isOneConstant(SetCC.getOperand(0).getOperand(1))) + if (!sd_match(Z.getOperand(0), m_SetCC(m_And(m_Value(), m_One()), m_Zero(), + m_SpecificCondCode(ISD::SETEQ)))) return SDValue(); // We are adding/subtracting a constant and an inverted low bit. Turn that @@ -2561,9 +2545,9 @@ static SDValue foldAddSubBoolOfMaskedVal(SDNode *N, const SDLoc &DL, // add (zext i1 (seteq (X & 1), 0)), C --> sub C+1, (zext (X & 1)) // sub C, (zext i1 (seteq (X & 1), 0)) --> add C-1, (zext (X & 1)) EVT VT = C.getValueType(); - SDValue LowBit = DAG.getZExtOrTrunc(SetCC.getOperand(0), DL, VT); - SDValue C1 = IsAdd ? DAG.getConstant(CN->getAPIntValue() + 1, DL, VT) : - DAG.getConstant(CN->getAPIntValue() - 1, DL, VT); + SDValue LowBit = DAG.getZExtOrTrunc(Z.getOperand(0).getOperand(0), DL, VT); + SDValue C1 = IsAdd ? DAG.getConstant(CN->getAPIntValue() + 1, DL, VT) + : DAG.getConstant(CN->getAPIntValue() - 1, DL, VT); return DAG.getNode(IsAdd ? ISD::SUB : ISD::ADD, DL, VT, C1, LowBit); } @@ -11554,13 +11538,12 @@ static SDValue foldVSelectToSignBitSplatMask(SDNode *N, SelectionDAG &DAG) { SDValue N1 = N->getOperand(1); SDValue N2 = N->getOperand(2); EVT VT = N->getValueType(0); - if (N0.getOpcode() != ISD::SETCC || !N0.hasOneUse()) - return SDValue(); - SDValue Cond0 = N0.getOperand(0); - SDValue Cond1 = N0.getOperand(1); - ISD::CondCode CC = cast(N0.getOperand(2))->get(); - if (VT != Cond0.getValueType()) + SDValue Cond0, Cond1; + ISD::CondCode CC; + if (!sd_match(N0, m_OneUse(m_SetCC(m_Value(Cond0), m_Value(Cond1), + m_CondCode(CC)))) || + VT != Cond0.getValueType()) return SDValue(); // Match a signbit check of Cond0 as "Cond0 s<0". Swap select operands if the diff --git a/llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp b/llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp index 46c385a0bc050e..a3d5e5f94b6109 100644 --- a/llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp +++ b/llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp @@ -119,6 +119,41 @@ TEST_F(SelectionDAGPatternMatchTest, matchValueType) { EXPECT_FALSE(sd_match(Op2, m_ScalableVectorVT())); } +TEST_F(SelectionDAGPatternMatchTest, matchTernaryOp) { + SDLoc DL; + auto Int32VT = EVT::getIntegerVT(Context, 32); + + SDValue Op0 = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 1, Int32VT); + SDValue Op1 = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 2, Int32VT); + + SDValue ICMP_UGT = DAG->getSetCC(DL, MVT::i1, Op0, Op1, ISD::SETUGT); + SDValue ICMP_EQ01 = DAG->getSetCC(DL, MVT::i1, Op0, Op1, ISD::SETEQ); + SDValue ICMP_EQ10 = DAG->getSetCC(DL, MVT::i1, Op1, Op0, ISD::SETEQ); + + using namespace SDPatternMatch; + ISD::CondCode CC; + EXPECT_TRUE(sd_match(ICMP_UGT, m_SetCC(m_Value(), m_Value(), + m_SpecificCondCode(ISD::SETUGT)))); + EXPECT_TRUE( + sd_match(ICMP_UGT, m_SetCC(m_Value(), m_Value(), m_CondCode(CC)))); + EXPECT_TRUE(CC == ISD::SETUGT); + EXPECT_FALSE(sd_match( + ICMP_UGT, m_SetCC(m_Value(), m_Value(), m_SpecificCondCode(ISD::SETLE)))); + + EXPECT_TRUE(sd_match(ICMP_EQ01, m_SetCC(m_Specific(Op0), m_Specific(Op1), + m_SpecificCondCode(ISD::SETEQ)))); + EXPECT_TRUE(sd_match(ICMP_EQ10, m_SetCC(m_Specific(Op1), m_Specific(Op0), + m_SpecificCondCode(ISD::SETEQ)))); + EXPECT_FALSE(sd_match(ICMP_EQ01, m_SetCC(m_Specific(Op1), m_Specific(Op0), + m_SpecificCondCode(ISD::SETEQ)))); + EXPECT_FALSE(sd_match(ICMP_EQ10, m_SetCC(m_Specific(Op0), m_Specific(Op1), + m_SpecificCondCode(ISD::SETEQ)))); + EXPECT_TRUE(sd_match(ICMP_EQ01, m_c_SetCC(m_Specific(Op1), m_Specific(Op0), + m_SpecificCondCode(ISD::SETEQ)))); + EXPECT_TRUE(sd_match(ICMP_EQ10, m_c_SetCC(m_Specific(Op0), m_Specific(Op1), + m_SpecificCondCode(ISD::SETEQ)))); +} + TEST_F(SelectionDAGPatternMatchTest, matchBinaryOp) { SDLoc DL; auto Int32VT = EVT::getIntegerVT(Context, 32); From 3aae4caffa3134d4edd1811fd2c35cbc95eb7441 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Timm=20B=C3=A4der?= Date: Sun, 14 Jul 2024 17:45:57 +0200 Subject: [PATCH 3/5] [clang][Interp] Improve InterpFrame::describe() Use getNameForDiagnostic(), like the CallStackFrame of the current interpreter. --- clang/lib/AST/Interp/InterpFrame.cpp | 5 ++++- clang/test/AST/Interp/literals.cpp | 30 ++++++++++------------------ 2 files changed, 14 insertions(+), 21 deletions(-) diff --git a/clang/lib/AST/Interp/InterpFrame.cpp b/clang/lib/AST/Interp/InterpFrame.cpp index b33f74dfe99f1c..383380f485e03b 100644 --- a/clang/lib/AST/Interp/InterpFrame.cpp +++ b/clang/lib/AST/Interp/InterpFrame.cpp @@ -167,7 +167,10 @@ void InterpFrame::describe(llvm::raw_ostream &OS) const { print(OS, This, S.getCtx(), S.getCtx().getRecordType(M->getParent())); OS << "->"; } - OS << *F << "("; + + F->getNameForDiagnostic(OS, S.getCtx().getPrintingPolicy(), + /*Qualified=*/false); + OS << '('; unsigned Off = 0; Off += Func->hasRVO() ? primSize(PT_Ptr) : 0; diff --git a/clang/test/AST/Interp/literals.cpp b/clang/test/AST/Interp/literals.cpp index 93e7e8b52a4540..af5bcb6d48ae7e 100644 --- a/clang/test/AST/Interp/literals.cpp +++ b/clang/test/AST/Interp/literals.cpp @@ -568,37 +568,27 @@ namespace IncDec { return 1; } static_assert(uninit(), ""); // both-error {{not an integral constant expression}} \ - // ref-note {{in call to 'uninit()'}} \ - // expected-note {{in call to 'uninit()'}} + // both-note {{in call to 'uninit()'}} static_assert(uninit(), ""); // both-error {{not an integral constant expression}} \ - // ref-note {{in call to 'uninit()'}} \ - // expected-note {{in call to 'uninit()'}} + // both-note {{in call to 'uninit()'}} static_assert(uninit(), ""); // both-error {{not an integral constant expression}} \ - // ref-note {{in call to 'uninit()'}} \ - // expected-note {{in call to 'uninit()'}} + // both-note {{in call to 'uninit()'}} static_assert(uninit(), ""); // both-error {{not an integral constant expression}} \ - // ref-note {{in call to 'uninit()'}} \ - // expected-note {{in call to 'uninit()'}} + // both-note {{in call to 'uninit()'}} static_assert(uninit(), ""); // both-error {{not an integral constant expression}} \ - // ref-note {{in call to 'uninit()'}} \ - // expected-note {{in call to 'uninit()'}} + // both-note {{in call to 'uninit()'}} static_assert(uninit(), ""); // both-error {{not an integral constant expression}} \ - // ref-note {{in call to 'uninit()'}} \ - // expected-note {{in call to 'uninit()'}} + // both-note {{in call to 'uninit()'}} static_assert(uninit(), ""); // both-error {{not an integral constant expression}} \ - // ref-note {{in call to 'uninit()'}} \ - // expected-note {{in call to 'uninit()'}} + // both-note {{in call to 'uninit()'}} static_assert(uninit(), ""); // both-error {{not an integral constant expression}} \ - // ref-note {{in call to 'uninit()'}} \ - // expected-note {{in call to 'uninit()'}} + // both-note {{in call to 'uninit()'}} static_assert(uninit(), ""); // both-error {{not an integral constant expression}} \ - // ref-note {{in call to 'uninit()'}} \ - // expected-note {{in call to 'uninit()'}} + // both-note {{in call to 'uninit()'}} static_assert(uninit(), ""); // both-error {{not an integral constant expression}} \ - // ref-note {{in call to 'uninit()'}} \ - // expected-note {{in call to 'uninit()'}} + // both-note {{in call to 'uninit()'}} constexpr int OverFlow() { // both-error {{never produces a constant expression}} int a = INT_MAX; From 3ccda936710d55d819c56cf4f2cf307c2d632b63 Mon Sep 17 00:00:00 2001 From: Florian Hahn Date: Sun, 14 Jul 2024 17:23:31 +0100 Subject: [PATCH 4/5] [LAA] Update pointer-bounds cache to also consider access type. The same pointer may be accessed with different types and the bound includes the size of the accessed type to compute the end. Update the cache to correctly disambiguate between different accessed types. --- llvm/include/llvm/Analysis/LoopAccessAnalysis.h | 7 +++++-- llvm/lib/Analysis/LoopAccessAnalysis.cpp | 7 ++++--- .../different-access-types-rt-checks.ll | 8 ++------ 3 files changed, 11 insertions(+), 11 deletions(-) diff --git a/llvm/include/llvm/Analysis/LoopAccessAnalysis.h b/llvm/include/llvm/Analysis/LoopAccessAnalysis.h index f6bb044392938e..afafb74bdcb0ac 100644 --- a/llvm/include/llvm/Analysis/LoopAccessAnalysis.h +++ b/llvm/include/llvm/Analysis/LoopAccessAnalysis.h @@ -269,7 +269,8 @@ class MemoryDepChecker { const Loop *getInnermostLoop() const { return InnermostLoop; } - DenseMap> & + DenseMap, + std::pair> & getPointerBounds() { return PointerBounds; } @@ -334,7 +335,9 @@ class MemoryDepChecker { /// Mapping of SCEV expressions to their expanded pointer bounds (pair of /// start and end pointer expressions). - DenseMap> PointerBounds; + DenseMap, + std::pair> + PointerBounds; /// Check whether there is a plausible dependence between the two /// accesses. diff --git a/llvm/lib/Analysis/LoopAccessAnalysis.cpp b/llvm/lib/Analysis/LoopAccessAnalysis.cpp index 018861a665c4cd..91994f33f30463 100644 --- a/llvm/lib/Analysis/LoopAccessAnalysis.cpp +++ b/llvm/lib/Analysis/LoopAccessAnalysis.cpp @@ -206,12 +206,13 @@ RuntimeCheckingPtrGroup::RuntimeCheckingPtrGroup( static std::pair getStartAndEndForAccess( const Loop *Lp, const SCEV *PtrExpr, Type *AccessTy, PredicatedScalarEvolution &PSE, - DenseMap> - &PointerBounds) { + DenseMap, + std::pair> &PointerBounds) { ScalarEvolution *SE = PSE.getSE(); auto [Iter, Ins] = PointerBounds.insert( - {PtrExpr, {SE->getCouldNotCompute(), SE->getCouldNotCompute()}}); + {{PtrExpr, AccessTy}, + {SE->getCouldNotCompute(), SE->getCouldNotCompute()}}); if (!Ins) return Iter->second; diff --git a/llvm/test/Analysis/LoopAccessAnalysis/different-access-types-rt-checks.ll b/llvm/test/Analysis/LoopAccessAnalysis/different-access-types-rt-checks.ll index 147119289b9ec4..58844c10cdcb95 100644 --- a/llvm/test/Analysis/LoopAccessAnalysis/different-access-types-rt-checks.ll +++ b/llvm/test/Analysis/LoopAccessAnalysis/different-access-types-rt-checks.ll @@ -3,8 +3,6 @@ target datalayout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-i128:128-f80:128-n8:16:32:64-S128" -; FIXME: The runtime checks for A are based on i8 accesses, but should be based -; on i32. define void @loads_of_same_pointer_with_different_sizes1(ptr %A, ptr %B, i64 %N) { ; CHECK-LABEL: 'loads_of_same_pointer_with_different_sizes1' ; CHECK-NEXT: loop: @@ -22,7 +20,7 @@ define void @loads_of_same_pointer_with_different_sizes1(ptr %A, ptr %B, i64 %N) ; CHECK-NEXT: (Low: %B High: ((4 * %N) + %B)) ; CHECK-NEXT: Member: {%B,+,4}<%loop> ; CHECK-NEXT: Group [[GRP2]]: -; CHECK-NEXT: (Low: %A High: (%N + %A)) +; CHECK-NEXT: (Low: %A High: (3 + %N + %A)) ; CHECK-NEXT: Member: {%A,+,1}<%loop> ; CHECK-NEXT: Member: {%A,+,1}<%loop> ; CHECK-EMPTY: @@ -101,8 +99,6 @@ exit: ret void } -; FIXME: The both runtime checks for A are based on i8 accesses, but one should -; be based on i32. define void @loads_of_same_pointer_with_different_sizes_retry_with_runtime_checks(ptr %A, ptr %B, i64 %N, i64 %off) { ; CHECK-LABEL: 'loads_of_same_pointer_with_different_sizes_retry_with_runtime_checks' ; CHECK-NEXT: loop: @@ -145,7 +141,7 @@ define void @loads_of_same_pointer_with_different_sizes_retry_with_runtime_check ; CHECK-NEXT: (Low: %A High: (%N + %A)) ; CHECK-NEXT: Member: {%A,+,1}<%loop> ; CHECK-NEXT: Group [[GRP8]]: -; CHECK-NEXT: (Low: %A High: (%N + %A)) +; CHECK-NEXT: (Low: %A High: (3 + %N + %A)) ; CHECK-NEXT: Member: {%A,+,1}<%loop> ; CHECK-EMPTY: ; CHECK-NEXT: Non vectorizable stores to invariant address were not found in loop. From a72eed7a238b0087789229bf635d3c517f8e7ff1 Mon Sep 17 00:00:00 2001 From: Jakub Kuderski Date: Sun, 14 Jul 2024 12:29:47 -0400 Subject: [PATCH 5/5] [mlir][spirv] Handle scalar shuffles in vector to spirv conversion (#98809) These may not get canonicalized before conversion to spirv and need to be handled during vector to spirv conversion. Because spirv does not support 1-element vectors, we can't emit `spirv.VectorShuffle` and need to lower this to `spirv.CompositeExtract`. --- .../VectorToSPIRV/VectorToSPIRV.cpp | 25 ++++++++++++------- .../VectorToSPIRV/vector-to-spirv.mlir | 24 ++++++++++++++++++ 2 files changed, 40 insertions(+), 9 deletions(-) diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp index c9363295ec32f5..a4390447532a50 100644 --- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp +++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp @@ -521,7 +521,7 @@ struct VectorShuffleOpConvert final LogicalResult matchAndRewrite(vector::ShuffleOp shuffleOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - auto oldResultType = shuffleOp.getResultVectorType(); + VectorType oldResultType = shuffleOp.getResultVectorType(); Type newResultType = getTypeConverter()->convertType(oldResultType); if (!newResultType) return rewriter.notifyMatchFailure(shuffleOp, @@ -532,20 +532,22 @@ struct VectorShuffleOpConvert final return cast(attr).getValue().getZExtValue(); }); - auto oldV1Type = shuffleOp.getV1VectorType(); - auto oldV2Type = shuffleOp.getV2VectorType(); + VectorType oldV1Type = shuffleOp.getV1VectorType(); + VectorType oldV2Type = shuffleOp.getV2VectorType(); - // When both operands are SPIR-V vectors, emit a SPIR-V shuffle. - if (oldV1Type.getNumElements() > 1 && oldV2Type.getNumElements() > 1) { + // When both operands and the result are SPIR-V vectors, emit a SPIR-V + // shuffle. + if (oldV1Type.getNumElements() > 1 && oldV2Type.getNumElements() > 1 && + oldResultType.getNumElements() > 1) { rewriter.replaceOpWithNewOp( shuffleOp, newResultType, adaptor.getV1(), adaptor.getV2(), rewriter.getI32ArrayAttr(mask)); return success(); } - // When at least one of the operands becomes a scalar after type conversion - // for SPIR-V, extract all the required elements and construct the result - // vector. + // When at least one of the operands or the result becomes a scalar after + // type conversion for SPIR-V, extract all the required elements and + // construct the result vector. auto getElementAtIdx = [&rewriter, loc = shuffleOp.getLoc()]( Value scalarOrVec, int32_t idx) -> Value { if (auto vecTy = dyn_cast(scalarOrVec.getType())) @@ -569,9 +571,14 @@ struct VectorShuffleOpConvert final newOperand = getElementAtIdx(vec, elementIdx); } + // Handle the scalar result corner case. + if (newOperands.size() == 1) { + rewriter.replaceOp(shuffleOp, newOperands.front()); + return success(); + } + rewriter.replaceOpWithNewOp( shuffleOp, newResultType, newOperands); - return success(); } }; diff --git a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir index 0d67851dfe41de..667aad7645c51c 100644 --- a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir +++ b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir @@ -483,6 +483,30 @@ func.func @shuffle(%v0 : vector<1xi32>, %v1: vector<1xi32>) -> vector<2xi32> { // ----- +// CHECK-LABEL: func @shuffle +// CHECK-SAME: %[[ARG0:.+]]: vector<4xi32>, %[[ARG1:.+]]: vector<4xi32> +// CHECK: %[[EXTR:.+]] = spirv.CompositeExtract %[[ARG0]][0 : i32] : vector<4xi32> +// CHECK: %[[RES:.+]] = builtin.unrealized_conversion_cast %[[EXTR]] : i32 to vector<1xi32> +// CHECK: return %[[RES]] : vector<1xi32> +func.func @shuffle(%v0 : vector<4xi32>, %v1: vector<4xi32>) -> vector<1xi32> { + %shuffle = vector.shuffle %v0, %v1 [0] : vector<4xi32>, vector<4xi32> + return %shuffle : vector<1xi32> +} + +// ----- + +// CHECK-LABEL: func @shuffle +// CHECK-SAME: %[[ARG0:.+]]: vector<4xi32>, %[[ARG1:.+]]: vector<4xi32> +// CHECK: %[[EXTR:.+]] = spirv.CompositeExtract %[[ARG1]][1 : i32] : vector<4xi32> +// CHECK: %[[RES:.+]] = builtin.unrealized_conversion_cast %[[EXTR]] : i32 to vector<1xi32> +// CHECK: return %[[RES]] : vector<1xi32> +func.func @shuffle(%v0 : vector<4xi32>, %v1: vector<4xi32>) -> vector<1xi32> { + %shuffle = vector.shuffle %v0, %v1 [5] : vector<4xi32>, vector<4xi32> + return %shuffle : vector<1xi32> +} + +// ----- + // CHECK-LABEL: func @interleave // CHECK-SAME: (%[[ARG0:.+]]: vector<2xf32>, %[[ARG1:.+]]: vector<2xf32>) // CHECK: %[[SHUFFLE:.*]] = spirv.VectorShuffle [0 : i32, 2 : i32, 1 : i32, 3 : i32] %[[ARG0]], %[[ARG1]] : vector<2xf32>, vector<2xf32> -> vector<4xf32>