diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td index 1e7e0a1715178d..bfbb40405c3c11 100644 --- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td @@ -544,8 +544,8 @@ def Vector_InterleaveOp : } class ResultIsHalfSourceVectorType : TypesMatchWith< - "type of 'input' is double the width of results", - "input", result, + "the trailing dimension of the results is half the width of source trailing dimension", + "source", result, [{ [&]() -> ::mlir::VectorType { auto vectorType = ::llvm::cast($_self); @@ -559,63 +559,67 @@ class ResultIsHalfSourceVectorType : TypesMatchWith< }] >; -def Vector_DeinterleaveOp : - Vector_Op<"deinterleave", [Pure, - PredOpTrait<"trailing dimension of input vector must be an even number", +def SourceVectorEvenElementCount : PredOpTrait< + "the trailing dimension of the source vector has an even number of elements", CPred<[{ [&](){ auto srcVec = getSourceVectorType(); return srcVec.getDimSize(srcVec.getRank() - 1) % 2 == 0; }() - }]>>, + }]> +>; + +def Vector_DeinterleaveOp : + Vector_Op<"deinterleave", [Pure, + SourceVectorEvenElementCount, ResultIsHalfSourceVectorType<"res1">, - ResultIsHalfSourceVectorType<"res2">, AllTypesMatch<["res1", "res2"]> ]> { let summary = "constructs two vectors by deinterleaving an input vector"; let description = [{ The deinterleave operation constructs two vectors from a single input - vector. Each new vector is the collection of even and odd elements - from the input, respectively. This is essentially the inverse of an - interleave operation. + vector. The first result vector contains the elements from even indexes + of the input, and the second contains elements from odd indexes. This is + the inverse of a `vector.interleave` operation. - Each output's size is half of the input vector's trailing dimension - for the n-D case and only dimension for 1-D cases. It is not possible - to conduct the operation on 0-D inputs or vectors where the size of - the (trailing) dimension is 1. + Each output's trailing dimension is half of the size of the input + vector's trailing dimension. This operation requires the input vector + to have a rank > 0 and an even number of elements in its trailing + dimension. The operation supports scalable vectors. Example: ```mlir - %0 = vector.deinterleave %a - : vector<[4]xi32> ; yields vector<[2]xi32>, vector<[2]xi32> - %1 = vector.deinterleave %b - : vector<8xi8> ; yields vector<4xi8>, vector<4xi8> - %2 = vector.deinterleave %c - : vector<2x8xf32> ; yields vector<2x4xf32>, vector<2x4xf32> - %3 = vector.deinterleave %d - : vector<2x4x[6]xf64> ; yields vector<2x4x[3]xf64>, vector<2x4x[3]xf64> + %0, %1 = vector.deinterleave %a + :vector<8xi8> -> vector<4xi8> + %2, %3 = vector.deinterleave %b + : vector<2x8xi8> -> vector<2x4xi8> + %4, %5 = vector.deinterleave %b + : vector<2x8x4xi8> -> vector<2x8x2xi8> + %6, %7 = vector.deinterleave %c + : vector<[8]xf32> -> vector<[4]xf32> + %8, %9 = vector.deinterleave %d + : vector<2x[6]xf64> -> vector<2x[3]xf64> + %10, %11 = vector.deinterleave %d + : vector<2x4x[6]xf64> -> vector<2x4x[3]xf64> ``` }]; - let arguments = (ins AnyVector:$input); + let arguments = (ins AnyVector:$source); let results = (outs AnyVector:$res1, AnyVector:$res2); let assemblyFormat = [{ - $input attr-dict `:` type($input) + $source attr-dict `:` type($source) `->` type($res1) }]; let extraClassDeclaration = [{ VectorType getSourceVectorType() { - return ::llvm::cast(getInput().getType()); + return ::llvm::cast(getSource().getType()); } - VectorType getResultOneVectorType() { + VectorType getResultVectorType() { return ::llvm::cast(getRes1().getType()); } - VectorType getResultTwoVectorType() { - return ::llvm::cast(getRes2().getType()); - } }]; } diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir index 25cacc6fdf93d8..1516f51fe1458a 100644 --- a/mlir/test/Dialect/Vector/invalid.mlir +++ b/mlir/test/Dialect/Vector/invalid.mlir @@ -1802,23 +1802,23 @@ func.func @invalid_outerproduct1(%src : memref) { // ----- func.func @deinterleave_zero_dim_fail(%vec : vector) { - // expected-error @+1 {{'vector.deinterleave' 'input' must be vector of any type values, but got 'vector'}} - %0, %1 = vector.deinterleave %vec : vector + // expected-error @+1 {{'vector.deinterleave' op operand #0 must be vector of any type values, but got 'vector}} + %0, %1 = vector.deinterleave %vec : vector -> vector return } // ----- func.func @deinterleave_one_dim_fail(%vec : vector<1xf32>) { - // expected-error @+1 {{'vector.deinterleave' op failed to verify that trailing dimension of input vector must be an even number}} - %0, %1 = vector.deinterleave %vec : vector<1xf32> + // expected-error @+1 {{'vector.deinterleave' op failed to verify that the trailing dimension of the source vector has an even number of elements}} + %0, %1 = vector.deinterleave %vec : vector<1xf32> -> vector<1xf32> return } // ----- func.func @deinterleave_oversized_output_fail(%vec : vector<4xf32>) { - // expected-error @+1 {{'vector.deinterleave' op failed to verify that type of 'input' is double the width of results}} + // expected-error @+1 {{'vector.deinterleave' op failed to verify that the trailing dimension of the results is half the width of source trailing dimension}} %0, %1 = "vector.deinterleave" (%vec) : (vector<4xf32>) -> (vector<8xf32>, vector<8xf32>) return } @@ -1826,7 +1826,7 @@ func.func @deinterleave_oversized_output_fail(%vec : vector<4xf32>) { // ----- func.func @deinterleave_output_dim_size_mismatch(%vec : vector<4xf32>) { - // expected-error @+1 {{'vector.deinterleave' op failed to verify that type of 'input' is double the width of results}} + // expected-error @+1 {{'vector.deinterleave' op failed to verify that the trailing dimension of the results is half the width of source trailing dimension}} %0, %1 = "vector.deinterleave" (%vec) : (vector<4xf32>) -> (vector<4xf32>, vector<2xf32>) return } @@ -1834,7 +1834,7 @@ func.func @deinterleave_output_dim_size_mismatch(%vec : vector<4xf32>) { // ----- func.func @deinterleave_n_dim_rank_fail(%vec : vector<2x3x4xf32>) { - // expected-error @+1 {{'vector.deinterleave' op failed to verify that type of 'input' is double the width of results}} + // expected-error @+1 {{'vector.deinterleave' op failed to verify that the trailing dimension of the results is half the width of source trailing dimension}} %0, %1 = "vector.deinterleave" (%vec) : (vector<2x3x4xf32>) -> (vector<2x3x4xf32>, vector<2x3x2xf32>) return } @@ -1842,7 +1842,7 @@ func.func @deinterleave_n_dim_rank_fail(%vec : vector<2x3x4xf32>) { // ----- func.func @deinterleave_scalable_dim_size_fail(%vec : vector<2x[4]xf32>) { - // expected-error @+1 {{'vector.deinterleave' op failed to verify that type of 'input' is double the width of results}} + // expected-error @+1 {{'vector.deinterleave' op failed to verify that all of {res1, res2} have same type}} %0, %1 = "vector.deinterleave" (%vec) : (vector<2x[4]xf32>) -> (vector<2x[2]xf32>, vector<2x[1]xf32>) return } @@ -1850,7 +1850,7 @@ func.func @deinterleave_scalable_dim_size_fail(%vec : vector<2x[4]xf32>) { // ----- func.func @deinterleave_scalable_rank_fail(%vec : vector<2x[4]xf32>) { - // expected-error @+1 {{'vector.deinterleave' op failed to verify that type of 'input' is double the width of results}} + // expected-error @+1 {{'vector.deinterleave' op failed to verify that all of {res1, res2} have same type}} %0, %1 = "vector.deinterleave" (%vec) : (vector<2x[4]xf32>) -> (vector<2x[2]xf32>, vector<[2]xf32>) return -} \ No newline at end of file +} diff --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir index a6a992f23a4ba4..9d8101d3eee978 100644 --- a/mlir/test/Dialect/Vector/ops.mlir +++ b/mlir/test/Dialect/Vector/ops.mlir @@ -1119,42 +1119,42 @@ func.func @interleave_2d_scalable(%a: vector<2x[2]xf64>, %b: vector<2x[2]xf64>) // CHECK-LABEL: @deinterleave_1d func.func @deinterleave_1d(%arg: vector<4xf32>) -> (vector<2xf32>, vector<2xf32>) { - // CHECK: vector.deinterleave %{{.*}} : vector<4xf32> - %0, %1 = vector.deinterleave %arg : vector<4xf32> + // CHECK: vector.deinterleave %{{.*}} : vector<4xf32> -> vector<2xf32> + %0, %1 = vector.deinterleave %arg : vector<4xf32> -> vector<2xf32> return %0, %1 : vector<2xf32>, vector<2xf32> } // CHECK-LABEL: @deinterleave_1d_scalable func.func @deinterleave_1d_scalable(%arg: vector<[4]xf32>) -> (vector<[2]xf32>, vector<[2]xf32>) { - // CHECK: vector.deinterleave %{{.*}} : vector<[4]xf32> - %0, %1 = vector.deinterleave %arg : vector<[4]xf32> + // CHECK: vector.deinterleave %{{.*}} : vector<[4]xf32> -> vector<[2]xf32> + %0, %1 = vector.deinterleave %arg : vector<[4]xf32> -> vector<[2]xf32> return %0, %1 : vector<[2]xf32>, vector<[2]xf32> } // CHECK-LABEL: @deinterleave_2d func.func @deinterleave_2d(%arg: vector<3x4xf32>) -> (vector<3x2xf32>, vector<3x2xf32>) { - // CHECK: vector.deinterleave %{{.*}} : vector<3x4xf32> - %0, %1 = vector.deinterleave %arg : vector<3x4xf32> + // CHECK: vector.deinterleave %{{.*}} : vector<3x4xf32> -> vector<3x2xf32> + %0, %1 = vector.deinterleave %arg : vector<3x4xf32> -> vector<3x2xf32> return %0, %1 : vector<3x2xf32>, vector<3x2xf32> } // CHECK-LABEL: @deinterleave_2d_scalable func.func @deinterleave_2d_scalable(%arg: vector<3x[4]xf32>) -> (vector<3x[2]xf32>, vector<3x[2]xf32>) { - // CHECK: vector.deinterleave %{{.*}} : vector<3x[4]xf32> - %0, %1 = vector.deinterleave %arg : vector<3x[4]xf32> + // CHECK: vector.deinterleave %{{.*}} : vector<3x[4]xf32> -> vector<3x[2]xf32> + %0, %1 = vector.deinterleave %arg : vector<3x[4]xf32> -> vector<3x[2]xf32> return %0, %1 : vector<3x[2]xf32>, vector<3x[2]xf32> } // CHECK-LABEL: @deinterleave_nd func.func @deinterleave_nd(%arg: vector<2x3x4x6xf32>) -> (vector<2x3x4x3xf32>, vector<2x3x4x3xf32>) { - // CHECK: vector.deinterleave %{{.*}} : vector<2x3x4x6xf32> - %0, %1 = vector.deinterleave %arg : vector<2x3x4x6xf32> + // CHECK: vector.deinterleave %{{.*}} : vector<2x3x4x6xf32> -> vector<2x3x4x3xf32> + %0, %1 = vector.deinterleave %arg : vector<2x3x4x6xf32> -> vector<2x3x4x3xf32> return %0, %1 : vector<2x3x4x3xf32>, vector<2x3x4x3xf32> } // CHECK-LABEL: @deinterleave_nd_scalable func.func @deinterleave_nd_scalable(%arg:vector<2x3x4x[6]xf32>) -> (vector<2x3x4x[3]xf32>, vector<2x3x4x[3]xf32>) { - // CHECK: vector.deinterleave %{{.*}} : vector<2x3x4x[6]xf32> - %0, %1 = vector.deinterleave %arg : vector<2x3x4x[6]xf32> + // CHECK: vector.deinterleave %{{.*}} : vector<2x3x4x[6]xf32> -> vector<2x3x4x[3]xf32> + %0, %1 = vector.deinterleave %arg : vector<2x3x4x[6]xf32> -> vector<2x3x4x[3]xf32> return %0, %1 : vector<2x3x4x[3]xf32>, vector<2x3x4x[3]xf32> -} \ No newline at end of file +}