Skip to content

Commit

Permalink
[mlir][VectorOps] Add deinterleave operation to vector dialect
Browse files Browse the repository at this point in the history
The deinterleave operation constructs two vectors from a single input
vector. Each new vector is the collection of even and odd elements
from the input, respectively. This is essentially the inverse of an
interleave operation.

Each output's size is half of the input vector's trailing dimension
for the n-D case and only dimension for 1-D cases. It is not possible
to conduct the operation on 0-D inputs or vectors where the size of
the (trailing) dimension is 1.

The operation supports scalable vectors.

Example:
```mlir
%0 = vector.deinterleave %a
           : vector<[4]xi32> -> vector<[2]xi32>
%1 = vector.deinterleave %b
           : vector<8xi8> -> vector<4xi8>
%2 = vector.deinterleave %c
           : vector<2x8xf32> -> vector<2x4xf32>
%3 = vector.deinterleave %d
           : vector<2x4x[6]xf64> -> vector<2x4x[3]xf64>
```
  • Loading branch information
mub-at-arm committed May 21, 2024
1 parent c396b74 commit ef12d14
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 52 deletions.
62 changes: 33 additions & 29 deletions mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -544,8 +544,8 @@ def Vector_InterleaveOp :
}

class ResultIsHalfSourceVectorType<string result> : TypesMatchWith<
"type of 'input' is double the width of results",
"input", result,
"the trailing dimension of the results is half the width of source trailing dimension",
"source", result,
[{
[&]() -> ::mlir::VectorType {
auto vectorType = ::llvm::cast<mlir::VectorType>($_self);
Expand All @@ -559,63 +559,67 @@ class ResultIsHalfSourceVectorType<string result> : TypesMatchWith<
}]
>;

def Vector_DeinterleaveOp :
Vector_Op<"deinterleave", [Pure,
PredOpTrait<"trailing dimension of input vector must be an even number",
def SourceVectorEvenElementCount : PredOpTrait<
"the trailing dimension of the source vector has an even number of elements",
CPred<[{
[&](){
auto srcVec = getSourceVectorType();
return srcVec.getDimSize(srcVec.getRank() - 1) % 2 == 0;
}()
}]>>,
}]>
>;

def Vector_DeinterleaveOp :
Vector_Op<"deinterleave", [Pure,
SourceVectorEvenElementCount,
ResultIsHalfSourceVectorType<"res1">,
ResultIsHalfSourceVectorType<"res2">,
AllTypesMatch<["res1", "res2"]>
]> {
let summary = "constructs two vectors by deinterleaving an input vector";
let description = [{
The deinterleave operation constructs two vectors from a single input
vector. Each new vector is the collection of even and odd elements
from the input, respectively. This is essentially the inverse of an
interleave operation.
vector. The first result vector contains the elements from even indexes
of the input, and the second contains elements from odd indexes. This is
the inverse of a `vector.interleave` operation.

Each output's size is half of the input vector's trailing dimension
for the n-D case and only dimension for 1-D cases. It is not possible
to conduct the operation on 0-D inputs or vectors where the size of
the (trailing) dimension is 1.
Each output's trailing dimension is half of the size of the input
vector's trailing dimension. This operation requires the input vector
to have a rank > 0 and an even number of elements in its trailing
dimension.

The operation supports scalable vectors.

Example:
```mlir
%0 = vector.deinterleave %a
: vector<[4]xi32> ; yields vector<[2]xi32>, vector<[2]xi32>
%1 = vector.deinterleave %b
: vector<8xi8> ; yields vector<4xi8>, vector<4xi8>
%2 = vector.deinterleave %c
: vector<2x8xf32> ; yields vector<2x4xf32>, vector<2x4xf32>
%3 = vector.deinterleave %d
: vector<2x4x[6]xf64> ; yields vector<2x4x[3]xf64>, vector<2x4x[3]xf64>
%0, %1 = vector.deinterleave %a
:vector<8xi8> -> vector<4xi8>
%2, %3 = vector.deinterleave %b
: vector<2x8xi8> -> vector<2x4xi8>
%4, %5 = vector.deinterleave %b
: vector<2x8x4xi8> -> vector<2x8x2xi8>
%6, %7 = vector.deinterleave %c
: vector<[8]xf32> -> vector<[4]xf32>
%8, %9 = vector.deinterleave %d
: vector<2x[6]xf64> -> vector<2x[3]xf64>
%10, %11 = vector.deinterleave %d
: vector<2x4x[6]xf64> -> vector<2x4x[3]xf64>
```
}];

let arguments = (ins AnyVector:$input);
let arguments = (ins AnyVector:$source);
let results = (outs AnyVector:$res1, AnyVector:$res2);

let assemblyFormat = [{
$input attr-dict `:` type($input)
$source attr-dict `:` type($source) `->` type($res1)
}];

let extraClassDeclaration = [{
VectorType getSourceVectorType() {
return ::llvm::cast<VectorType>(getInput().getType());
return ::llvm::cast<VectorType>(getSource().getType());
}
VectorType getResultOneVectorType() {
VectorType getResultVectorType() {
return ::llvm::cast<VectorType>(getRes1().getType());
}
VectorType getResultTwoVectorType() {
return ::llvm::cast<VectorType>(getRes2().getType());
}
}];
}

Expand Down
20 changes: 10 additions & 10 deletions mlir/test/Dialect/Vector/invalid.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1802,55 +1802,55 @@ func.func @invalid_outerproduct1(%src : memref<?xf32>) {
// -----

func.func @deinterleave_zero_dim_fail(%vec : vector<f32>) {
// expected-error @+1 {{'vector.deinterleave' 'input' must be vector of any type values, but got 'vector<f32>'}}
%0, %1 = vector.deinterleave %vec : vector<f32>
// expected-error @+1 {{'vector.deinterleave' op operand #0 must be vector of any type values, but got 'vector<f32>}}
%0, %1 = vector.deinterleave %vec : vector<f32> -> vector<f32>
return
}

// -----

func.func @deinterleave_one_dim_fail(%vec : vector<1xf32>) {
// expected-error @+1 {{'vector.deinterleave' op failed to verify that trailing dimension of input vector must be an even number}}
%0, %1 = vector.deinterleave %vec : vector<1xf32>
// expected-error @+1 {{'vector.deinterleave' op failed to verify that the trailing dimension of the source vector has an even number of elements}}
%0, %1 = vector.deinterleave %vec : vector<1xf32> -> vector<1xf32>
return
}

// -----

func.func @deinterleave_oversized_output_fail(%vec : vector<4xf32>) {
// expected-error @+1 {{'vector.deinterleave' op failed to verify that type of 'input' is double the width of results}}
// expected-error @+1 {{'vector.deinterleave' op failed to verify that the trailing dimension of the results is half the width of source trailing dimension}}
%0, %1 = "vector.deinterleave" (%vec) : (vector<4xf32>) -> (vector<8xf32>, vector<8xf32>)
return
}

// -----

func.func @deinterleave_output_dim_size_mismatch(%vec : vector<4xf32>) {
// expected-error @+1 {{'vector.deinterleave' op failed to verify that type of 'input' is double the width of results}}
// expected-error @+1 {{'vector.deinterleave' op failed to verify that the trailing dimension of the results is half the width of source trailing dimension}}
%0, %1 = "vector.deinterleave" (%vec) : (vector<4xf32>) -> (vector<4xf32>, vector<2xf32>)
return
}

// -----

func.func @deinterleave_n_dim_rank_fail(%vec : vector<2x3x4xf32>) {
// expected-error @+1 {{'vector.deinterleave' op failed to verify that type of 'input' is double the width of results}}
// expected-error @+1 {{'vector.deinterleave' op failed to verify that the trailing dimension of the results is half the width of source trailing dimension}}
%0, %1 = "vector.deinterleave" (%vec) : (vector<2x3x4xf32>) -> (vector<2x3x4xf32>, vector<2x3x2xf32>)
return
}

// -----

func.func @deinterleave_scalable_dim_size_fail(%vec : vector<2x[4]xf32>) {
// expected-error @+1 {{'vector.deinterleave' op failed to verify that type of 'input' is double the width of results}}
// expected-error @+1 {{'vector.deinterleave' op failed to verify that all of {res1, res2} have same type}}
%0, %1 = "vector.deinterleave" (%vec) : (vector<2x[4]xf32>) -> (vector<2x[2]xf32>, vector<2x[1]xf32>)
return
}

// -----

func.func @deinterleave_scalable_rank_fail(%vec : vector<2x[4]xf32>) {
// expected-error @+1 {{'vector.deinterleave' op failed to verify that type of 'input' is double the width of results}}
// expected-error @+1 {{'vector.deinterleave' op failed to verify that all of {res1, res2} have same type}}
%0, %1 = "vector.deinterleave" (%vec) : (vector<2x[4]xf32>) -> (vector<2x[2]xf32>, vector<[2]xf32>)
return
}
}
26 changes: 13 additions & 13 deletions mlir/test/Dialect/Vector/ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1119,42 +1119,42 @@ func.func @interleave_2d_scalable(%a: vector<2x[2]xf64>, %b: vector<2x[2]xf64>)

// CHECK-LABEL: @deinterleave_1d
func.func @deinterleave_1d(%arg: vector<4xf32>) -> (vector<2xf32>, vector<2xf32>) {
// CHECK: vector.deinterleave %{{.*}} : vector<4xf32>
%0, %1 = vector.deinterleave %arg : vector<4xf32>
// CHECK: vector.deinterleave %{{.*}} : vector<4xf32> -> vector<2xf32>
%0, %1 = vector.deinterleave %arg : vector<4xf32> -> vector<2xf32>
return %0, %1 : vector<2xf32>, vector<2xf32>
}

// CHECK-LABEL: @deinterleave_1d_scalable
func.func @deinterleave_1d_scalable(%arg: vector<[4]xf32>) -> (vector<[2]xf32>, vector<[2]xf32>) {
// CHECK: vector.deinterleave %{{.*}} : vector<[4]xf32>
%0, %1 = vector.deinterleave %arg : vector<[4]xf32>
// CHECK: vector.deinterleave %{{.*}} : vector<[4]xf32> -> vector<[2]xf32>
%0, %1 = vector.deinterleave %arg : vector<[4]xf32> -> vector<[2]xf32>
return %0, %1 : vector<[2]xf32>, vector<[2]xf32>
}

// CHECK-LABEL: @deinterleave_2d
func.func @deinterleave_2d(%arg: vector<3x4xf32>) -> (vector<3x2xf32>, vector<3x2xf32>) {
// CHECK: vector.deinterleave %{{.*}} : vector<3x4xf32>
%0, %1 = vector.deinterleave %arg : vector<3x4xf32>
// CHECK: vector.deinterleave %{{.*}} : vector<3x4xf32> -> vector<3x2xf32>
%0, %1 = vector.deinterleave %arg : vector<3x4xf32> -> vector<3x2xf32>
return %0, %1 : vector<3x2xf32>, vector<3x2xf32>
}

// CHECK-LABEL: @deinterleave_2d_scalable
func.func @deinterleave_2d_scalable(%arg: vector<3x[4]xf32>) -> (vector<3x[2]xf32>, vector<3x[2]xf32>) {
// CHECK: vector.deinterleave %{{.*}} : vector<3x[4]xf32>
%0, %1 = vector.deinterleave %arg : vector<3x[4]xf32>
// CHECK: vector.deinterleave %{{.*}} : vector<3x[4]xf32> -> vector<3x[2]xf32>
%0, %1 = vector.deinterleave %arg : vector<3x[4]xf32> -> vector<3x[2]xf32>
return %0, %1 : vector<3x[2]xf32>, vector<3x[2]xf32>
}

// CHECK-LABEL: @deinterleave_nd
func.func @deinterleave_nd(%arg: vector<2x3x4x6xf32>) -> (vector<2x3x4x3xf32>, vector<2x3x4x3xf32>) {
// CHECK: vector.deinterleave %{{.*}} : vector<2x3x4x6xf32>
%0, %1 = vector.deinterleave %arg : vector<2x3x4x6xf32>
// CHECK: vector.deinterleave %{{.*}} : vector<2x3x4x6xf32> -> vector<2x3x4x3xf32>
%0, %1 = vector.deinterleave %arg : vector<2x3x4x6xf32> -> vector<2x3x4x3xf32>
return %0, %1 : vector<2x3x4x3xf32>, vector<2x3x4x3xf32>
}

// CHECK-LABEL: @deinterleave_nd_scalable
func.func @deinterleave_nd_scalable(%arg:vector<2x3x4x[6]xf32>) -> (vector<2x3x4x[3]xf32>, vector<2x3x4x[3]xf32>) {
// CHECK: vector.deinterleave %{{.*}} : vector<2x3x4x[6]xf32>
%0, %1 = vector.deinterleave %arg : vector<2x3x4x[6]xf32>
// CHECK: vector.deinterleave %{{.*}} : vector<2x3x4x[6]xf32> -> vector<2x3x4x[3]xf32>
%0, %1 = vector.deinterleave %arg : vector<2x3x4x[6]xf32> -> vector<2x3x4x[3]xf32>
return %0, %1 : vector<2x3x4x[3]xf32>, vector<2x3x4x[3]xf32>
}
}

0 comments on commit ef12d14

Please sign in to comment.