Skip to content

Commit

Permalink
[mlir][spirv] Implement SPIR-V lowering for vector.deinterleave (ll…
Browse files Browse the repository at this point in the history
…vm#95313)

1. Added a conversion for `vector.deinterleave` to the `VectorToSPIRV`
pass.
2. Added LIT tests for the new conversion.

---------

Co-authored-by: Jakub Kuderski <[email protected]>
  • Loading branch information
angelz913 and kuhar authored Jun 13, 2024
1 parent b6688a0 commit 597cde1
Show file tree
Hide file tree
Showing 2 changed files with 89 additions and 3 deletions.
66 changes: 63 additions & 3 deletions mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -618,6 +618,66 @@ struct VectorInterleaveOpConvert final
}
};

struct VectorDeinterleaveOpConvert final
: public OpConversionPattern<vector::DeinterleaveOp> {
using OpConversionPattern::OpConversionPattern;

LogicalResult
matchAndRewrite(vector::DeinterleaveOp deinterleaveOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {

// Check the result vector type.
VectorType oldResultType = deinterleaveOp.getResultVectorType();
Type newResultType = getTypeConverter()->convertType(oldResultType);
if (!newResultType)
return rewriter.notifyMatchFailure(deinterleaveOp,
"unsupported result vector type");

Location loc = deinterleaveOp->getLoc();

// Deinterleave the indices.
Value sourceVector = adaptor.getSource();
VectorType sourceType = deinterleaveOp.getSourceVectorType();
int n = sourceType.getNumElements();

// Output vectors of size 1 are converted to scalars by the type converter.
// We cannot use `spirv::VectorShuffleOp` directly in this case, and need to
// use `spirv::CompositeExtractOp`.
if (n == 2) {
auto elem0 = rewriter.create<spirv::CompositeExtractOp>(
loc, newResultType, sourceVector, rewriter.getI32ArrayAttr({0}));

auto elem1 = rewriter.create<spirv::CompositeExtractOp>(
loc, newResultType, sourceVector, rewriter.getI32ArrayAttr({1}));

rewriter.replaceOp(deinterleaveOp, {elem0, elem1});
return success();
}

// Indices for `shuffleEven` (result 0).
auto seqEven = llvm::seq<int64_t>(n / 2);
auto indicesEven =
llvm::map_to_vector(seqEven, [](int i) { return i * 2; });

// Indices for `shuffleOdd` (result 1).
auto seqOdd = llvm::seq<int64_t>(n / 2);
auto indicesOdd =
llvm::map_to_vector(seqOdd, [](int i) { return i * 2 + 1; });

// Create two SPIR-V shuffles.
auto shuffleEven = rewriter.create<spirv::VectorShuffleOp>(
loc, newResultType, sourceVector, sourceVector,
rewriter.getI32ArrayAttr(indicesEven));

auto shuffleOdd = rewriter.create<spirv::VectorShuffleOp>(
loc, newResultType, sourceVector, sourceVector,
rewriter.getI32ArrayAttr(indicesOdd));

rewriter.replaceOp(deinterleaveOp, {shuffleEven, shuffleOdd});
return success();
}
};

struct VectorLoadOpConverter final
: public OpConversionPattern<vector::LoadOp> {
using OpConversionPattern::OpConversionPattern;
Expand Down Expand Up @@ -862,9 +922,9 @@ void mlir::populateVectorToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
VectorReductionFloatMinMax<CL_FLOAT_MAX_MIN_OPS>,
VectorReductionFloatMinMax<GL_FLOAT_MAX_MIN_OPS>, VectorShapeCast,
VectorInsertStridedSliceOpConvert, VectorShuffleOpConvert,
VectorInterleaveOpConvert, VectorSplatPattern, VectorLoadOpConverter,
VectorStoreOpConverter>(typeConverter, patterns.getContext(),
PatternBenefit(1));
VectorInterleaveOpConvert, VectorDeinterleaveOpConvert,
VectorSplatPattern, VectorLoadOpConverter, VectorStoreOpConverter>(
typeConverter, patterns.getContext(), PatternBenefit(1));

// Make sure that the more specialized dot product pattern has higher benefit
// than the generic one that extracts all elements.
Expand Down
26 changes: 26 additions & 0 deletions mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -507,6 +507,32 @@ func.func @interleave_size1(%a: vector<1xf32>, %b: vector<1xf32>) -> vector<2xf3

// -----

// CHECK-LABEL: func @deinterleave
// CHECK-SAME: (%[[ARG0:.+]]: vector<4xf32>)
// CHECK: %[[SHUFFLE0:.*]] = spirv.VectorShuffle [0 : i32, 2 : i32] %[[ARG0]], %[[ARG0]] : vector<4xf32>, vector<4xf32> -> vector<2xf32>
// CHECK: %[[SHUFFLE1:.*]] = spirv.VectorShuffle [1 : i32, 3 : i32] %[[ARG0]], %[[ARG0]] : vector<4xf32>, vector<4xf32> -> vector<2xf32>
// CHECK: return %[[SHUFFLE0]], %[[SHUFFLE1]]
func.func @deinterleave(%a: vector<4xf32>) -> (vector<2xf32>, vector<2xf32>) {
%0, %1 = vector.deinterleave %a : vector<4xf32> -> vector<2xf32>
return %0, %1 : vector<2xf32>, vector<2xf32>
}

// -----

// CHECK-LABEL: func @deinterleave_scalar
// CHECK-SAME: (%[[ARG0:.+]]: vector<2xf32>)
// CHECK: %[[EXTRACT0:.*]] = spirv.CompositeExtract %[[ARG0]][0 : i32] : vector<2xf32>
// CHECK: %[[EXTRACT1:.*]] = spirv.CompositeExtract %[[ARG0]][1 : i32] : vector<2xf32>
// CHECK: %[[CAST0:.*]] = builtin.unrealized_conversion_cast %[[EXTRACT0]] : f32 to vector<1xf32>
// CHECK: %[[CAST1:.*]] = builtin.unrealized_conversion_cast %[[EXTRACT1]] : f32 to vector<1xf32>
// CHECK: return %[[CAST0]], %[[CAST1]]
func.func @deinterleave_scalar(%a: vector<2xf32>) -> (vector<1xf32>, vector<1xf32>) {
%0, %1 = vector.deinterleave %a: vector<2xf32> -> vector<1xf32>
return %0, %1 : vector<1xf32>, vector<1xf32>
}

// -----

// CHECK-LABEL: func @reduction_add
// CHECK-SAME: (%[[V:.+]]: vector<4xi32>)
// CHECK: %[[S0:.+]] = spirv.CompositeExtract %[[V]][0 : i32] : vector<4xi32>
Expand Down

0 comments on commit 597cde1

Please sign in to comment.