Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][scf] Extend fuse producer to multi-level candidates case #97803

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
Original file line number Diff line number Diff line change
Expand Up @@ -157,11 +157,15 @@ struct SCFFuseProducerOfSliceResult {
Value tiledAndFusedProducer; // Tile and fused producer value.
SmallVector<Operation *> tiledOps;
};

std::optional<SCFFuseProducerOfSliceResult>
tileAndFuseProducerOfSlice(RewriterBase &rewriter,
tensor::ExtractSliceOp candidateSliceOp,
MutableArrayRef<LoopLikeOpInterface> loops);

std::optional<SCFFuseProducerOfSliceResult>
tileAndFuseProducerOfSlice(RewriterBase &rewriter, Operation *candidateSliceOp);

/// Reconstruct the fused producer from within the tiled-and-fused code. Based
/// on the slice of the producer computed in place it is possible that within
/// the loop nest same slice of the producer is computed multiple times. It is
Expand Down
168 changes: 163 additions & 5 deletions mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1068,12 +1068,12 @@ getUntiledProducerFromSliceSource(OpOperand *source,
return {dyn_cast<OpResult>(source->get()), destinationIterArg};
}

/// Implementation of fusing producer of a single slice by computing the
/// Basic implementation of fusing producer of a single slice by computing the
/// slice of the producer in-place.
std::optional<scf::SCFFuseProducerOfSliceResult>
mlir::scf::tileAndFuseProducerOfSlice(
RewriterBase &rewriter, tensor::ExtractSliceOp candidateSliceOp,
MutableArrayRef<LoopLikeOpInterface> loops) {
static std::optional<scf::SCFFuseProducerOfSliceResult>
tileAndFuseProducerOfSliceImpl(RewriterBase &rewriter,
tensor::ExtractSliceOp candidateSliceOp,
MutableArrayRef<LoopLikeOpInterface> loops) {
// 1. Get the producer of the source (potentially walking through
// `iter_args` of nested `scf.for`)
auto [fusableProducer, destinationInitArg] =
Expand Down Expand Up @@ -1185,6 +1185,164 @@ mlir::scf::tileAndFuseProducerOfSlice(
tileAndFuseResult->tiledOps};
}

/// Get the real producer from candidate ExtractSliceOp
///
/// ```
/// %0 = producer
/// %1 = scf.for(%arg1 = %0)
/// %2 = extract %arg1
/// %3 = scf.for(%arg2 = %2)
/// %4 = extract %args2
/// ...
/// ```
///
/// @param candidateSliceOp: %4 = extract %args2
/// @param backwardSlice: in-out parameter populated by backward extractSliceOps
/// @return OpResult Producer : %0 = producer
static FailureOr<OpResult> getRealProducerFromExtractSliceOp(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of doing this, would it help if you just folded the extract_slices using something like this pattern

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I know that we can merge consecutive extractSlice. But, I am sorry that we have to just keep these consecutive extractSlice for selection. The most intuitive example is as below:

%0 = producer
%1 = scf.for(%arg1 = %0)
   %2 = extract %arg1
   %3 = scf.for(%arg2 = %2)
      %4 = extract %args2

In some cases, the producer could be only fused into the outer candidate slice due to semantic of producer, i.e. reduce/pack/unpack those ops have specific demands of tiling size. If we merge the two consecutive extractSlice, it may break this fusion. And that is why we emphasize multi-level candidates here.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That is OK, but the point here is to fuse with %2 you dont need to look through extract_slices.... Just not sure of why there would be a case where you need to look through extract_slices to get to the producer, when you cant collapse all the extract slices to begin with.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That is just simple test. If we have one more candidate slice, saying:

%0 = producer
%1 = scf.for(%arg1 = %0)
   %2 = extract %arg1
   %3 = scf.for(%arg2 = %2)
      %4 = extract %args2
      %5 = scf.for(%arg3 = %4)
          %6 = extract %args3
          tiled_consumer ins(%6)

where %2 and %4 are both valid candidate to fuse producer. And the user just want to manually pass %4 to tileAndFuseProducerOfSlice because %4 usually has smaller tileSize in most cases. Then, the current getUntiledProducerFromSliceSource will fail to get real producer without looking through extract_slices.

Another use case is that if someone want to enable automatic fusion by starting with tiled_consumer just like what tileConsumerAndFuseProducersUsingSCF did. The available extract slice we can find via operand of tiled_consumer is only the nearest candidate %6. Then, we also need to look through all extract_slices to get real producer.

Meanwhile, we dont know which candidate slice could be fuse until we know what the exact producer is. Thus, it is unreasonable to collapse all the extract slices before fusion.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure this reasoning holds. If you want to fuse %4 and not %2 you are effectively "fusing" the slices and the doing the producer fusion. It would be better to stage it so that you fuse the slices so that you get the producer directly and then do the fusion. So effectively this change is trying to do both of this

  1. Combine extract slices
  2. Do producer fusion
    in code. That is just added complexity. It would be easier to just combine the extract slices first and then fuse the producer.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

%0 = producerOp
scf.for(%1=%0)
   %2 = extract_slice %1
   scf.for(%3=%2)
      %4 = extract_slice %3

Besides, I am worried about merging two candidates interleaved by a scf.for with its inits involved is not trivial...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi, @MaheshRavishankar. Since PR #108318 has been broken down into three parts:

  1. support single nested scf.for (merged, great thanks for your review!)
  2. multi-level candidates
  3. multiple users

I think the second part is similar to this patch, no matter producer or consumer fusion. In general, this patch focus on how to deal with consecutive candidates interleaved by a scf.for.

Lets look back what the major argument remains so far to speed up our next stage of review:

  1. why need to look through the total chain of candidate slice?
  2. why not merge two consecutive candidates?

The answer is listed in above several threads. The root cause behind are two different solutions to solve multi-level candidates cases:

  1. merge consecutive candidates in advance.
  2. iteratively fuse producer/consumer with one of candidates step by step(from outer to inner).

IMO, the most concern about the first solution is that it may introduce too many additional transform regarding scf.for. E.g..

%0 = producerOp
scf.for(%1=%0)
   %2 = extract_slice %1
   scf.for(%3=%2)
      %4 = extract_slice %3
      ... 
      yield %x

As I explained above, this scenario is a little different from what MergeConsecutiveExtractSlice expects due to scf.for.

If we just merge consecutive extract slice(BTW, here we also need look through to penetrate scf.for ), the possible resultant IR looks like below:

%0 = producerOp
scf.for(%1=%0)
   %2 = extract_slice %1
   scf.for(%3=%2)
      %4 = extract_slice %1 [newOffsets] [newSizes]
      ... 
      yield %x

Then the problem is that how we modify new inits of scf.for with latest dpsInit of fused producerOp? (getUntiledProducerFromSliceSource will fail to find destinationIterArg)

On the other hand, if we modify %3=%2 at the same time, the resultant IR would become:

%0 = producerOp
scf.for(%1=%0)
   %2 = extract_slice %1
   scf.for(%3=%1)
      %4 = extract_slice %3 [newOffsets] [newSizes]
      ... 
      yield %x_new

First of all, this transform has already introduced more code change and there is an assumption that %3 has only single one user. Moreover, considering the semantic of scf.for, it seems that we also need to modify yield value %x(and even its producer, like insert_slice) to make the result of scf.for matched with its inits .

In summary, I am afraid that this kind of complexity is much more than the second solution based on simple iterative process of reusing existed function.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanka for the explanation. I see the issues a bit more clearly now. I havent reviewed the code yet. I dont know if this has been rebased on top of what was submitted.
But here is the complexity that I am concerned about. In you example, you need to somewhere keep track of the sequence of extract slices that you need to walk to get the producer because the actual offset and size you need is obtained by "combining" the offsets and sizes of all the slices to get the "real offset" and size. Also I am not convinced that fusing the first extract slice + producer and then doing the second extract slice + producer is not feasible. That should be always possible.

So if you start with this

%0 = producerOp
scf.for(%1=%0)
   %2 = extract_slice %1
   scf.for(%3=%2)
      %4 = extract_slice %3
      ... 
      yield %x

you should always be able to do

scf.for(%1=%0)
   %2 = tiled producer 1
   scf.for(%3=%2)
      %4 = extract_slice %3
      ... 
      yield %x

and then you do

scf.for(%1=%0)
   scf.for(%3=%2)
      %4 = tiled producer 2
      ... 
      yield %x

The amount of state you need to carry during the transformation to fuse with one extract slice is prohibitively high IMO. That will make it hard to change/fix the transformation.

Copy link
Contributor Author

@Yun-Fly Yun-Fly Sep 19, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I dont know if this has been rebased on top of what was submitted.

Not rebased yet, because I think it is more import to reach agreement on solution before pushing...

you need to somewhere keep track of the sequence of extract slices that you need to walk to get the producer because the actual offset and size you need is obtained by "combining" the offsets and sizes of all the slices to get the "real offset" and size.

Yes, from the view of consequentialism, the real offset is indeed combined. However, there are something else we need to address, like what I explained above regarding loop transform. In this way, each two of extract_slice (interleaved by a scf.for) merged, the certain scf.for(and its yieldValue, etc...) need transformed. Is the amount of state need to carry similar to what I do in iterative fashion?

Also I am not convinced that fusing the first extract slice + producer and then doing the second extract slice + producer is not feasible. That should be always possible.

It is decided by concrete semantic of tilable op, such as reduce/pack/unpack op. Lets say fusing producer unpack into following loop:

// unpack tensor from ABab to AB
%1 = tensor.unpack ... inner_tiles = [32, 32] ... : tensor<2x2x32x32xbf16> -> tensor<64x64xbf16>
scf.for(0, 2) { // tile A dimension
 extract_slice1 // tileSize<32, 64>
 scf.for(0, 4) { // tile B dimension
    extract_slice2 // tileSize<32, 16>
    ...
 }
}

As you can see, the tileSize comes from extract_slice2 is <32,16>, but tensor.unpack prefer perfect tiling case, i.e. tileSize should be exactly divided by inner_tiles. So, it may be not feasible in this case.

BTW, fusing consumer reduce maybe more intuitive:

%0 = scf.for(0, 2) { // tile 128 dimension
 scf.for(0, 4) { // tile B dimension
    ...
   insert_slice1 // tileSize<64, 64>
 }
 insert_slice1 // tileSize<64, 256>
}
%2 = linalg.reduce { arith.addf } ins(%0 : tensor<128x256xf32>) outs(%1: tensor<128xf32>) dimensions = [1]

We could not furtherly fuse reduce into the inner-most insert slice, otherwise, it will lead to partial reduce(its another topic).

So if you start with this

you should always be able to do

and then you do

Yes, it is exactly what I do now. IIRC, that is also what you suggest before in nested consumer fusion thread... (In fact, my first version of implement regarding nested consumer fusion is just combining offset and size of several candidate slice, however, it got negated because of too many code changes...)

That will make it hard to change/fix the transformation.

Current iterative method break down the whole transformation into several repeated logic by existed tiledAndFuseProducerOfSlice, which should be much easier to debug, at least from my view..

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ping

Operation *candidateSliceOp,
SmallVector<tensor::ExtractSliceOp> &backwardSlice, int curDepth = 0,
int maxDepth = 5) {
if (!isa<tensor::ExtractSliceOp>(candidateSliceOp))
return failure();
// control recursive time in avoid of stack overflow
if (curDepth > maxDepth)
return failure();

auto extractOp = cast<tensor::ExtractSliceOp>(candidateSliceOp);
backwardSlice.push_back(extractOp);
Value rootSource = extractOp.getSourceMutable().get();

while (true) {
if (auto iterArg = dyn_cast<BlockArgument>(rootSource)) {
if (auto outerLoop = dyn_cast<LoopLikeOpInterface>(
iterArg.getOwner()->getParentOp())) {
rootSource = outerLoop.getTiedLoopInit(iterArg)->get();
continue;
}
return failure();
} else if (auto sliceOp =
rootSource.getDefiningOp<tensor::ExtractSliceOp>()) {
// walk up loop to find larger candidate extractSliceOp
return getRealProducerFromExtractSliceOp(sliceOp, backwardSlice,
curDepth + 1);
}
break;
}
return dyn_cast<OpResult>(rootSource);
}

/// Recursively find the outer nest loops of given loop(included) while the
/// predict function succeed, sorted from outer to inner.
///
/// @param loop: target loop, note that this loop will be also included. I.e.
/// if no other nest loops were found, just return itself.
/// @param pred: predict function, the termination condition of recursive
/// process.
/// @return Outer Nest Loops: nest loops outside given target loop(included).
///
/// E.g.
///
/// ```
/// %0 = scf.for()
/// %1 = scf.for()
/// %2 = scf.for()
/// ```
///
/// If `%2 = scf.for` is given without specific prediction function, this
/// function will return three nest loops: %0 + %1 + %2.
static SmallVector<LoopLikeOpInterface>
getOuterNestLoopsWhile(LoopLikeOpInterface loop,
function_ref<LogicalResult(LoopLikeOpInterface)> pred) {
SmallVector<LoopLikeOpInterface> nestLoops = {loop};
auto outerLoop = dyn_cast<LoopLikeOpInterface>(loop->getParentOp());
while (outerLoop && succeeded(pred(outerLoop))) {
nestLoops.push_back(outerLoop);
outerLoop = dyn_cast<LoopLikeOpInterface>(outerLoop->getParentOp());
}
// sorted from outer to inner
return {nestLoops.rbegin(), nestLoops.rend()};
}

/// Check if it is the ForOp that yield the result of inner loop
static LogicalResult isForOpYieldResultOfInnerLoop(LoopLikeOpInterface loop) {
if (auto forOp = dyn_cast<scf::ForOp>(loop.getOperation())) {
Block::OpListType &opsInLoopBody = forOp.getBody()->getOperations();
for (auto &&[index, op] : llvm::enumerate(opsInLoopBody)) {
// If the orderIndex of inner loop is the last second one before the
// yieldOp of ForOp, the given loop must yield the result of inner loop.
if (isa<LoopLikeOpInterface>(op)) {
return success((index + 2) == opsInLoopBody.size());
}
}
}
return failure();
}

/// Enhanced version for basic implementation of fusing producer, which can deal
/// with multi-level candidates. E.g.
///
/// ```
/// %0 = untiled_producer
/// %1 = scf.for(%arg1 = %0)
/// %2 = tensor.extract_slice %arg1
/// %3 = scf.for(%arg2 = %2)
/// %4 = tensor.extract_slice %args2
/// %5 = tiled_consumer ins(%4)
/// ```
///
/// This utility can fuse untiled producer at `%4 = tensor.extract_slice` within
/// inner loop `%3 = scf.for`.
std::optional<scf::SCFFuseProducerOfSliceResult>
mlir::scf::tileAndFuseProducerOfSlice(RewriterBase &rewriter,
Operation *candidateSliceOp) {
SmallVector<tensor::ExtractSliceOp> backwardSlice;
FailureOr<OpResult> realProducer =
getRealProducerFromExtractSliceOp(candidateSliceOp, backwardSlice);
if (failed(realProducer))
return std::nullopt;

std::optional<scf::SCFFuseProducerOfSliceResult> fuseProducerResult;
// reverse from outer to inner
std::reverse(backwardSlice.begin(), backwardSlice.end());
// multiple application of `tileAndFuseProducerOfSliceImpl`
for (auto &&[index, sliceOp] : llvm::enumerate(backwardSlice)) {
// get nest loops between next candidate sliceOp and tiled producer.
auto whileProducerOutOfLoopBlock =
Yun-Fly marked this conversation as resolved.
Show resolved Hide resolved
[&fuseProducerResult,
&realProducer](LoopLikeOpInterface loop) -> LogicalResult {
// ensure that all surrounding outer loops are just yielding the result of
// the inner loops.
if (failed(isForOpYieldResultOfInnerLoop(loop)))
return failure();
Operation *originalOp =
fuseProducerResult
? fuseProducerResult->tiledAndFusedProducer.getDefiningOp()
: realProducer->getDefiningOp();
Block &body = loop->getRegion(0).front();
return success(originalOp->getBlock() != &body);
};
SmallVector<LoopLikeOpInterface> outerLoops =
getOuterNestLoopsWhile(sliceOp->getParentOfType<LoopLikeOpInterface>(),
whileProducerOutOfLoopBlock);
fuseProducerResult =
tileAndFuseProducerOfSliceImpl(rewriter, sliceOp, outerLoops);
if (!fuseProducerResult) {
return std::nullopt;
}
}
return fuseProducerResult;
}

/// Implementation of fusing producer of a single slice by computing the
/// slice of the producer in-place.
std::optional<scf::SCFFuseProducerOfSliceResult>
mlir::scf::tileAndFuseProducerOfSlice(
RewriterBase &rewriter, tensor::ExtractSliceOp candidateSliceOp,
MutableArrayRef<LoopLikeOpInterface> loops) {
return tileAndFuseProducerOfSliceImpl(rewriter, candidateSliceOp, loops);
}

/// Reconstruct the fused producer from within the tiled-and-fused code.
LogicalResult mlir::scf::yieldReplacementForFusedProducer(
RewriterBase &rewriter, tensor::ExtractSliceOp sliceOp,
Expand Down
86 changes: 86 additions & 0 deletions mlir/test/Interfaces/TilingInterface/tile-and-fuse-producer.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
// RUN: mlir-opt --transform-interpreter --cse --split-input-file %s | FileCheck %s

#map = affine_map<(d0) -> (d0 * 128)>
module {
func.func @gemm_fill_fusion_multi_level_extract_slice(%arg0: tensor<256x512xf32>, %arg1: tensor<512x256xf32>, %arg2: tensor<256x256xf32>) -> tensor<256x256xf32> {
%c0 = arith.constant 0 : index
%c64 = arith.constant 64 : index
%c128 = arith.constant 128 : index
%cst = arith.constant 0.000000e+00 : f32
%dest0 = tensor.empty() : tensor<256x256xf32>
%dest1 = linalg.fill ins(%cst : f32) outs(%dest0 : tensor<256x256xf32>) -> tensor<256x256xf32>
%1 = scf.forall (%arg3, %arg4) in (2, 2) shared_outs(%arg5 = %dest1) -> tensor<256x256xf32> {
%iv0 = affine.apply #map(%arg3)
%iv1 = affine.apply #map(%arg4)
%extracted_slice_1 = tensor.extract_slice %arg5[%iv0, %iv1] [128, 128] [1, 1] : tensor<256x256xf32> to tensor<128x128xf32>
%extracted_slice_2 = tensor.extract_slice %arg0[%iv0, 0] [128, 512] [1, 1] : tensor<256x512xf32> to tensor<128x512xf32>
%extracted_slice_3 = tensor.extract_slice %arg1[0, %iv1] [512, 128] [1, 1] : tensor<512x256xf32> to tensor<512x128xf32>
%2 = scf.for %arg6 = %c0 to %c128 step %c64 iter_args(%arg7 = %extracted_slice_1) -> (tensor<128x128xf32>) {
%3 = scf.for %arg8 = %c0 to %c128 step %c64 iter_args(%arg9 = %arg7) -> (tensor<128x128xf32>) {
%extracted_slice_4 = tensor.extract_slice %arg9[%arg6, %arg8] [64, 64] [1, 1] : tensor<128x128xf32> to tensor<64x64xf32>
%extracted_slice_5 = tensor.extract_slice %extracted_slice_2[%arg6, 0] [64, 512] [1, 1] : tensor<128x512xf32> to tensor<64x512xf32>
%extracted_slice_6 = tensor.extract_slice %extracted_slice_3[0, %arg8] [512, 64] [1, 1] : tensor<512x128xf32> to tensor<512x64xf32>
%4 = linalg.matmul ins(%extracted_slice_5, %extracted_slice_6 : tensor<64x512xf32>, tensor<512x64xf32>) outs(%extracted_slice_4 : tensor<64x64xf32>) -> tensor<64x64xf32>
%insert_slice = tensor.insert_slice %4 into %arg9[%arg6, %arg8] [64, 64] [1, 1] : tensor<64x64xf32> into tensor<128x128xf32>
scf.yield %insert_slice : tensor<128x128xf32>
}
scf.yield %3 : tensor<128x128xf32>
}
scf.forall.in_parallel {
tensor.parallel_insert_slice %2 into %arg5[%iv0, %iv1] [128, 128] [1, 1] : tensor<128x128xf32> into tensor<256x256xf32>
}
}
return %1 : tensor<256x256xf32>
}
}

module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
%matmul = transform.structured.match ops{["linalg.matmul"]} in %arg1
: (!transform.any_op) -> !transform.any_op
%yield = transform.get_producer_of_operand %matmul[2]
: (!transform.any_op) -> !transform.any_op
%a, %b = transform.test.fuse_producer %yield
: (!transform.any_op) -> (!transform.any_op, !transform.any_op)
transform.yield
}
}

// CHECK: #[[MAP0:.*]] = affine_map<(d0) -> (d0 * 128)>
// CHECK: func.func @gemm_fill_fusion_multi_level_extract_slice(
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<256x512xf32>
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<512x256xf32>
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: tensor<256x256xf32>
// CHECK: %[[C0:.*]] = arith.constant 0 : index
// CHECK: %[[dest0:.*]] = tensor.empty() : tensor<256x256xf32>
// CHECK: %[[FORALL_RESULT:.*]] = scf.forall (%[[IV1:.*]], %[[IV2:.*]]) in (2, 2)
// CHECK-SAME: shared_outs(%[[INIT_ARG0:.*]] = %[[dest0]])
// CHECK-SAME: {
// CHECK: %[[AFFINE_IV1:.*]] = affine.apply #[[MAP0]](%[[IV1]])
// CHECK: %[[AFFINE_IV2:.*]] = affine.apply #[[MAP0]](%[[IV2]])
// CHECK: %[[FILL_OUT_SLICE0:.*]] = tensor.extract_slice %[[INIT_ARG0]][%[[AFFINE_IV1]], %[[AFFINE_IV2]]] [128, 128] [1, 1]
// CHECK: %[[INPUT_SLICE0:.*]] = tensor.extract_slice %[[ARG0]][%[[AFFINE_IV1]], 0] [128, 512] [1, 1]
// CHECK: %[[WEIGHT_SLICE0:.*]] = tensor.extract_slice %[[ARG1]][0, %[[AFFINE_IV2]]] [512, 128] [1, 1]
// CHECK: %[[LOOP_RESULT1:.*]] = scf.for %[[IV3:.*]] = %[[C0]]
// CHECK-SAME: iter_args(%[[INIT_ARG1:.*]] = %[[FILL_OUT_SLICE0]])
// CHECK-SAME: {
// CHECK: %[[LOOP_RESULT2:.*]] = scf.for %[[IV4:.*]] = %[[C0]]
// CHECK-SAME: iter_args(%[[INIT_ARG2:.*]] = %[[INIT_ARG1]])
// CHECK-SAME: {
// CHECK: %[[FILL_OUT_SLICE1:.*]] = tensor.extract_slice %[[INIT_ARG2]][%[[IV3]], %[[IV4]]] [64, 64] [1, 1]
// CHECK: %[[TILED_FILL_OUT:.*]] = linalg.fill
// CHECK-SAME: outs(%[[FILL_OUT_SLICE1]] :
// CHECK: %[[INPUT_SLICE1:.*]] = tensor.extract_slice %[[INPUT_SLICE0]][%[[IV3]], 0] [64, 512] [1, 1]
// CHECK: %[[WEIGHT_SLICE1:.*]] = tensor.extract_slice %[[WEIGHT_SLICE0]][0, %[[IV4]]] [512, 64] [1, 1]
// CHECK: %[[TILED_MAT_OUT:.*]] = linalg.matmul
// CHECK-SAME: outs(%[[TILED_FILL_OUT]] :
// CHECK: %[[INSERT_MAT:.*]] = tensor.insert_slice %[[TILED_MAT_OUT]] into %[[INIT_ARG2]][%[[IV3]], %[[IV4]]] [64, 64] [1, 1]
// CHECK: scf.yield %[[INSERT_MAT]] :
// CHECK: }
// CHECK: scf.yield %[[LOOP_RESULT2]] :
// CHECK: }
// CHECK: scf.forall.in_parallel {
// CHECK: tensor.parallel_insert_slice %[[LOOP_RESULT1]] into %[[INIT_ARG0]][%[[AFFINE_IV1]], %[[AFFINE_IV2]]] [128, 128] [1, 1]
// CHECK: }
// CHECK: }
// CHECK: return %[[FORALL_RESULT]] :
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,56 @@ transform::TestFuseAndYieldOp::apply(TransformRewriter &rewriter,
: DiagnosedSilenceableFailure::success();
}

//===----------------------------------------------------------------------===//
// TestFuseProducerOp
//===----------------------------------------------------------------------===//

/// Apply fusing of producer transformation to all payload ops and store both
/// the original producer operation as well as the fused producer operation.
template <typename Range>
static LogicalResult
applyFuseProducer(RewriterBase &rewriter, Operation *transformOp,
Range &&payloadOps, TransformResults &transformResults) {
SmallVector<Operation *> originalProducerOps;
SmallVector<Operation *> fusedProducerOps;

for (Operation *target : payloadOps) {
rewriter.setInsertionPoint(target);

std::optional<scf::SCFFuseProducerOfSliceResult> fuseProducerResults =
scf::tileAndFuseProducerOfSlice(rewriter, target);

if (!fuseProducerResults)
return failure();

// Report back the relevant handles to the transform op.
originalProducerOps.push_back(fuseProducerResults->origProducer.getOwner());
fusedProducerOps.push_back(fuseProducerResults->tiledOps[0]);
}

transformResults.set(transformOp->getOpResult(0), originalProducerOps);
transformResults.set(transformOp->getOpResult(1), fusedProducerOps);
return success();
}

DiagnosedSilenceableFailure
transform::TestFuseProducerOp::apply(TransformRewriter &rewriter,
TransformResults &transformResults,
TransformState &state) {
LogicalResult result =
applyFuseProducer(rewriter, getOperation(),
state.getPayloadOps(getTarget()), transformResults);
return failed(result) ? DiagnosedSilenceableFailure::definiteFailure()
: DiagnosedSilenceableFailure::success();
}

void transform::TestFuseProducerOp::getEffects(
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
consumesHandle(getTargetMutable(), effects);
producesHandle(getOperation()->getOpResults(), effects);
modifiesPayload(effects);
}

//===----------------------------------------------------------------------===//
// TestFuseConsumerOp
//===----------------------------------------------------------------------===//
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,25 @@ def TestFuseAndYieldOp : Op<Transform_Dialect, "test.fuse_and_yield",
}];
}

def TestFuseProducerOp : Op<Transform_Dialect, "test.fuse_producer",
[DeclareOpInterfaceMethods<TransformOpInterface>,
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
ReportTrackingListenerFailuresOpTrait]> {
let description = [{
Fuses the producer of the operation pointed to by the target handle
using the options provided as attributes.
}];

let arguments =
(ins TransformHandleTypeInterface:$target);
let results = (outs TransformHandleTypeInterface:$producer,
TransformHandleTypeInterface:$fused_producer);

let assemblyFormat = [{
$target attr-dict `:` functional-type(operands, results)
}];
}

def TestFuseConsumerOp : Op<Transform_Dialect, "test.fuse_consumer",
[DeclareOpInterfaceMethods<TransformOpInterface>,
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
Expand Down
Loading