From a6a972576590bddd43b89e031d6bb5d127d9cc3e Mon Sep 17 00:00:00 2001 From: "Mubashar.Ahmad@arm.com" Date: Thu, 16 May 2024 12:28:34 +0000 Subject: [PATCH] [mlir][VectorOps] Add deinterleave operation to vector dialect The deinterleave operation constructs two vectors from a single input vector. Each new vector is the collection of even and odd elements from the input, respectively. This is essentially the inverse of an interleave operation. Each output's size is half of the input vector's trailing dimension for the n-D case and only dimension for 1-D cases. It is not possible to conduct the operation on 0-D inputs or vectors where the size of the (trailing) dimension is 1. The operation supports scalable vectors. Example: ```mlir %0 = vector.deinterleave %a : vector<[4]xi32> -> vector<[2]xi32>x2 %1 = vector.deinterleave %b : vector<8xi8> -> vector<4xi8>x2 %2 = vector.deinterleave %c : vector<2x8xf32> -> vector<2x4xf32>x2 %3 = vector.deinterleave %d : vector<2x4x[6]xf64> -> vector<2x4x[3]xf64>x2 ``` --- .../mlir/Dialect/Vector/IR/VectorOps.td | 56 +++++++++---------- mlir/test/Dialect/Vector/invalid.mlir | 20 +++---- mlir/test/Dialect/Vector/ops.mlir | 26 ++++----- 3 files changed, 51 insertions(+), 51 deletions(-) diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td index 1e7e0a1715178d..0d6e99b4bb34d4 100644 --- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td @@ -544,8 +544,8 @@ def Vector_InterleaveOp : } class ResultIsHalfSourceVectorType : TypesMatchWith< - "type of 'input' is double the width of results", - "input", result, + "the trailing dimension of the results is half the width of source trailing dimension", + "source", result, [{ [&]() -> ::mlir::VectorType { auto vectorType = ::llvm::cast($_self); @@ -559,15 +559,18 @@ class ResultIsHalfSourceVectorType : TypesMatchWith< }] >; -def Vector_DeinterleaveOp : - Vector_Op<"deinterleave", [Pure, - PredOpTrait<"trailing dimension of input vector must be an even number", +class SourceVectorEvenElementCount : PredOpTrait<"the trailing dimension of the source vector has an even number of elements", CPred<[{ [&](){ auto srcVec = getSourceVectorType(); return srcVec.getDimSize(srcVec.getRank() - 1) % 2 == 0; }() - }]>>, + }]> +>; + +def Vector_DeinterleaveOp : + Vector_Op<"deinterleave", [Pure, + SourceVectorEvenElementCount<>, ResultIsHalfSourceVectorType<"res1">, ResultIsHalfSourceVectorType<"res2">, AllTypesMatch<["res1", "res2"]> @@ -575,47 +578,44 @@ def Vector_DeinterleaveOp : let summary = "constructs two vectors by deinterleaving an input vector"; let description = [{ The deinterleave operation constructs two vectors from a single input - vector. Each new vector is the collection of even and odd elements - from the input, respectively. This is essentially the inverse of an - interleave operation. + vector. The first result vector contains the elements from even indexes + of the input, and the second contains elements from odd indexes. This is + the inverse of a 'vector.interleave' operation. - Each output's size is half of the input vector's trailing dimension - for the n-D case and only dimension for 1-D cases. It is not possible - to conduct the operation on 0-D inputs or vectors where the size of - the (trailing) dimension is 1. + Each output's trailing dimension is half of the size of the input + vector's trailing dimension. This operation requires the input vector + to have a rank > 0 and an even number of elements in its trailing + dimension. The operation supports scalable vectors. Example: ```mlir - %0 = vector.deinterleave %a - : vector<[4]xi32> ; yields vector<[2]xi32>, vector<[2]xi32> - %1 = vector.deinterleave %b - : vector<8xi8> ; yields vector<4xi8>, vector<4xi8> - %2 = vector.deinterleave %c - : vector<2x8xf32> ; yields vector<2x4xf32>, vector<2x4xf32> - %3 = vector.deinterleave %d - : vector<2x4x[6]xf64> ; yields vector<2x4x[3]xf64>, vector<2x4x[3]xf64> + %0, %1 = vector.deinterleave %a + : vector<[4]xi32> -> vector<[2]xi32>x2 + %2, %3 = vector.deinterleave %b + : vector<8xi8> -> vector<4xi8>x2 + %4, %5 = vector.deinterleave %c + : vector<2x8xf32> -> vector<2x4xf32>x2 + %6, %7 = vector.deinterleave %d + : vector<2x4x[6]xf64> -> vector<2x4x[3]xf64>x2 ``` }]; - let arguments = (ins AnyVector:$input); + let arguments = (ins AnyVector:$source); let results = (outs AnyVector:$res1, AnyVector:$res2); let assemblyFormat = [{ - $input attr-dict `:` type($input) + $source attr-dict `:` type($source) `->` type($res1) `` `x2` }]; let extraClassDeclaration = [{ VectorType getSourceVectorType() { - return ::llvm::cast(getInput().getType()); + return ::llvm::cast(getSource().getType()); } - VectorType getResultOneVectorType() { + VectorType getResultVectorType() { return ::llvm::cast(getRes1().getType()); } - VectorType getResultTwoVectorType() { - return ::llvm::cast(getRes2().getType()); - } }]; } diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir index 25cacc6fdf93d8..079ad909640d69 100644 --- a/mlir/test/Dialect/Vector/invalid.mlir +++ b/mlir/test/Dialect/Vector/invalid.mlir @@ -1802,23 +1802,23 @@ func.func @invalid_outerproduct1(%src : memref) { // ----- func.func @deinterleave_zero_dim_fail(%vec : vector) { - // expected-error @+1 {{'vector.deinterleave' 'input' must be vector of any type values, but got 'vector'}} - %0, %1 = vector.deinterleave %vec : vector + // expected-error @+1 {{'vector.deinterleave' op operand #0 must be vector of any type values, but got 'vector}} + %0, %1 = vector.deinterleave %vec : vector -> vectorx2 return } // ----- func.func @deinterleave_one_dim_fail(%vec : vector<1xf32>) { - // expected-error @+1 {{'vector.deinterleave' op failed to verify that trailing dimension of input vector must be an even number}} - %0, %1 = vector.deinterleave %vec : vector<1xf32> + // expected-error @+1 {{'vector.deinterleave' op failed to verify that the trailing dimension of the source vector has an even number of elements}} + %0, %1 = vector.deinterleave %vec : vector<1xf32> -> vector<1xf32>x2 return } // ----- func.func @deinterleave_oversized_output_fail(%vec : vector<4xf32>) { - // expected-error @+1 {{'vector.deinterleave' op failed to verify that type of 'input' is double the width of results}} + // expected-error @+1 {{'vector.deinterleave' op failed to verify that the trailing dimension of the results is half the width of source trailing dimension}} %0, %1 = "vector.deinterleave" (%vec) : (vector<4xf32>) -> (vector<8xf32>, vector<8xf32>) return } @@ -1826,7 +1826,7 @@ func.func @deinterleave_oversized_output_fail(%vec : vector<4xf32>) { // ----- func.func @deinterleave_output_dim_size_mismatch(%vec : vector<4xf32>) { - // expected-error @+1 {{'vector.deinterleave' op failed to verify that type of 'input' is double the width of results}} + // expected-error @+1 {{'vector.deinterleave' op failed to verify that the trailing dimension of the results is half the width of source trailing dimension}} %0, %1 = "vector.deinterleave" (%vec) : (vector<4xf32>) -> (vector<4xf32>, vector<2xf32>) return } @@ -1834,7 +1834,7 @@ func.func @deinterleave_output_dim_size_mismatch(%vec : vector<4xf32>) { // ----- func.func @deinterleave_n_dim_rank_fail(%vec : vector<2x3x4xf32>) { - // expected-error @+1 {{'vector.deinterleave' op failed to verify that type of 'input' is double the width of results}} + // expected-error @+1 {{'vector.deinterleave' op failed to verify that the trailing dimension of the results is half the width of source trailing dimension}} %0, %1 = "vector.deinterleave" (%vec) : (vector<2x3x4xf32>) -> (vector<2x3x4xf32>, vector<2x3x2xf32>) return } @@ -1842,7 +1842,7 @@ func.func @deinterleave_n_dim_rank_fail(%vec : vector<2x3x4xf32>) { // ----- func.func @deinterleave_scalable_dim_size_fail(%vec : vector<2x[4]xf32>) { - // expected-error @+1 {{'vector.deinterleave' op failed to verify that type of 'input' is double the width of results}} + // expected-error @+1 {{'vector.deinterleave' op failed to verify that the trailing dimension of the results is half the width of source trailing dimension}} %0, %1 = "vector.deinterleave" (%vec) : (vector<2x[4]xf32>) -> (vector<2x[2]xf32>, vector<2x[1]xf32>) return } @@ -1850,7 +1850,7 @@ func.func @deinterleave_scalable_dim_size_fail(%vec : vector<2x[4]xf32>) { // ----- func.func @deinterleave_scalable_rank_fail(%vec : vector<2x[4]xf32>) { - // expected-error @+1 {{'vector.deinterleave' op failed to verify that type of 'input' is double the width of results}} + // expected-error @+1 {{'vector.deinterleave' op failed to verify that the trailing dimension of the results is half the width of source trailing dimension}} %0, %1 = "vector.deinterleave" (%vec) : (vector<2x[4]xf32>) -> (vector<2x[2]xf32>, vector<[2]xf32>) return -} \ No newline at end of file +} diff --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir index a6a992f23a4ba4..a22385e535f793 100644 --- a/mlir/test/Dialect/Vector/ops.mlir +++ b/mlir/test/Dialect/Vector/ops.mlir @@ -1119,42 +1119,42 @@ func.func @interleave_2d_scalable(%a: vector<2x[2]xf64>, %b: vector<2x[2]xf64>) // CHECK-LABEL: @deinterleave_1d func.func @deinterleave_1d(%arg: vector<4xf32>) -> (vector<2xf32>, vector<2xf32>) { - // CHECK: vector.deinterleave %{{.*}} : vector<4xf32> - %0, %1 = vector.deinterleave %arg : vector<4xf32> + // CHECK: vector.deinterleave %{{.*}} : vector<4xf32> -> vector<2xf32>x2 + %0, %1 = vector.deinterleave %arg : vector<4xf32> -> vector<2xf32>x2 return %0, %1 : vector<2xf32>, vector<2xf32> } // CHECK-LABEL: @deinterleave_1d_scalable func.func @deinterleave_1d_scalable(%arg: vector<[4]xf32>) -> (vector<[2]xf32>, vector<[2]xf32>) { - // CHECK: vector.deinterleave %{{.*}} : vector<[4]xf32> - %0, %1 = vector.deinterleave %arg : vector<[4]xf32> + // CHECK: vector.deinterleave %{{.*}} : vector<[4]xf32> -> vector<[2]xf32>x2 + %0, %1 = vector.deinterleave %arg : vector<[4]xf32> -> vector<[2]xf32>x2 return %0, %1 : vector<[2]xf32>, vector<[2]xf32> } // CHECK-LABEL: @deinterleave_2d func.func @deinterleave_2d(%arg: vector<3x4xf32>) -> (vector<3x2xf32>, vector<3x2xf32>) { - // CHECK: vector.deinterleave %{{.*}} : vector<3x4xf32> - %0, %1 = vector.deinterleave %arg : vector<3x4xf32> + // CHECK: vector.deinterleave %{{.*}} : vector<3x4xf32> -> vector<3x2xf32>x2 + %0, %1 = vector.deinterleave %arg : vector<3x4xf32> -> vector<3x2xf32>x2 return %0, %1 : vector<3x2xf32>, vector<3x2xf32> } // CHECK-LABEL: @deinterleave_2d_scalable func.func @deinterleave_2d_scalable(%arg: vector<3x[4]xf32>) -> (vector<3x[2]xf32>, vector<3x[2]xf32>) { - // CHECK: vector.deinterleave %{{.*}} : vector<3x[4]xf32> - %0, %1 = vector.deinterleave %arg : vector<3x[4]xf32> + // CHECK: vector.deinterleave %{{.*}} : vector<3x[4]xf32> -> vector<3x[2]xf32>x2 + %0, %1 = vector.deinterleave %arg : vector<3x[4]xf32> -> vector<3x[2]xf32>x2 return %0, %1 : vector<3x[2]xf32>, vector<3x[2]xf32> } // CHECK-LABEL: @deinterleave_nd func.func @deinterleave_nd(%arg: vector<2x3x4x6xf32>) -> (vector<2x3x4x3xf32>, vector<2x3x4x3xf32>) { - // CHECK: vector.deinterleave %{{.*}} : vector<2x3x4x6xf32> - %0, %1 = vector.deinterleave %arg : vector<2x3x4x6xf32> + // CHECK: vector.deinterleave %{{.*}} : vector<2x3x4x6xf32> -> vector<2x3x4x3xf32>x2 + %0, %1 = vector.deinterleave %arg : vector<2x3x4x6xf32> -> vector<2x3x4x3xf32>x2 return %0, %1 : vector<2x3x4x3xf32>, vector<2x3x4x3xf32> } // CHECK-LABEL: @deinterleave_nd_scalable func.func @deinterleave_nd_scalable(%arg:vector<2x3x4x[6]xf32>) -> (vector<2x3x4x[3]xf32>, vector<2x3x4x[3]xf32>) { - // CHECK: vector.deinterleave %{{.*}} : vector<2x3x4x[6]xf32> - %0, %1 = vector.deinterleave %arg : vector<2x3x4x[6]xf32> + // CHECK: vector.deinterleave %{{.*}} : vector<2x3x4x[6]xf32> -> vector<2x3x4x[3]xf32>x2 + %0, %1 = vector.deinterleave %arg : vector<2x3x4x[6]xf32> -> vector<2x3x4x[3]xf32>x2 return %0, %1 : vector<2x3x4x[3]xf32>, vector<2x3x4x[3]xf32> -} \ No newline at end of file +}