Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][vector] Add deinterleave operation to vector dialect #92409

Merged
merged 3 commits into from
May 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 80 additions & 0 deletions mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -543,6 +543,86 @@ def Vector_InterleaveOp :
}];
}

class ResultIsHalfSourceVectorType<string result> : TypesMatchWith<
"the trailing dimension of the results is half the width of source trailing dimension",
"source", result,
[{
[&]() -> ::mlir::VectorType {
auto vectorType = ::llvm::cast<mlir::VectorType>($_self);
::mlir::VectorType::Builder builder(vectorType);
auto lastDim = vectorType.getRank() - 1;
auto newDimSize = vectorType.getDimSize(lastDim) / 2;;
if (newDimSize <= 0)
return vectorType; // (invalid input type)
return builder.setDim(lastDim, newDimSize);
}()
}]
>;

def SourceVectorEvenElementCount : PredOpTrait<
"the trailing dimension of the source vector has an even number of elements",
CPred<[{
[&](){
auto srcVec = getSourceVectorType();
return srcVec.getDimSize(srcVec.getRank() - 1) % 2 == 0;
}()
}]>
>;

def Vector_DeinterleaveOp :
Vector_Op<"deinterleave", [Pure,
SourceVectorEvenElementCount,
ResultIsHalfSourceVectorType<"res1">,
AllTypesMatch<["res1", "res2"]>
]> {
let summary = "constructs two vectors by deinterleaving an input vector";
let description = [{
The deinterleave operation constructs two vectors from a single input
vector. The first result vector contains the elements from even indexes
of the input, and the second contains elements from odd indexes. This is
the inverse of a `vector.interleave` operation.

Each output's trailing dimension is half of the size of the input
vector's trailing dimension. This operation requires the input vector
to have a rank > 0 and an even number of elements in its trailing
dimension.

The operation supports scalable vectors.

Example:
```mlir
%0, %1 = vector.deinterleave %a
: vector<8xi8> -> vector<4xi8>
%2, %3 = vector.deinterleave %b
: vector<2x8xi8> -> vector<2x4xi8>
%4, %5 = vector.deinterleave %c
: vector<2x8x4xi8> -> vector<2x8x2xi8>
%6, %7 = vector.deinterleave %d
: vector<[8]xf32> -> vector<[4]xf32>
%8, %9 = vector.deinterleave %e
: vector<2x[6]xf64> -> vector<2x[3]xf64>
%10, %11 = vector.deinterleave %f
: vector<2x4x[6]xf64> -> vector<2x4x[3]xf64>
```
}];

let arguments = (ins AnyVector:$source);
let results = (outs AnyVector:$res1, AnyVector:$res2);

let assemblyFormat = [{
$source attr-dict `:` type($source) `->` type($res1)
}];

let extraClassDeclaration = [{
VectorType getSourceVectorType() {
return ::llvm::cast<VectorType>(getSource().getType());
}
VectorType getResultVectorType() {
return ::llvm::cast<VectorType>(getRes1().getType());
}
}];
}

def Vector_ExtractElementOp :
Vector_Op<"extractelement", [Pure,
TypesMatchWith<"result type matches element type of vector operand",
Expand Down
56 changes: 56 additions & 0 deletions mlir/test/Dialect/Vector/invalid.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1798,3 +1798,59 @@ func.func @invalid_outerproduct1(%src : memref<?xf32>) {
// expected-error @+1 {{'vector.outerproduct' op expected 1-d vector for operand #1}}
%op = vector.outerproduct %0, %1 : vector<[4]x[4]xf32>, vector<[4]xf32>
}

// -----

func.func @deinterleave_zero_dim_fail(%vec : vector<f32>) {
// expected-error @+1 {{'vector.deinterleave' op operand #0 must be vector of any type values, but got 'vector<f32>}}
%0, %1 = vector.deinterleave %vec : vector<f32> -> vector<f32>
return
}

// -----

func.func @deinterleave_one_dim_fail(%vec : vector<1xf32>) {
// expected-error @+1 {{'vector.deinterleave' op failed to verify that the trailing dimension of the source vector has an even number of elements}}
%0, %1 = vector.deinterleave %vec : vector<1xf32> -> vector<1xf32>
return
}

// -----

func.func @deinterleave_oversized_output_fail(%vec : vector<4xf32>) {
// expected-error @+1 {{'vector.deinterleave' op failed to verify that the trailing dimension of the results is half the width of source trailing dimension}}
%0, %1 = "vector.deinterleave" (%vec) : (vector<4xf32>) -> (vector<8xf32>, vector<8xf32>)
return
}

// -----

func.func @deinterleave_output_dim_size_mismatch(%vec : vector<4xf32>) {
// expected-error @+1 {{'vector.deinterleave' op failed to verify that the trailing dimension of the results is half the width of source trailing dimension}}
%0, %1 = "vector.deinterleave" (%vec) : (vector<4xf32>) -> (vector<4xf32>, vector<2xf32>)
return
}

// -----

func.func @deinterleave_n_dim_rank_fail(%vec : vector<2x3x4xf32>) {
// expected-error @+1 {{'vector.deinterleave' op failed to verify that the trailing dimension of the results is half the width of source trailing dimension}}
%0, %1 = "vector.deinterleave" (%vec) : (vector<2x3x4xf32>) -> (vector<2x3x4xf32>, vector<2x3x2xf32>)
return
}

// -----

func.func @deinterleave_scalable_dim_size_fail(%vec : vector<2x[4]xf32>) {
// expected-error @+1 {{'vector.deinterleave' op failed to verify that all of {res1, res2} have same type}}
%0, %1 = "vector.deinterleave" (%vec) : (vector<2x[4]xf32>) -> (vector<2x[2]xf32>, vector<2x[1]xf32>)
return
}

// -----

func.func @deinterleave_scalable_rank_fail(%vec : vector<2x[4]xf32>) {
// expected-error @+1 {{'vector.deinterleave' op failed to verify that all of {res1, res2} have same type}}
%0, %1 = "vector.deinterleave" (%vec) : (vector<2x[4]xf32>) -> (vector<2x[2]xf32>, vector<[2]xf32>)
return
}
42 changes: 42 additions & 0 deletions mlir/test/Dialect/Vector/ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1116,3 +1116,45 @@ func.func @interleave_2d_scalable(%a: vector<2x[2]xf64>, %b: vector<2x[2]xf64>)
%0 = vector.interleave %a, %b : vector<2x[2]xf64>
return %0 : vector<2x[4]xf64>
}

// CHECK-LABEL: @deinterleave_1d
func.func @deinterleave_1d(%arg: vector<4xf32>) -> (vector<2xf32>, vector<2xf32>) {
// CHECK: vector.deinterleave %{{.*}} : vector<4xf32> -> vector<2xf32>
%0, %1 = vector.deinterleave %arg : vector<4xf32> -> vector<2xf32>
return %0, %1 : vector<2xf32>, vector<2xf32>
}

// CHECK-LABEL: @deinterleave_1d_scalable
func.func @deinterleave_1d_scalable(%arg: vector<[4]xf32>) -> (vector<[2]xf32>, vector<[2]xf32>) {
// CHECK: vector.deinterleave %{{.*}} : vector<[4]xf32> -> vector<[2]xf32>
%0, %1 = vector.deinterleave %arg : vector<[4]xf32> -> vector<[2]xf32>
return %0, %1 : vector<[2]xf32>, vector<[2]xf32>
}

// CHECK-LABEL: @deinterleave_2d
func.func @deinterleave_2d(%arg: vector<3x4xf32>) -> (vector<3x2xf32>, vector<3x2xf32>) {
// CHECK: vector.deinterleave %{{.*}} : vector<3x4xf32> -> vector<3x2xf32>
%0, %1 = vector.deinterleave %arg : vector<3x4xf32> -> vector<3x2xf32>
return %0, %1 : vector<3x2xf32>, vector<3x2xf32>
}

// CHECK-LABEL: @deinterleave_2d_scalable
func.func @deinterleave_2d_scalable(%arg: vector<3x[4]xf32>) -> (vector<3x[2]xf32>, vector<3x[2]xf32>) {
// CHECK: vector.deinterleave %{{.*}} : vector<3x[4]xf32> -> vector<3x[2]xf32>
%0, %1 = vector.deinterleave %arg : vector<3x[4]xf32> -> vector<3x[2]xf32>
return %0, %1 : vector<3x[2]xf32>, vector<3x[2]xf32>
}

// CHECK-LABEL: @deinterleave_nd
func.func @deinterleave_nd(%arg: vector<2x3x4x6xf32>) -> (vector<2x3x4x3xf32>, vector<2x3x4x3xf32>) {
// CHECK: vector.deinterleave %{{.*}} : vector<2x3x4x6xf32> -> vector<2x3x4x3xf32>
%0, %1 = vector.deinterleave %arg : vector<2x3x4x6xf32> -> vector<2x3x4x3xf32>
return %0, %1 : vector<2x3x4x3xf32>, vector<2x3x4x3xf32>
}

// CHECK-LABEL: @deinterleave_nd_scalable
func.func @deinterleave_nd_scalable(%arg:vector<2x3x4x[6]xf32>) -> (vector<2x3x4x[3]xf32>, vector<2x3x4x[3]xf32>) {
// CHECK: vector.deinterleave %{{.*}} : vector<2x3x4x[6]xf32> -> vector<2x3x4x[3]xf32>
%0, %1 = vector.deinterleave %arg : vector<2x3x4x[6]xf32> -> vector<2x3x4x[3]xf32>
return %0, %1 : vector<2x3x4x[3]xf32>, vector<2x3x4x[3]xf32>
}
Loading