Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][Tensor] Add pattern to fold concats of empty. #98994

Merged
merged 1 commit into from
Jul 17, 2024

Conversation

MaheshRavishankar
Copy link
Contributor

A concatenation of empty tensors can be replaced by a single empty tensor of the concatenated shape. Add this pattern to populateFoldTensorEmptyPatterns.

A concatenation of empty tensors can be replaced by a single empty
tensor of the concatenated shape. Add this pattern to
`populateFoldTensorEmptyPatterns`.
@llvmbot
Copy link
Collaborator

llvmbot commented Jul 16, 2024

@llvm/pr-subscribers-mlir-tensor

@llvm/pr-subscribers-mlir

Author: None (MaheshRavishankar)

Changes

A concatenation of empty tensors can be replaced by a single empty tensor of the concatenated shape. Add this pattern to populateFoldTensorEmptyPatterns.


Full diff: https://github.com/llvm/llvm-project/pull/98994.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Tensor/Transforms/EmptyOpPatterns.cpp (+35-2)
  • (modified) mlir/test/Dialect/Tensor/fold-empty-op.mlir (+38)
diff --git a/mlir/lib/Dialect/Tensor/Transforms/EmptyOpPatterns.cpp b/mlir/lib/Dialect/Tensor/Transforms/EmptyOpPatterns.cpp
index 43ad0acaf7420..60b0c3e759b6c 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/EmptyOpPatterns.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/EmptyOpPatterns.cpp
@@ -136,6 +136,38 @@ struct FoldEmptyTensorWithUnPackOp : public OpRewritePattern<UnPackOp> {
   }
 };
 
+// Fold concat operation where all the operands are empty.
+struct FoldConcatsOfEmpty : public OpRewritePattern<ConcatOp> {
+  using OpRewritePattern<ConcatOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(tensor::ConcatOp concatOp,
+                                PatternRewriter &rewriter) const override {
+    auto concatOperands = concatOp.getInputs();
+    if (concatOperands.empty()) {
+      return failure();
+    }
+    auto firstEmptyOp = concatOperands.front().getDefiningOp<tensor::EmptyOp>();
+    if (!firstEmptyOp) {
+      return failure();
+    }
+    auto isDefinedByEmptyOp = [](Value v) -> bool {
+      return v.getDefiningOp<tensor::EmptyOp>();
+    };
+    if (!llvm::all_of(concatOperands.drop_front(), isDefinedByEmptyOp)) {
+      return rewriter.notifyMatchFailure(
+          concatOp, "not all operands are defined by an empty op");
+    }
+    SmallVector<SmallVector<OpFoldResult>> resultShape;
+    if (failed(concatOp.reifyResultShapes(rewriter, resultShape))) {
+      return rewriter.notifyMatchFailure(concatOp,
+                                         "failed to get result shape");
+    }
+    rewriter.replaceOpWithNewOp<tensor::EmptyOp>(
+        concatOp, resultShape[0], concatOp.getResultType().getElementType());
+    return success();
+  }
+};
+
 } // namespace
 
 void mlir::tensor::populateFoldTensorEmptyPatterns(RewritePatternSet &patterns,
@@ -144,6 +176,7 @@ void mlir::tensor::populateFoldTensorEmptyPatterns(RewritePatternSet &patterns,
                FoldEmptyTensorWithReshapeOp<tensor::ExpandShapeOp>,
                FoldEmptyTensorWithReshapeOp<tensor::CollapseShapeOp>>(
       patterns.getContext(), /*benefit=*/1, foldSingleUseOnly);
-  patterns.add<FoldEmptyTensorWithPackOp, FoldEmptyTensorWithUnPackOp>(
-      patterns.getContext(), /*benefit=*/1);
+  patterns.add<FoldConcatsOfEmpty, FoldEmptyTensorWithPackOp,
+               FoldEmptyTensorWithUnPackOp>(patterns.getContext(),
+                                            /*benefit=*/1);
 }
diff --git a/mlir/test/Dialect/Tensor/fold-empty-op.mlir b/mlir/test/Dialect/Tensor/fold-empty-op.mlir
index e94f6ec7ec56e..5beb8c250aa10 100644
--- a/mlir/test/Dialect/Tensor/fold-empty-op.mlir
+++ b/mlir/test/Dialect/Tensor/fold-empty-op.mlir
@@ -164,3 +164,41 @@ func.func @double_use_of_tensor_empty(%arg0: index, %arg1: index)
 //       CHECK:   tensor.empty{{.*}} : tensor<?x10x40xf32>
 //       CHECK:   tensor.extract_slice
 //       CHECK:   tensor.extract_slice
+
+// -----
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%root : !transform.any_op {transform.readonly}) {
+    %func_op = transform.structured.match ops{["func.func"]} in %root : (!transform.any_op) -> !transform.op<"func.func">
+    transform.apply_patterns to %func_op {
+      transform.apply_patterns.tensor.fold_tensor_empty
+    } : !transform.op<"func.func">
+    transform.yield
+  }
+}
+
+func.func @concats_of_empty(
+    %arg0 : index, %arg1 : index, %arg2 : index, %arg3 : index)
+    -> tensor<5x?x?xf32>
+{
+  %0 = tensor.empty(%arg0, %arg1) : tensor<5x?x?xf32>
+  %1 = tensor.empty(%arg2, %arg3) : tensor<5x?x?xf32>
+  %2 = tensor.concat dim(1) %0, %1 : (tensor<5x?x?xf32>, tensor<5x?x?xf32>) -> tensor<5x?x?xf32>
+  return %2 : tensor<5x?x?xf32>
+}
+//       CHECK: #[[MAP:.+]] = affine_map<()[s0, s1] -> (s0 + s1)>
+//       CHECK: func @concats_of_empty(
+//  CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]: index,
+//  CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]: index,
+//  CHECK-SAME:     %[[ARG2:[a-zA-Z0-9]+]]: index,
+//  CHECK-SAME:     %[[ARG3:[a-zA-Z0-9]+]]: index)
+//   CHECK-DAG:   %[[C1:.+]] = arith.constant 1 : index
+//   CHECK-DAG:   %[[C2:.+]] = arith.constant 2 : index
+//   CHECK-DAG:   %[[EMPTY0:.+]] = tensor.empty(%[[ARG0]], %[[ARG1]])
+//   CHECK-DAG:   %[[EMPTY1:.+]] = tensor.empty(%[[ARG2]], %[[ARG3]])
+//       CHECK:   %[[D2:.+]] = tensor.dim %[[EMPTY0]], %[[C2]]
+//   CHECK-DAG:   %[[D0_1:.+]] = tensor.dim %[[EMPTY0]], %[[C1]]
+//   CHECK-DAG:   %[[D1_1:.+]] = tensor.dim %[[EMPTY1]], %[[C1]]
+//   CHECK-DAG:   %[[SUM:.+]] = affine.apply #[[MAP]]()[%[[D0_1]], %[[D1_1]]]
+//       CHECK:   %[[NEW_EMPTY:.+]] = tensor.empty(%[[SUM]], %[[D2]])
+//       CHECK:   return %[[NEW_EMPTY]]

Copy link
Contributor

@hanhanW hanhanW left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM with one question

Comment on lines +149 to +159
auto firstEmptyOp = concatOperands.front().getDefiningOp<tensor::EmptyOp>();
if (!firstEmptyOp) {
return failure();
}
auto isDefinedByEmptyOp = [](Value v) -> bool {
return v.getDefiningOp<tensor::EmptyOp>();
};
if (!llvm::all_of(concatOperands.drop_front(), isDefinedByEmptyOp)) {
return rewriter.notifyMatchFailure(
concatOp, "not all operands are defined by an empty op");
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is the first operand special? Can't we just write something like below?

if (!llvm::all_of(concatOperands, isDefinedByEmptyOp)) {
  ....
}

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just adding an early exit if the first operand is itself is not an empty op.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought that the llvm::all_of should already consider the case? I imagine that it also does early-exit if the first operand is not an empty op. It could be different in multi-threads though, idk.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 I don't quite understand the code here: llvm::all_of should already be having the early exit.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was just trying to avoid even doing the all_of at all if even the first operand is not empty. But ok, I'll change it.

@MaheshRavishankar MaheshRavishankar merged commit c077a4f into llvm:main Jul 17, 2024
10 checks passed
sgundapa pushed a commit to sgundapa/upstream_effort that referenced this pull request Jul 23, 2024
A concatenation of empty tensors can be replaced by a single empty
tensor of the concatenated shape. Add this pattern to
`populateFoldTensorEmptyPatterns`.
yuxuanchen1997 pushed a commit that referenced this pull request Jul 25, 2024
Summary:
A concatenation of empty tensors can be replaced by a single empty
tensor of the concatenated shape. Add this pattern to
`populateFoldTensorEmptyPatterns`.

Test Plan: 

Reviewers: 

Subscribers: 

Tasks: 

Tags: 


Differential Revision: https://phabricator.intern.facebook.com/D60250835
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants