Skip to content

Commit

Permalink
[mlir] Reuse pack dest in tensor.pack decomposition (llvm#108025)
Browse files Browse the repository at this point in the history
In the `lowerPack` transform, there is a special case for lowering into
a simple `tensor.pad` + `tensor.insert_slice`, but the destination
becomes a newly created `tensor.empty`. This PR fixes the transform to
reuse the original destination of the `tensor.pack`.
  • Loading branch information
Max191 authored and VitaNuo committed Sep 12, 2024
1 parent fe69476 commit 5b04947
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 11 deletions.
7 changes: 2 additions & 5 deletions mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -305,8 +305,6 @@ FailureOr<LowerPackResult> linalg::lowerPack(RewriterBase &rewriter,
if (rankReduces == SliceVerificationResult::Success) {
// This pack is just a plain pad.
// Just insert the pad in the higher ranked tensor.
auto emptyOp =
rewriter.create<tensor::EmptyOp>(loc, packedTensorType, ValueRange{});
// Offsets.
SmallVector<OpFoldResult> zeros(packOp.getDestRank(),
rewriter.getIndexAttr(0));
Expand All @@ -317,9 +315,8 @@ FailureOr<LowerPackResult> linalg::lowerPack(RewriterBase &rewriter,
tensor::getMixedSizes(rewriter, loc, packOp.getDest());

auto insertSliceOp = rewriter.create<tensor::InsertSliceOp>(
loc, /*source=*/padOp, /*dest=*/emptyOp,
/*offsets=*/zeros, sizes,
/*strides=*/ones);
loc, /*source=*/padOp, /*dest=*/packOp.getDest(),
/*offsets=*/zeros, sizes, /*strides=*/ones);

LLVM_DEBUG(DBGS() << "insert_slice op: " << insertSliceOp; DBGSNL(););

Expand Down
14 changes: 8 additions & 6 deletions mlir/test/Dialect/Linalg/transform-lower-pack.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -62,14 +62,15 @@ module attributes {transform.with_named_sequence} {
// -----

// CHECK-LABEL: func.func @pack_as_pad(
// CHECK: %[[SRC:.+]]: tensor<129x47x16x16xf32>,
// CHECK: %[[OUT:.+]]: tensor<1x1x1x1x136x64x16x16xf32>)
func.func @pack_as_pad(%arg0: tensor<129x47x16x16xf32>, %arg1: tensor<1x1x1x1x136x64x16x16xf32>) -> tensor<1x1x1x1x136x64x16x16xf32> {
%cst_0 = arith.constant 0.0 : f32

// tensor.pack is lowered to tensor.pad + tensor.insert_slice
// CHECK: %[[PAD:.*]] = tensor.pad {{.*}} low[0, 0, 0, 0]
// CHECK: %[[PAD:.*]] = tensor.pad %[[SRC]] low[0, 0, 0, 0] high[7, 17, 0, 0]
// CHECK: : tensor<129x47x16x16xf32> to tensor<136x64x16x16xf32>
// CHECK: %[[EMPTY:.*]] = tensor.empty() : tensor<1x1x1x1x136x64x16x16xf32>
// CHECK: %[[RES:.*]] = tensor.insert_slice %[[PAD]] into %[[EMPTY]]
// CHECK: %[[RES:.*]] = tensor.insert_slice %[[PAD]] into %[[OUT]]
// offsets.
// CHECK-SAME: [0, 0, 0, 0, 0, 0, 0, 0]
// sizes.
Expand Down Expand Up @@ -387,14 +388,15 @@ module attributes {transform.with_named_sequence} {
// -----

// CHECK-LABEL: func.func @pack_as_pad_with_outer_dims_perm(
// CHECK: %[[SRC:.+]]: tensor<129x47x16x16xf32>,
// CHECK: %[[OUT:.+]]: tensor<1x1x1x1x136x64x16x16xf32>)
func.func @pack_as_pad_with_outer_dims_perm(%arg0: tensor<129x47x16x16xf32>, %arg1: tensor<1x1x1x1x136x64x16x16xf32>) -> tensor<1x1x1x1x136x64x16x16xf32> {
%cst_0 = arith.constant 0.0 : f32

// tensor.pack is lowered to tensor.pad + tensor.insert_slice
// CHECK: %[[PAD:.*]] = tensor.pad {{.*}} low[0, 0, 0, 0]
// CHECK: %[[PAD:.*]] = tensor.pad %[[SRC]] low[0, 0, 0, 0] high[7, 17, 0, 0]
// CHECK: : tensor<129x47x16x16xf32> to tensor<136x64x16x16xf32>
// CHECK: %[[EMPTY:.*]] = tensor.empty() : tensor<1x1x1x1x136x64x16x16xf32>
// CHECK: %[[RES:.*]] = tensor.insert_slice %[[PAD]] into %[[EMPTY]]
// CHECK: %[[RES:.*]] = tensor.insert_slice %[[PAD]] into %[[OUT]]
// offsets.
// CHECK-SAME: [0, 0, 0, 0, 0, 0, 0, 0]
// sizes.
Expand Down

0 comments on commit 5b04947

Please sign in to comment.