-
Notifications
You must be signed in to change notification settings - Fork 11.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][Tensor] Add pattern to fold concats of empty. #98994
[mlir][Tensor] Add pattern to fold concats of empty. #98994
Conversation
A concatenation of empty tensors can be replaced by a single empty tensor of the concatenated shape. Add this pattern to `populateFoldTensorEmptyPatterns`.
@llvm/pr-subscribers-mlir-tensor @llvm/pr-subscribers-mlir Author: None (MaheshRavishankar) ChangesA concatenation of empty tensors can be replaced by a single empty tensor of the concatenated shape. Add this pattern to Full diff: https://github.com/llvm/llvm-project/pull/98994.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Tensor/Transforms/EmptyOpPatterns.cpp b/mlir/lib/Dialect/Tensor/Transforms/EmptyOpPatterns.cpp
index 43ad0acaf7420..60b0c3e759b6c 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/EmptyOpPatterns.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/EmptyOpPatterns.cpp
@@ -136,6 +136,38 @@ struct FoldEmptyTensorWithUnPackOp : public OpRewritePattern<UnPackOp> {
}
};
+// Fold concat operation where all the operands are empty.
+struct FoldConcatsOfEmpty : public OpRewritePattern<ConcatOp> {
+ using OpRewritePattern<ConcatOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(tensor::ConcatOp concatOp,
+ PatternRewriter &rewriter) const override {
+ auto concatOperands = concatOp.getInputs();
+ if (concatOperands.empty()) {
+ return failure();
+ }
+ auto firstEmptyOp = concatOperands.front().getDefiningOp<tensor::EmptyOp>();
+ if (!firstEmptyOp) {
+ return failure();
+ }
+ auto isDefinedByEmptyOp = [](Value v) -> bool {
+ return v.getDefiningOp<tensor::EmptyOp>();
+ };
+ if (!llvm::all_of(concatOperands.drop_front(), isDefinedByEmptyOp)) {
+ return rewriter.notifyMatchFailure(
+ concatOp, "not all operands are defined by an empty op");
+ }
+ SmallVector<SmallVector<OpFoldResult>> resultShape;
+ if (failed(concatOp.reifyResultShapes(rewriter, resultShape))) {
+ return rewriter.notifyMatchFailure(concatOp,
+ "failed to get result shape");
+ }
+ rewriter.replaceOpWithNewOp<tensor::EmptyOp>(
+ concatOp, resultShape[0], concatOp.getResultType().getElementType());
+ return success();
+ }
+};
+
} // namespace
void mlir::tensor::populateFoldTensorEmptyPatterns(RewritePatternSet &patterns,
@@ -144,6 +176,7 @@ void mlir::tensor::populateFoldTensorEmptyPatterns(RewritePatternSet &patterns,
FoldEmptyTensorWithReshapeOp<tensor::ExpandShapeOp>,
FoldEmptyTensorWithReshapeOp<tensor::CollapseShapeOp>>(
patterns.getContext(), /*benefit=*/1, foldSingleUseOnly);
- patterns.add<FoldEmptyTensorWithPackOp, FoldEmptyTensorWithUnPackOp>(
- patterns.getContext(), /*benefit=*/1);
+ patterns.add<FoldConcatsOfEmpty, FoldEmptyTensorWithPackOp,
+ FoldEmptyTensorWithUnPackOp>(patterns.getContext(),
+ /*benefit=*/1);
}
diff --git a/mlir/test/Dialect/Tensor/fold-empty-op.mlir b/mlir/test/Dialect/Tensor/fold-empty-op.mlir
index e94f6ec7ec56e..5beb8c250aa10 100644
--- a/mlir/test/Dialect/Tensor/fold-empty-op.mlir
+++ b/mlir/test/Dialect/Tensor/fold-empty-op.mlir
@@ -164,3 +164,41 @@ func.func @double_use_of_tensor_empty(%arg0: index, %arg1: index)
// CHECK: tensor.empty{{.*}} : tensor<?x10x40xf32>
// CHECK: tensor.extract_slice
// CHECK: tensor.extract_slice
+
+// -----
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%root : !transform.any_op {transform.readonly}) {
+ %func_op = transform.structured.match ops{["func.func"]} in %root : (!transform.any_op) -> !transform.op<"func.func">
+ transform.apply_patterns to %func_op {
+ transform.apply_patterns.tensor.fold_tensor_empty
+ } : !transform.op<"func.func">
+ transform.yield
+ }
+}
+
+func.func @concats_of_empty(
+ %arg0 : index, %arg1 : index, %arg2 : index, %arg3 : index)
+ -> tensor<5x?x?xf32>
+{
+ %0 = tensor.empty(%arg0, %arg1) : tensor<5x?x?xf32>
+ %1 = tensor.empty(%arg2, %arg3) : tensor<5x?x?xf32>
+ %2 = tensor.concat dim(1) %0, %1 : (tensor<5x?x?xf32>, tensor<5x?x?xf32>) -> tensor<5x?x?xf32>
+ return %2 : tensor<5x?x?xf32>
+}
+// CHECK: #[[MAP:.+]] = affine_map<()[s0, s1] -> (s0 + s1)>
+// CHECK: func @concats_of_empty(
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: index,
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: index,
+// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: index,
+// CHECK-SAME: %[[ARG3:[a-zA-Z0-9]+]]: index)
+// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
+// CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index
+// CHECK-DAG: %[[EMPTY0:.+]] = tensor.empty(%[[ARG0]], %[[ARG1]])
+// CHECK-DAG: %[[EMPTY1:.+]] = tensor.empty(%[[ARG2]], %[[ARG3]])
+// CHECK: %[[D2:.+]] = tensor.dim %[[EMPTY0]], %[[C2]]
+// CHECK-DAG: %[[D0_1:.+]] = tensor.dim %[[EMPTY0]], %[[C1]]
+// CHECK-DAG: %[[D1_1:.+]] = tensor.dim %[[EMPTY1]], %[[C1]]
+// CHECK-DAG: %[[SUM:.+]] = affine.apply #[[MAP]]()[%[[D0_1]], %[[D1_1]]]
+// CHECK: %[[NEW_EMPTY:.+]] = tensor.empty(%[[SUM]], %[[D2]])
+// CHECK: return %[[NEW_EMPTY]]
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM with one question
auto firstEmptyOp = concatOperands.front().getDefiningOp<tensor::EmptyOp>(); | ||
if (!firstEmptyOp) { | ||
return failure(); | ||
} | ||
auto isDefinedByEmptyOp = [](Value v) -> bool { | ||
return v.getDefiningOp<tensor::EmptyOp>(); | ||
}; | ||
if (!llvm::all_of(concatOperands.drop_front(), isDefinedByEmptyOp)) { | ||
return rewriter.notifyMatchFailure( | ||
concatOp, "not all operands are defined by an empty op"); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is the first operand special? Can't we just write something like below?
if (!llvm::all_of(concatOperands, isDefinedByEmptyOp)) {
....
}
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just adding an early exit if the first operand is itself is not an empty op.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I thought that the llvm::all_of
should already consider the case? I imagine that it also does early-exit if the first operand is not an empty op. It could be different in multi-threads though, idk.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1 I don't quite understand the code here: llvm::all_of should already be having the early exit.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was just trying to avoid even doing the all_of
at all if even the first operand is not empty. But ok, I'll change it.
A concatenation of empty tensors can be replaced by a single empty tensor of the concatenated shape. Add this pattern to `populateFoldTensorEmptyPatterns`.
Summary: A concatenation of empty tensors can be replaced by a single empty tensor of the concatenated shape. Add this pattern to `populateFoldTensorEmptyPatterns`. Test Plan: Reviewers: Subscribers: Tasks: Tags: Differential Revision: https://phabricator.intern.facebook.com/D60250835
A concatenation of empty tensors can be replaced by a single empty tensor of the concatenated shape. Add this pattern to
populateFoldTensorEmptyPatterns
.