-
Notifications
You must be signed in to change notification settings - Fork 11.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][scf] Extend fuse producer to multi-level candidates case #97803
base: main
Are you sure you want to change the base?
Conversation
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-scf Author: None (Yun-Fly) ChangesAlthough producer can be already fused within nest loop structure at present, it didn't work for multi-level candidates
If we want to fuse producer This patch uses multiple application of existing tiling interface as same as another counterpart PR. Full diff: https://github.com/llvm/llvm-project/pull/97803.diff 5 Files Affected:
diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
index d68ca11207376..36888c3d6d607 100644
--- a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
+++ b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
@@ -136,11 +136,19 @@ struct SCFFuseProducerOfSliceResult {
Value tiledAndFusedProducer; // Tile and fused producer value.
SmallVector<Operation *> tiledOps;
};
+std::optional<SCFFuseProducerOfSliceResult>
+tileAndFuseProducerOfSliceImpl(RewriterBase &rewriter,
+ tensor::ExtractSliceOp candidateSliceOp,
+ MutableArrayRef<LoopLikeOpInterface> loops);
+
std::optional<SCFFuseProducerOfSliceResult>
tileAndFuseProducerOfSlice(RewriterBase &rewriter,
tensor::ExtractSliceOp candidateSliceOp,
MutableArrayRef<LoopLikeOpInterface> loops);
+std::optional<SCFFuseProducerOfSliceResult>
+tileAndFuseProducerOfSlice(RewriterBase &rewriter, Operation *candidateSliceOp);
+
/// Reconstruct the fused producer from within the tiled-and-fused code. Based
/// on the slice of the producer computed in place it is possible that within
/// the loop nest same slice of the producer is computed multiple times. It is
diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
index a1392813d6de3..d6baa69618dae 100644
--- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
@@ -835,7 +835,7 @@ getUntiledProducerFromSliceSource(OpOperand *source,
/// Implementation of fusing producer of a single slice by computing the
/// slice of the producer in-place.
std::optional<scf::SCFFuseProducerOfSliceResult>
-mlir::scf::tileAndFuseProducerOfSlice(
+mlir::scf::tileAndFuseProducerOfSliceImpl(
RewriterBase &rewriter, tensor::ExtractSliceOp candidateSliceOp,
MutableArrayRef<LoopLikeOpInterface> loops) {
// 1. Get the producer of the source (potentially walking through
@@ -949,6 +949,115 @@ mlir::scf::tileAndFuseProducerOfSlice(
tileAndFuseResult->tiledOps};
}
+/// Get the Root source of target ExtractSliceOp
+/// %0 =
+/// %1 = scf.for(%arg1 = %0)
+/// %2 = extract %arg1
+/// %3 = scf.for(%arg2 = %2)
+/// %4 = extract %args2
+/// ...
+/// @param targetSliceOp: %4 = extract %args2
+/// @param extractSliceOpChain: chain of all related extract sliceOp
+/// @return Value of Root Source : %0
+static FailureOr<Value> getRootSourceOfExtractSliceOp(
+ Operation *targetSliceOp,
+ SmallVectorImpl<tensor::ExtractSliceOp> &extractSliceOpChain,
+ int curDepth = 0, int maxDepth = 5) {
+ assert(isa<tensor::ExtractSliceOp>(targetSliceOp));
+ // control recursive time in avoid of stack overflow
+ if (curDepth > maxDepth)
+ return failure();
+
+ auto extractOp = cast<tensor::ExtractSliceOp>(targetSliceOp);
+ extractSliceOpChain.push_back(extractOp);
+ Value rootSource = extractOp.getSourceMutable().get();
+
+ while (true) {
+ if (auto iterArg = dyn_cast<BlockArgument>(rootSource)) {
+ if (auto outerLoop = dyn_cast<LoopLikeOpInterface>(
+ iterArg.getOwner()->getParentOp())) {
+ rootSource = outerLoop.getTiedLoopInit(iterArg)->get();
+ continue;
+ }
+ return failure();
+ } else if (auto sliceOp =
+ rootSource.getDefiningOp<tensor::ExtractSliceOp>()) {
+ // walk up loop to find larger candidate extractSliceOp
+ return getRootSourceOfExtractSliceOp(sliceOp, extractSliceOpChain,
+ curDepth + 1);
+ }
+ break;
+ }
+ return rootSource;
+}
+
+/// Return outer loops of given ForOp(included) until the predict function
+/// succeed(excluded), sorted from outer to inner.
+static SmallVector<LoopLikeOpInterface>
+getOuterLoopsUntil(LoopLikeOpInterface loop,
+ std::function<LogicalResult(LoopLikeOpInterface)> pred) {
+ SmallVector<LoopLikeOpInterface> outerLoops = {loop};
+ auto forOp = loop->getParentOfType<LoopLikeOpInterface>();
+ while (forOp) {
+ if (succeeded(pred(forOp)))
+ break;
+ outerLoops.push_back(forOp);
+ forOp = forOp->getParentOfType<LoopLikeOpInterface>();
+ }
+ return {outerLoops.rbegin(), outerLoops.rend()};
+}
+
+/// Enhanced version of `tileAndFuseProducerOfSliceImpl`, which can deal with
+/// multi-level `extractSliceOp`. E.g.
+///
+/// %0 = untiled_producer
+/// %1 = scf.for(%arg1 = %0)
+/// %2 = extract %arg1
+/// %3 = scf.for(%arg2 = %2)
+/// %4 = extract %args2
+/// %5 = tiled_consumer ins(%4)
+std::optional<scf::SCFFuseProducerOfSliceResult>
+mlir::scf::tileAndFuseProducerOfSlice(RewriterBase &rewriter,
+ Operation *candidateSliceOp) {
+ SmallVector<tensor::ExtractSliceOp> sliceOpChain;
+ if (failed(getRootSourceOfExtractSliceOp(candidateSliceOp, sliceOpChain))) {
+ return std::nullopt;
+ }
+
+ std::optional<scf::SCFFuseProducerOfSliceResult> fuseProducerResult;
+ // reverse from outer to inner
+ std::reverse(sliceOpChain.begin(), sliceOpChain.end());
+ // multiple application of `tileAndFuseProducerOfSliceImpl`
+ for (auto &&[index, sliceOp] : llvm::enumerate(sliceOpChain)) {
+ Operation *upperSliceOp = index ? sliceOpChain[index - 1] : nullptr;
+ auto untilUpperSliceFound =
+ [&upperSliceOp](LoopLikeOpInterface loop) -> LogicalResult {
+ if (upperSliceOp) {
+ Block &body = loop->getRegion(0).front();
+ if (upperSliceOp->getBlock() == &body)
+ return success();
+ }
+ return failure();
+ };
+ SmallVector<LoopLikeOpInterface> outerLoops = getOuterLoopsUntil(
+ sliceOp->getParentOfType<LoopLikeOpInterface>(), untilUpperSliceFound);
+ fuseProducerResult =
+ tileAndFuseProducerOfSliceImpl(rewriter, sliceOp, outerLoops);
+ if (!fuseProducerResult) {
+ return std::nullopt;
+ }
+ }
+ return fuseProducerResult;
+}
+
+/// To be compatible with previous behavior
+std::optional<scf::SCFFuseProducerOfSliceResult>
+mlir::scf::tileAndFuseProducerOfSlice(
+ RewriterBase &rewriter, tensor::ExtractSliceOp candidateSliceOp,
+ MutableArrayRef<LoopLikeOpInterface> loops) {
+ return tileAndFuseProducerOfSliceImpl(rewriter, candidateSliceOp, loops);
+}
+
/// Reconstruct the fused producer from within the tiled-and-fused code.
LogicalResult mlir::scf::yieldReplacementForFusedProducer(
RewriterBase &rewriter, tensor::ExtractSliceOp sliceOp,
diff --git a/mlir/test/Interfaces/TilingInterface/tile-and-fuse-producer.mlir b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-producer.mlir
new file mode 100644
index 0000000000000..ef1c6952a55e1
--- /dev/null
+++ b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-producer.mlir
@@ -0,0 +1,86 @@
+// RUN: mlir-opt --transform-interpreter --cse --split-input-file %s | FileCheck %s
+
+#map = affine_map<(d0) -> (d0 * 128)>
+module {
+ func.func @gemm_fill_fusion_multi_level_extract_slice(%arg0: tensor<256x512xf32>, %arg1: tensor<512x256xf32>, %arg2: tensor<256x256xf32>) -> tensor<256x256xf32> {
+ %c0 = arith.constant 0 : index
+ %c64 = arith.constant 64 : index
+ %c128 = arith.constant 128 : index
+ %cst = arith.constant 0.000000e+00 : f32
+ %dest0 = tensor.empty() : tensor<256x256xf32>
+ %dest1 = linalg.fill ins(%cst : f32) outs(%dest0 : tensor<256x256xf32>) -> tensor<256x256xf32>
+ %1 = scf.forall (%arg3, %arg4) in (2, 2) shared_outs(%arg5 = %dest1) -> tensor<256x256xf32> {
+ %iv0 = affine.apply #map(%arg3)
+ %iv1 = affine.apply #map(%arg4)
+ %extracted_slice_1 = tensor.extract_slice %arg5[%iv0, %iv1] [128, 128] [1, 1] : tensor<256x256xf32> to tensor<128x128xf32>
+ %extracted_slice_2 = tensor.extract_slice %arg0[%iv0, 0] [128, 512] [1, 1] : tensor<256x512xf32> to tensor<128x512xf32>
+ %extracted_slice_3 = tensor.extract_slice %arg1[0, %iv1] [512, 128] [1, 1] : tensor<512x256xf32> to tensor<512x128xf32>
+ %2 = scf.for %arg6 = %c0 to %c128 step %c64 iter_args(%arg7 = %extracted_slice_1) -> (tensor<128x128xf32>) {
+ %3 = scf.for %arg8 = %c0 to %c128 step %c64 iter_args(%arg9 = %arg7) -> (tensor<128x128xf32>) {
+ %extracted_slice_4 = tensor.extract_slice %arg9[%arg6, %arg8] [64, 64] [1, 1] : tensor<128x128xf32> to tensor<64x64xf32>
+ %extracted_slice_5 = tensor.extract_slice %extracted_slice_2[%arg6, 0] [64, 512] [1, 1] : tensor<128x512xf32> to tensor<64x512xf32>
+ %extracted_slice_6 = tensor.extract_slice %extracted_slice_3[0, %arg8] [512, 64] [1, 1] : tensor<512x128xf32> to tensor<512x64xf32>
+ %4 = linalg.matmul ins(%extracted_slice_5, %extracted_slice_6 : tensor<64x512xf32>, tensor<512x64xf32>) outs(%extracted_slice_4 : tensor<64x64xf32>) -> tensor<64x64xf32>
+ %insert_slice = tensor.insert_slice %4 into %arg9[%arg6, %arg8] [64, 64] [1, 1] : tensor<64x64xf32> into tensor<128x128xf32>
+ scf.yield %insert_slice : tensor<128x128xf32>
+ }
+ scf.yield %3 : tensor<128x128xf32>
+ }
+ scf.forall.in_parallel {
+ tensor.parallel_insert_slice %2 into %arg5[%iv0, %iv1] [128, 128] [1, 1] : tensor<128x128xf32> into tensor<256x256xf32>
+ }
+ }
+ return %1 : tensor<256x256xf32>
+ }
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
+ %matmul = transform.structured.match ops{["linalg.matmul"]} in %arg1
+ : (!transform.any_op) -> !transform.any_op
+ %yield = transform.get_producer_of_operand %matmul[2]
+ : (!transform.any_op) -> !transform.any_op
+ %a, %b = transform.test.fuse_producer %yield
+ : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+ transform.yield
+ }
+}
+
+// CHECK: #[[MAP0:.*]] = affine_map<(d0) -> (d0 * 128)>
+// CHECK: func.func @gemm_fill_fusion_multi_level_extract_slice(
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<256x512xf32>
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<512x256xf32>
+// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: tensor<256x256xf32>
+// CHECK: %[[C0:.*]] = arith.constant 0 : index
+// CHECK: %[[dest0:.*]] = tensor.empty() : tensor<256x256xf32>
+// CHECK: %[[FORALL_RESULT:.*]] = scf.forall (%[[IV1:.*]], %[[IV2:.*]]) in (2, 2)
+// CHECK-SAME: shared_outs(%[[INIT_ARG0:.*]] = %[[dest0]])
+// CHECK-SAME: {
+// CHECK: %[[AFFINE_IV1:.*]] = affine.apply #[[MAP0]](%[[IV1]])
+// CHECK: %[[AFFINE_IV2:.*]] = affine.apply #[[MAP0]](%[[IV2]])
+// CHECK: %[[FILL_OUT_SLICE0:.*]] = tensor.extract_slice %[[INIT_ARG0]][%[[AFFINE_IV1]], %[[AFFINE_IV2]]] [128, 128] [1, 1]
+// CHECK: %[[INPUT_SLICE0:.*]] = tensor.extract_slice %[[ARG0]][%[[AFFINE_IV1]], 0] [128, 512] [1, 1]
+// CHECK: %[[WEIGHT_SLICE0:.*]] = tensor.extract_slice %[[ARG1]][0, %[[AFFINE_IV2]]] [512, 128] [1, 1]
+// CHECK: %[[LOOP_RESULT1:.*]] = scf.for %[[IV3:.*]] = %[[C0]]
+// CHECK-SAME: iter_args(%[[INIT_ARG1:.*]] = %[[FILL_OUT_SLICE0]])
+// CHECK-SAME: {
+// CHECK: %[[LOOP_RESULT2:.*]] = scf.for %[[IV4:.*]] = %[[C0]]
+// CHECK-SAME: iter_args(%[[INIT_ARG2:.*]] = %[[INIT_ARG1]])
+// CHECK-SAME: {
+// CHECK: %[[FILL_OUT_SLICE1:.*]] = tensor.extract_slice %[[INIT_ARG2]][%[[IV3]], %[[IV4]]] [64, 64] [1, 1]
+// CHECK: %[[TILED_FILL_OUT:.*]] = linalg.fill
+// CHECK-SAME: outs(%[[FILL_OUT_SLICE1]] :
+// CHECK: %[[INPUT_SLICE1:.*]] = tensor.extract_slice %[[INPUT_SLICE0]][%[[IV3]], 0] [64, 512] [1, 1]
+// CHECK: %[[WEIGHT_SLICE1:.*]] = tensor.extract_slice %[[WEIGHT_SLICE0]][0, %[[IV4]]] [512, 64] [1, 1]
+// CHECK: %[[TILED_MAT_OUT:.*]] = linalg.matmul
+// CHECK-SAME: outs(%[[TILED_FILL_OUT]] :
+// CHECK: %[[INSERT_MAT:.*]] = tensor.insert_slice %[[TILED_MAT_OUT]] into %[[INIT_ARG2]][%[[IV3]], %[[IV4]]] [64, 64] [1, 1]
+// CHECK: scf.yield %[[INSERT_MAT]] :
+// CHECK: }
+// CHECK: scf.yield %[[LOOP_RESULT2]] :
+// CHECK: }
+// CHECK: scf.forall.in_parallel {
+// CHECK: tensor.parallel_insert_slice %[[LOOP_RESULT1]] into %[[INIT_ARG0]][%[[AFFINE_IV1]], %[[AFFINE_IV2]]] [128, 128] [1, 1]
+// CHECK: }
+// CHECK: }
+// CHECK: return %[[FORALL_RESULT]] :
\ No newline at end of file
diff --git a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp
index 8f206d9077272..8d3d5b23a53e9 100644
--- a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp
+++ b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp
@@ -160,6 +160,56 @@ transform::TestFuseAndYieldOp::apply(TransformRewriter &rewriter,
: DiagnosedSilenceableFailure::success();
}
+//===----------------------------------------------------------------------===//
+// TestFuseProducerOp
+//===----------------------------------------------------------------------===//
+
+/// Apply fusing of producer transformation to all payload ops and store both
+/// the original producer operation as well as the fused producer operation.
+template <typename Range>
+static LogicalResult
+applyFuseProducer(RewriterBase &rewriter, Operation *transformOp,
+ Range &&payloadOps, TransformResults &transformResults) {
+ SmallVector<Operation *> originalProducerOps;
+ SmallVector<Operation *> fusedProducerOps;
+
+ for (Operation *target : payloadOps) {
+ rewriter.setInsertionPoint(target);
+
+ std::optional<scf::SCFFuseProducerOfSliceResult> fuseProducerResults =
+ scf::tileAndFuseProducerOfSlice(rewriter, target);
+
+ if (!fuseProducerResults)
+ return failure();
+
+ // Report back the relevant handles to the transform op.
+ originalProducerOps.push_back(fuseProducerResults->origProducer.getOwner());
+ fusedProducerOps.push_back(fuseProducerResults->tiledOps[0]);
+ }
+
+ transformResults.set(transformOp->getOpResult(0), originalProducerOps);
+ transformResults.set(transformOp->getOpResult(1), fusedProducerOps);
+ return success();
+}
+
+DiagnosedSilenceableFailure
+transform::TestFuseProducerOp::apply(TransformRewriter &rewriter,
+ TransformResults &transformResults,
+ TransformState &state) {
+ LogicalResult result =
+ applyFuseProducer(rewriter, getOperation(),
+ state.getPayloadOps(getTarget()), transformResults);
+ return failed(result) ? DiagnosedSilenceableFailure::definiteFailure()
+ : DiagnosedSilenceableFailure::success();
+}
+
+void transform::TestFuseProducerOp::getEffects(
+ SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+ consumesHandle(getTargetMutable(), effects);
+ producesHandle(getOperation()->getOpResults(), effects);
+ modifiesPayload(effects);
+}
+
//===----------------------------------------------------------------------===//
// TestFuseConsumerOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.td b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.td
index d55d746bd6aa9..6e73478c35c4a 100644
--- a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.td
+++ b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.td
@@ -49,6 +49,25 @@ def TestFuseAndYieldOp : Op<Transform_Dialect, "test.fuse_and_yield",
}];
}
+def TestFuseProducerOp : Op<Transform_Dialect, "test.fuse_producer",
+ [DeclareOpInterfaceMethods<TransformOpInterface>,
+ DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
+ ReportTrackingListenerFailuresOpTrait]> {
+ let description = [{
+ Fuses the producer of the operation pointed to by the target handle
+ using the options provided as attributes.
+ }];
+
+ let arguments =
+ (ins TransformHandleTypeInterface:$target);
+ let results = (outs TransformHandleTypeInterface:$producer,
+ TransformHandleTypeInterface:$fused_producer);
+
+ let assemblyFormat = [{
+ $target attr-dict `:` functional-type(operands, results)
+ }];
+}
+
def TestFuseConsumerOp : Op<Transform_Dialect, "test.fuse_consumer",
[DeclareOpInterfaceMethods<TransformOpInterface>,
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
|
a4c2778
to
8c9a585
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! Please give me some time to review. Broadly my concern is how do you control which loop nest you want to fuse into when you are doing multiple tiling levels... I can see it is being controlled by API knobs, but those seem fragile. I need to look more in detail though.
Good question! In plan, this PR just focuses on the functionality extension. In another word, user is responsible for deciding which loop nest to fuse by feeding specific candidate into BTW, as introduced in this early RFC, we intend to design a higher-level utility or API to control which loop nest to fuse based on semantic check(validity) and cost model(performance), both of which is expected to be designed as an call-back option with default behavior. I can give you more example for details:
NOTE: fusing consumer has the same concern as producer. A typical use case is fusing As the result, I prefer to use another PR to particularly deal with which loops to fuse in multiple level tiling when both of this patch and #94190 merged. Looking forward to hear your advice! |
6d38e2d
to
9feadc4
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
THis seems related to the other PR (#94190). Lets come back to this after we have some idea of that. Or if you want to try to land this first, we can switch the order around.
I have no specific request about the priory order so far. Its up to you or which one is closer to land status in your mind. Hopefully both of them would be merged ASAP because another coming PR may depend on them. |
Ok, lets do the producer fusion first, cause that is easier to navigate than the consumer fusion. (I know you say ASAP, but this is involved code and reviews take a time. Ill do my best to keep it timely, but I am juggling this with other work that is a bit higher priority for my employer) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This PR is much easier to review... Ill ask a few more questions to help navigate this.
Trying to put all of this logic upstream is hard. This seems like a very special case where you already have someone that is partially tiling the code and you are trying to reuse some logic from tile and fuse for this case. But there is too much of an assumption being made that the loop structure you create is consistent with what tiling creates. In fact that assumption is already made by the implementation. That is why the current mlir::tileAndFuseProducerOfSlice
is taking the loop nest generated by MutableArrayRef<LoopLikeOpInterface> loops
as arguments.
I'd like to understand if there is a way you can keep the logic of finding loops closer to your context and then consider if we need to change the implementation of how the producer is found by walking the loop nest and extract slice source definition chain. Does that question make sense?
The loop structure created in this PR is exactly what multiple level tiling creates. So it quite looks consistent with what tiling creates, but also different in multiple level of nest loops and candidate slices for better utilization of parallelism and memory hierarchy. I will detail more our context at last.
Yes, I know it and that is one of reason why this PR looks easier reviewer than another patch of fusing consumer. IIRC, current
Our context is that some developers want to costume the way how to tile the certain op instead of directly using general but common |
I would like to express my sincere gratitude for your involving! |
/// @param candidateSliceOp: %4 = extract %args2 | ||
/// @param backwardSlice: in-out parameter populated by backward extractSliceOps | ||
/// @return OpResult Producer : %0 = producer | ||
static FailureOr<OpResult> getRealProducerFromExtractSliceOp( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Instead of doing this, would it help if you just folded the extract_slices
using something like this pattern
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I know that we can merge consecutive extractSlice
. But, I am sorry that we have to just keep these consecutive extractSlice
for selection. The most intuitive example is as below:
%0 = producer
%1 = scf.for(%arg1 = %0)
%2 = extract %arg1
%3 = scf.for(%arg2 = %2)
%4 = extract %args2
In some cases, the producer could be only fused into the outer candidate slice due to semantic of producer, i.e. reduce/pack/unpack those ops have specific demands of tiling size. If we merge the two consecutive extractSlice
, it may break this fusion. And that is why we emphasize multi-level candidates here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That is OK, but the point here is to fuse with %2
you dont need to look through extract_slices
.... Just not sure of why there would be a case where you need to look through extract_slices to get to the producer, when you cant collapse all the extract slices to begin with.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That is just simple test. If we have one more candidate slice, saying:
%0 = producer
%1 = scf.for(%arg1 = %0)
%2 = extract %arg1
%3 = scf.for(%arg2 = %2)
%4 = extract %args2
%5 = scf.for(%arg3 = %4)
%6 = extract %args3
tiled_consumer ins(%6)
where %2
and %4
are both valid candidate to fuse producer. And the user just want to manually pass %4
to tileAndFuseProducerOfSlice
because %4
usually has smaller tileSize in most cases. Then, the current getUntiledProducerFromSliceSource
will fail to get real producer without looking through extract_slices
.
Another use case is that if someone want to enable automatic fusion by starting with tiled_consumer
just like what tileConsumerAndFuseProducersUsingSCF
did. The available extract slice we can find via operand of tiled_consumer
is only the nearest candidate %6
. Then, we also need to look through all extract_slices
to get real producer.
Meanwhile, we dont know which candidate slice could be fuse until we know what the exact producer is. Thus, it is unreasonable to collapse all the extract slices before fusion.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am not sure this reasoning holds. If you want to fuse %4
and not %2
you are effectively "fusing" the slices and the doing the producer fusion. It would be better to stage it so that you fuse the slices so that you get the producer directly and then do the fusion. So effectively this change is trying to do both of this
- Combine extract slices
- Do producer fusion
in code. That is just added complexity. It would be easier to just combine the extract slices first and then fuse the producer.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
%0 = producerOp
scf.for(%1=%0)
%2 = extract_slice %1
scf.for(%3=%2)
%4 = extract_slice %3
Besides, I am worried about merging two candidates interleaved by a scf.for
with its inits
involved is not trivial...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi, @MaheshRavishankar. Since PR #108318 has been broken down into three parts:
- support single nested
scf.for
(merged, great thanks for your review!) - multi-level candidates
- multiple users
I think the second part is similar to this patch, no matter producer or consumer fusion. In general, this patch focus on how to deal with consecutive candidates interleaved by a scf.for
.
Lets look back what the major argument remains so far to speed up our next stage of review:
- why need to look through the total chain of candidate slice?
- why not merge two consecutive candidates?
The answer is listed in above several threads. The root cause behind are two different solutions to solve multi-level candidates cases:
- merge consecutive candidates in advance.
- iteratively fuse producer/consumer with one of candidates step by step(from outer to inner).
IMO, the most concern about the first solution is that it may introduce too many additional transform regarding scf.for
. E.g..
%0 = producerOp
scf.for(%1=%0)
%2 = extract_slice %1
scf.for(%3=%2)
%4 = extract_slice %3
...
yield %x
As I explained above, this scenario is a little different from what MergeConsecutiveExtractSlice
expects due to scf.for
.
If we just merge consecutive extract slice(BTW, here we also need look through to penetrate scf.for
), the possible resultant IR looks like below:
%0 = producerOp
scf.for(%1=%0)
%2 = extract_slice %1
scf.for(%3=%2)
%4 = extract_slice %1 [newOffsets] [newSizes]
...
yield %x
Then the problem is that how we modify new inits
of scf.for
with latest dpsInit
of fused producerOp
? (getUntiledProducerFromSliceSource
will fail to find destinationIterArg
)
On the other hand, if we modify %3=%2
at the same time, the resultant IR would become:
%0 = producerOp
scf.for(%1=%0)
%2 = extract_slice %1
scf.for(%3=%1)
%4 = extract_slice %3 [newOffsets] [newSizes]
...
yield %x_new
First of all, this transform has already introduced more code change and there is an assumption that %3
has only single one user. Moreover, considering the semantic of scf.for
, it seems that we also need to modify yield value %x
(and even its producer, like insert_slice
) to make the result of scf.for
matched with its inits
.
In summary, I am afraid that this kind of complexity is much more than the second solution based on simple iterative process of reusing existed function.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanka for the explanation. I see the issues a bit more clearly now. I havent reviewed the code yet. I dont know if this has been rebased on top of what was submitted.
But here is the complexity that I am concerned about. In you example, you need to somewhere keep track of the sequence of extract slices that you need to walk to get the producer because the actual offset and size you need is obtained by "combining" the offsets and sizes of all the slices to get the "real offset" and size. Also I am not convinced that fusing the first extract slice + producer and then doing the second extract slice + producer is not feasible. That should be always possible.
So if you start with this
%0 = producerOp
scf.for(%1=%0)
%2 = extract_slice %1
scf.for(%3=%2)
%4 = extract_slice %3
...
yield %x
you should always be able to do
scf.for(%1=%0)
%2 = tiled producer 1
scf.for(%3=%2)
%4 = extract_slice %3
...
yield %x
and then you do
scf.for(%1=%0)
scf.for(%3=%2)
%4 = tiled producer 2
...
yield %x
The amount of state you need to carry during the transformation to fuse with one extract slice is prohibitively high IMO. That will make it hard to change/fix the transformation.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I dont know if this has been rebased on top of what was submitted.
Not rebased yet, because I think it is more import to reach agreement on solution before pushing...
you need to somewhere keep track of the sequence of extract slices that you need to walk to get the producer because the actual offset and size you need is obtained by "combining" the offsets and sizes of all the slices to get the "real offset" and size.
Yes, from the view of consequentialism, the real offset is indeed combined. However, there are something else we need to address, like what I explained above regarding loop transform. In this way, each two of extract_slice (interleaved by a scf.for
) merged, the certain scf.for
(and its yieldValue, etc...) need transformed. Is the amount of state need to carry similar to what I do in iterative fashion?
Also I am not convinced that fusing the first extract slice + producer and then doing the second extract slice + producer is not feasible. That should be always possible.
It is decided by concrete semantic of tilable op, such as reduce
/pack
/unpack
op. Lets say fusing producer unpack
into following loop:
// unpack tensor from ABab to AB
%1 = tensor.unpack ... inner_tiles = [32, 32] ... : tensor<2x2x32x32xbf16> -> tensor<64x64xbf16>
scf.for(0, 2) { // tile A dimension
extract_slice1 // tileSize<32, 64>
scf.for(0, 4) { // tile B dimension
extract_slice2 // tileSize<32, 16>
...
}
}
As you can see, the tileSize
comes from extract_slice2
is <32,16>
, but tensor.unpack
prefer perfect tiling case, i.e. tileSize should be exactly divided by inner_tiles
. So, it may be not feasible in this case.
BTW, fusing consumer reduce
maybe more intuitive:
%0 = scf.for(0, 2) { // tile 128 dimension
scf.for(0, 4) { // tile B dimension
...
insert_slice1 // tileSize<64, 64>
}
insert_slice1 // tileSize<64, 256>
}
%2 = linalg.reduce { arith.addf } ins(%0 : tensor<128x256xf32>) outs(%1: tensor<128xf32>) dimensions = [1]
We could not furtherly fuse reduce
into the inner-most insert slice, otherwise, it will lead to partial reduce(its another topic).
So if you start with this
you should always be able to do
and then you do
Yes, it is exactly what I do now. IIRC, that is also what you suggest before in nested consumer fusion thread... (In fact, my first version of implement regarding nested consumer fusion is just combining
offset and size of several candidate slice, however, it got negated because of too many code changes...)
That will make it hard to change/fix the transformation.
Current iterative method break down the whole transformation into several repeated logic by existed tiledAndFuseProducerOfSlice
, which should be much easier to debug, at least from my view..
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ping
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Left a few more comments. Thanks for being patient. I am a bit busy this week, but I am more free next week. If you are on discord (IREE discord or MLIR discord), we can connect there and I am happy to work with you with a faster cadence to help you land what you need upstream, and maybe suggest things you can keep downstream since it is valid only for your context.
Yes, surely. I just ping you there. Feel free to have a talk if you have time. |
I dont think I got the ping. |
9feadc4
to
c52cf40
Compare
c52cf40
to
23796bf
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I dont know if this commit is still relevant. I was looking for a commit where things were broken up a bit more. If this commit is still relevant, I responded to the outstanding comment here.
/// @param candidateSliceOp: %4 = extract %args2 | ||
/// @param backwardSlice: in-out parameter populated by backward extractSliceOps | ||
/// @return OpResult Producer : %0 = producer | ||
static FailureOr<OpResult> getRealProducerFromExtractSliceOp( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am not sure this reasoning holds. If you want to fuse %4
and not %2
you are effectively "fusing" the slices and the doing the producer fusion. It would be better to stage it so that you fuse the slices so that you get the producer directly and then do the fusion. So effectively this change is trying to do both of this
- Combine extract slices
- Do producer fusion
in code. That is just added complexity. It would be easier to just combine the extract slices first and then fuse the producer.
Although producer can be already fused within nest loop structure at present, it didn't work for multi-level candidates
extractSliceOp
case, E.g.If we want to fuse producer
fill
into%extracted_slice_4
, current status in master will be broken atgetUntiledProducerFromSliceSource
due to upper candidate slice%extracted_slice_1
. This PR extends fusing producer to multi-levelextractSliceOp
cases.This patch uses multiple application of existing tiling interface as same as another counterpart PR.