diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp index 0e5e563ed5450af..77f0ea9d2236ea6 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp @@ -305,8 +305,6 @@ FailureOr linalg::lowerPack(RewriterBase &rewriter, if (rankReduces == SliceVerificationResult::Success) { // This pack is just a plain pad. // Just insert the pad in the higher ranked tensor. - auto emptyOp = - rewriter.create(loc, packedTensorType, ValueRange{}); // Offsets. SmallVector zeros(packOp.getDestRank(), rewriter.getIndexAttr(0)); @@ -317,9 +315,8 @@ FailureOr linalg::lowerPack(RewriterBase &rewriter, tensor::getMixedSizes(rewriter, loc, packOp.getDest()); auto insertSliceOp = rewriter.create( - loc, /*source=*/padOp, /*dest=*/emptyOp, - /*offsets=*/zeros, sizes, - /*strides=*/ones); + loc, /*source=*/padOp, /*dest=*/packOp.getDest(), + /*offsets=*/zeros, sizes, /*strides=*/ones); LLVM_DEBUG(DBGS() << "insert_slice op: " << insertSliceOp; DBGSNL();); diff --git a/mlir/test/Dialect/Linalg/transform-lower-pack.mlir b/mlir/test/Dialect/Linalg/transform-lower-pack.mlir index f34ef4f961483d7..48bf1c151de8f5f 100644 --- a/mlir/test/Dialect/Linalg/transform-lower-pack.mlir +++ b/mlir/test/Dialect/Linalg/transform-lower-pack.mlir @@ -62,14 +62,15 @@ module attributes {transform.with_named_sequence} { // ----- // CHECK-LABEL: func.func @pack_as_pad( +// CHECK: %[[SRC:.+]]: tensor<129x47x16x16xf32>, +// CHECK: %[[OUT:.+]]: tensor<1x1x1x1x136x64x16x16xf32>) func.func @pack_as_pad(%arg0: tensor<129x47x16x16xf32>, %arg1: tensor<1x1x1x1x136x64x16x16xf32>) -> tensor<1x1x1x1x136x64x16x16xf32> { %cst_0 = arith.constant 0.0 : f32 // tensor.pack is lowered to tensor.pad + tensor.insert_slice - // CHECK: %[[PAD:.*]] = tensor.pad {{.*}} low[0, 0, 0, 0] + // CHECK: %[[PAD:.*]] = tensor.pad %[[SRC]] low[0, 0, 0, 0] high[7, 17, 0, 0] // CHECK: : tensor<129x47x16x16xf32> to tensor<136x64x16x16xf32> - // CHECK: %[[EMPTY:.*]] = tensor.empty() : tensor<1x1x1x1x136x64x16x16xf32> - // CHECK: %[[RES:.*]] = tensor.insert_slice %[[PAD]] into %[[EMPTY]] + // CHECK: %[[RES:.*]] = tensor.insert_slice %[[PAD]] into %[[OUT]] // offsets. // CHECK-SAME: [0, 0, 0, 0, 0, 0, 0, 0] // sizes. @@ -387,14 +388,15 @@ module attributes {transform.with_named_sequence} { // ----- // CHECK-LABEL: func.func @pack_as_pad_with_outer_dims_perm( +// CHECK: %[[SRC:.+]]: tensor<129x47x16x16xf32>, +// CHECK: %[[OUT:.+]]: tensor<1x1x1x1x136x64x16x16xf32>) func.func @pack_as_pad_with_outer_dims_perm(%arg0: tensor<129x47x16x16xf32>, %arg1: tensor<1x1x1x1x136x64x16x16xf32>) -> tensor<1x1x1x1x136x64x16x16xf32> { %cst_0 = arith.constant 0.0 : f32 // tensor.pack is lowered to tensor.pad + tensor.insert_slice - // CHECK: %[[PAD:.*]] = tensor.pad {{.*}} low[0, 0, 0, 0] + // CHECK: %[[PAD:.*]] = tensor.pad %[[SRC]] low[0, 0, 0, 0] high[7, 17, 0, 0] // CHECK: : tensor<129x47x16x16xf32> to tensor<136x64x16x16xf32> - // CHECK: %[[EMPTY:.*]] = tensor.empty() : tensor<1x1x1x1x136x64x16x16xf32> - // CHECK: %[[RES:.*]] = tensor.insert_slice %[[PAD]] into %[[EMPTY]] + // CHECK: %[[RES:.*]] = tensor.insert_slice %[[PAD]] into %[[OUT]] // offsets. // CHECK-SAME: [0, 0, 0, 0, 0, 0, 0, 0] // sizes.