From a72eed7a238b0087789229bf635d3c517f8e7ff1 Mon Sep 17 00:00:00 2001 From: Jakub Kuderski Date: Sun, 14 Jul 2024 12:29:47 -0400 Subject: [PATCH] [mlir][spirv] Handle scalar shuffles in vector to spirv conversion (#98809) These may not get canonicalized before conversion to spirv and need to be handled during vector to spirv conversion. Because spirv does not support 1-element vectors, we can't emit `spirv.VectorShuffle` and need to lower this to `spirv.CompositeExtract`. --- .../VectorToSPIRV/VectorToSPIRV.cpp | 25 ++++++++++++------- .../VectorToSPIRV/vector-to-spirv.mlir | 24 ++++++++++++++++++ 2 files changed, 40 insertions(+), 9 deletions(-) diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp index c9363295ec32f5..a4390447532a50 100644 --- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp +++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp @@ -521,7 +521,7 @@ struct VectorShuffleOpConvert final LogicalResult matchAndRewrite(vector::ShuffleOp shuffleOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - auto oldResultType = shuffleOp.getResultVectorType(); + VectorType oldResultType = shuffleOp.getResultVectorType(); Type newResultType = getTypeConverter()->convertType(oldResultType); if (!newResultType) return rewriter.notifyMatchFailure(shuffleOp, @@ -532,20 +532,22 @@ struct VectorShuffleOpConvert final return cast(attr).getValue().getZExtValue(); }); - auto oldV1Type = shuffleOp.getV1VectorType(); - auto oldV2Type = shuffleOp.getV2VectorType(); + VectorType oldV1Type = shuffleOp.getV1VectorType(); + VectorType oldV2Type = shuffleOp.getV2VectorType(); - // When both operands are SPIR-V vectors, emit a SPIR-V shuffle. - if (oldV1Type.getNumElements() > 1 && oldV2Type.getNumElements() > 1) { + // When both operands and the result are SPIR-V vectors, emit a SPIR-V + // shuffle. + if (oldV1Type.getNumElements() > 1 && oldV2Type.getNumElements() > 1 && + oldResultType.getNumElements() > 1) { rewriter.replaceOpWithNewOp( shuffleOp, newResultType, adaptor.getV1(), adaptor.getV2(), rewriter.getI32ArrayAttr(mask)); return success(); } - // When at least one of the operands becomes a scalar after type conversion - // for SPIR-V, extract all the required elements and construct the result - // vector. + // When at least one of the operands or the result becomes a scalar after + // type conversion for SPIR-V, extract all the required elements and + // construct the result vector. auto getElementAtIdx = [&rewriter, loc = shuffleOp.getLoc()]( Value scalarOrVec, int32_t idx) -> Value { if (auto vecTy = dyn_cast(scalarOrVec.getType())) @@ -569,9 +571,14 @@ struct VectorShuffleOpConvert final newOperand = getElementAtIdx(vec, elementIdx); } + // Handle the scalar result corner case. + if (newOperands.size() == 1) { + rewriter.replaceOp(shuffleOp, newOperands.front()); + return success(); + } + rewriter.replaceOpWithNewOp( shuffleOp, newResultType, newOperands); - return success(); } }; diff --git a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir index 0d67851dfe41de..667aad7645c51c 100644 --- a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir +++ b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir @@ -483,6 +483,30 @@ func.func @shuffle(%v0 : vector<1xi32>, %v1: vector<1xi32>) -> vector<2xi32> { // ----- +// CHECK-LABEL: func @shuffle +// CHECK-SAME: %[[ARG0:.+]]: vector<4xi32>, %[[ARG1:.+]]: vector<4xi32> +// CHECK: %[[EXTR:.+]] = spirv.CompositeExtract %[[ARG0]][0 : i32] : vector<4xi32> +// CHECK: %[[RES:.+]] = builtin.unrealized_conversion_cast %[[EXTR]] : i32 to vector<1xi32> +// CHECK: return %[[RES]] : vector<1xi32> +func.func @shuffle(%v0 : vector<4xi32>, %v1: vector<4xi32>) -> vector<1xi32> { + %shuffle = vector.shuffle %v0, %v1 [0] : vector<4xi32>, vector<4xi32> + return %shuffle : vector<1xi32> +} + +// ----- + +// CHECK-LABEL: func @shuffle +// CHECK-SAME: %[[ARG0:.+]]: vector<4xi32>, %[[ARG1:.+]]: vector<4xi32> +// CHECK: %[[EXTR:.+]] = spirv.CompositeExtract %[[ARG1]][1 : i32] : vector<4xi32> +// CHECK: %[[RES:.+]] = builtin.unrealized_conversion_cast %[[EXTR]] : i32 to vector<1xi32> +// CHECK: return %[[RES]] : vector<1xi32> +func.func @shuffle(%v0 : vector<4xi32>, %v1: vector<4xi32>) -> vector<1xi32> { + %shuffle = vector.shuffle %v0, %v1 [5] : vector<4xi32>, vector<4xi32> + return %shuffle : vector<1xi32> +} + +// ----- + // CHECK-LABEL: func @interleave // CHECK-SAME: (%[[ARG0:.+]]: vector<2xf32>, %[[ARG1:.+]]: vector<2xf32>) // CHECK: %[[SHUFFLE:.*]] = spirv.VectorShuffle [0 : i32, 2 : i32, 1 : i32, 3 : i32] %[[ARG0]], %[[ARG1]] : vector<2xf32>, vector<2xf32> -> vector<4xf32>