Skip to content

Commit

Permalink
[mlir][spirv] Handle scalar shuffles in vector to spirv conversion (#…
Browse files Browse the repository at this point in the history
…98809)

These may not get canonicalized before conversion to spirv and need to
be handled during vector to spirv conversion. Because spirv does not
support 1-element vectors, we can't emit `spirv.VectorShuffle` and need
to lower this to `spirv.CompositeExtract`.
  • Loading branch information
kuhar authored Jul 14, 2024
1 parent 3ccda93 commit a72eed7
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 9 deletions.
25 changes: 16 additions & 9 deletions mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -521,7 +521,7 @@ struct VectorShuffleOpConvert final
LogicalResult
matchAndRewrite(vector::ShuffleOp shuffleOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto oldResultType = shuffleOp.getResultVectorType();
VectorType oldResultType = shuffleOp.getResultVectorType();
Type newResultType = getTypeConverter()->convertType(oldResultType);
if (!newResultType)
return rewriter.notifyMatchFailure(shuffleOp,
Expand All @@ -532,20 +532,22 @@ struct VectorShuffleOpConvert final
return cast<IntegerAttr>(attr).getValue().getZExtValue();
});

auto oldV1Type = shuffleOp.getV1VectorType();
auto oldV2Type = shuffleOp.getV2VectorType();
VectorType oldV1Type = shuffleOp.getV1VectorType();
VectorType oldV2Type = shuffleOp.getV2VectorType();

// When both operands are SPIR-V vectors, emit a SPIR-V shuffle.
if (oldV1Type.getNumElements() > 1 && oldV2Type.getNumElements() > 1) {
// When both operands and the result are SPIR-V vectors, emit a SPIR-V
// shuffle.
if (oldV1Type.getNumElements() > 1 && oldV2Type.getNumElements() > 1 &&
oldResultType.getNumElements() > 1) {
rewriter.replaceOpWithNewOp<spirv::VectorShuffleOp>(
shuffleOp, newResultType, adaptor.getV1(), adaptor.getV2(),
rewriter.getI32ArrayAttr(mask));
return success();
}

// When at least one of the operands becomes a scalar after type conversion
// for SPIR-V, extract all the required elements and construct the result
// vector.
// When at least one of the operands or the result becomes a scalar after
// type conversion for SPIR-V, extract all the required elements and
// construct the result vector.
auto getElementAtIdx = [&rewriter, loc = shuffleOp.getLoc()](
Value scalarOrVec, int32_t idx) -> Value {
if (auto vecTy = dyn_cast<VectorType>(scalarOrVec.getType()))
Expand All @@ -569,9 +571,14 @@ struct VectorShuffleOpConvert final
newOperand = getElementAtIdx(vec, elementIdx);
}

// Handle the scalar result corner case.
if (newOperands.size() == 1) {
rewriter.replaceOp(shuffleOp, newOperands.front());
return success();
}

rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(
shuffleOp, newResultType, newOperands);

return success();
}
};
Expand Down
24 changes: 24 additions & 0 deletions mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -483,6 +483,30 @@ func.func @shuffle(%v0 : vector<1xi32>, %v1: vector<1xi32>) -> vector<2xi32> {

// -----

// CHECK-LABEL: func @shuffle
// CHECK-SAME: %[[ARG0:.+]]: vector<4xi32>, %[[ARG1:.+]]: vector<4xi32>
// CHECK: %[[EXTR:.+]] = spirv.CompositeExtract %[[ARG0]][0 : i32] : vector<4xi32>
// CHECK: %[[RES:.+]] = builtin.unrealized_conversion_cast %[[EXTR]] : i32 to vector<1xi32>
// CHECK: return %[[RES]] : vector<1xi32>
func.func @shuffle(%v0 : vector<4xi32>, %v1: vector<4xi32>) -> vector<1xi32> {
%shuffle = vector.shuffle %v0, %v1 [0] : vector<4xi32>, vector<4xi32>
return %shuffle : vector<1xi32>
}

// -----

// CHECK-LABEL: func @shuffle
// CHECK-SAME: %[[ARG0:.+]]: vector<4xi32>, %[[ARG1:.+]]: vector<4xi32>
// CHECK: %[[EXTR:.+]] = spirv.CompositeExtract %[[ARG1]][1 : i32] : vector<4xi32>
// CHECK: %[[RES:.+]] = builtin.unrealized_conversion_cast %[[EXTR]] : i32 to vector<1xi32>
// CHECK: return %[[RES]] : vector<1xi32>
func.func @shuffle(%v0 : vector<4xi32>, %v1: vector<4xi32>) -> vector<1xi32> {
%shuffle = vector.shuffle %v0, %v1 [5] : vector<4xi32>, vector<4xi32>
return %shuffle : vector<1xi32>
}

// -----

// CHECK-LABEL: func @interleave
// CHECK-SAME: (%[[ARG0:.+]]: vector<2xf32>, %[[ARG1:.+]]: vector<2xf32>)
// CHECK: %[[SHUFFLE:.*]] = spirv.VectorShuffle [0 : i32, 2 : i32, 1 : i32, 3 : i32] %[[ARG0]], %[[ARG1]] : vector<2xf32>, vector<2xf32> -> vector<4xf32>
Expand Down

0 comments on commit a72eed7

Please sign in to comment.