Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][scf] Extend fuse producer to multi-level candidates case #97803

Open
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

Yun-Fly
Copy link
Contributor

@Yun-Fly Yun-Fly commented Jul 5, 2024

Although producer can be already fused within nest loop structure at present, it didn't work for multi-level candidates extractSliceOp case, E.g.

func.func @gemm_fill_fusion(%arg0: tensor<256x512xf32>, %arg1: tensor<512x256xf32>, %arg2: tensor<256x256xf32>) -> tensor<256x256xf32> {
    %c0 = arith.constant 0 : index
    %c64 = arith.constant 64 : index
    %c128 = arith.constant 128 : index
    %cst = arith.constant 0.000000e+00 : f32
    %dest0 = tensor.empty() : tensor<256x256xf32>
    %dest1 = linalg.fill ins(%cst : f32) outs(%dest0 : tensor<256x256xf32>) -> tensor<256x256xf32>
    %1 = scf.forall (%arg3, %arg4) in (2, 2) shared_outs(%arg5 = %dest1) -> tensor<256x256xf32> {
      %iv0 = affine.apply #map(%arg3)
      %iv1 = affine.apply #map(%arg4)
      %extracted_slice_1 = tensor.extract_slice %arg5[%iv0, %iv1] [128, 128] [1, 1] : tensor<256x256xf32> to tensor<128x128xf32>
      %extracted_slice_2 = tensor.extract_slice %arg0[%iv0, 0] [128, 512] [1, 1] : tensor<256x512xf32> to tensor<128x512xf32>
      %extracted_slice_3 = tensor.extract_slice %arg1[0, %iv1] [512, 128] [1, 1] : tensor<512x256xf32> to tensor<512x128xf32>
      %2 = scf.for %arg6 = %c0 to %c128 step %c64 iter_args(%arg7 = %extracted_slice_1) -> (tensor<128x128xf32>) {
        %3 = scf.for %arg8 = %c0 to %c128 step %c64 iter_args(%arg9 = %arg7) -> (tensor<128x128xf32>) {
          %extracted_slice_4 = tensor.extract_slice %arg9[%arg6, %arg8] [64, 64] [1, 1] : tensor<128x128xf32> to tensor<64x64xf32>
          %extracted_slice_5 = tensor.extract_slice %extracted_slice_2[%arg6, 0] [64, 512] [1, 1] : tensor<128x512xf32> to tensor<64x512xf32>
          %extracted_slice_6 = tensor.extract_slice %extracted_slice_3[0, %arg8] [512, 64] [1, 1] : tensor<512x128xf32> to tensor<512x64xf32>
          %4 = linalg.matmul ins(%extracted_slice_5, %extracted_slice_6 : tensor<64x512xf32>, tensor<512x64xf32>) outs(%extracted_slice_4 : tensor<64x64xf32>) -> tensor<64x64xf32>
          %insert_slice = tensor.insert_slice %4 into %arg9[%arg6, %arg8] [64, 64] [1, 1] : tensor<64x64xf32> into tensor<128x128xf32>
          scf.yield %insert_slice : tensor<128x128xf32>
        }
        scf.yield %3 : tensor<128x128xf32>
      }
      scf.forall.in_parallel {
         tensor.parallel_insert_slice %2 into %arg5[%iv0, %iv1] [128, 128] [1, 1] : tensor<128x128xf32> into tensor<256x256xf32>
      }
    }
    return %1 : tensor<256x256xf32>
  }

If we want to fuse producer fill into %extracted_slice_4, current status in master will be broken at getUntiledProducerFromSliceSource due to upper candidate slice %extracted_slice_1 . This PR extends fusing producer to multi-level extractSliceOp cases.

This patch uses multiple application of existing tiling interface as same as another counterpart PR.

@llvmbot
Copy link
Collaborator

llvmbot commented Jul 5, 2024

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-scf

Author: None (Yun-Fly)

Changes

Although producer can be already fused within nest loop structure at present, it didn't work for multi-level candidates extractSliceOp case, E.g.

func.func @<!-- -->gemm_fill_fusion(%arg0: tensor&lt;256x512xf32&gt;, %arg1: tensor&lt;512x256xf32&gt;, %arg2: tensor&lt;256x256xf32&gt;) -&gt; tensor&lt;256x256xf32&gt; {
    %c0 = arith.constant 0 : index
    %c64 = arith.constant 64 : index
    %c128 = arith.constant 128 : index
    %cst = arith.constant 0.000000e+00 : f32
    %dest0 = tensor.empty() : tensor&lt;256x256xf32&gt;
    %dest1 = linalg.fill ins(%cst : f32) outs(%dest0 : tensor&lt;256x256xf32&gt;) -&gt; tensor&lt;256x256xf32&gt;
    %1 = scf.forall (%arg3, %arg4) in (2, 2) shared_outs(%arg5 = %dest1) -&gt; tensor&lt;256x256xf32&gt; {
      %iv0 = affine.apply #map(%arg3)
      %iv1 = affine.apply #map(%arg4)
      %extracted_slice_1 = tensor.extract_slice %arg5[%iv0, %iv1] [128, 128] [1, 1] : tensor&lt;256x256xf32&gt; to tensor&lt;128x128xf32&gt;
      %extracted_slice_2 = tensor.extract_slice %arg0[%iv0, 0] [128, 512] [1, 1] : tensor&lt;256x512xf32&gt; to tensor&lt;128x512xf32&gt;
      %extracted_slice_3 = tensor.extract_slice %arg1[0, %iv1] [512, 128] [1, 1] : tensor&lt;512x256xf32&gt; to tensor&lt;512x128xf32&gt;
      %2 = scf.for %arg6 = %c0 to %c128 step %c64 iter_args(%arg7 = %extracted_slice_1) -&gt; (tensor&lt;128x128xf32&gt;) {
        %3 = scf.for %arg8 = %c0 to %c128 step %c64 iter_args(%arg9 = %arg7) -&gt; (tensor&lt;128x128xf32&gt;) {
          %extracted_slice_4 = tensor.extract_slice %arg9[%arg6, %arg8] [64, 64] [1, 1] : tensor&lt;128x128xf32&gt; to tensor&lt;64x64xf32&gt;
          %extracted_slice_5 = tensor.extract_slice %extracted_slice_2[%arg6, 0] [64, 512] [1, 1] : tensor&lt;128x512xf32&gt; to tensor&lt;64x512xf32&gt;
          %extracted_slice_6 = tensor.extract_slice %extracted_slice_3[0, %arg8] [512, 64] [1, 1] : tensor&lt;512x128xf32&gt; to tensor&lt;512x64xf32&gt;
          %4 = linalg.matmul ins(%extracted_slice_5, %extracted_slice_6 : tensor&lt;64x512xf32&gt;, tensor&lt;512x64xf32&gt;) outs(%extracted_slice_4 : tensor&lt;64x64xf32&gt;) -&gt; tensor&lt;64x64xf32&gt;
          %insert_slice = tensor.insert_slice %4 into %arg9[%arg6, %arg8] [64, 64] [1, 1] : tensor&lt;64x64xf32&gt; into tensor&lt;128x128xf32&gt;
          scf.yield %insert_slice : tensor&lt;128x128xf32&gt;
        }
        scf.yield %3 : tensor&lt;128x128xf32&gt;
      }
      scf.forall.in_parallel {
         tensor.parallel_insert_slice %2 into %arg5[%iv0, %iv1] [128, 128] [1, 1] : tensor&lt;128x128xf32&gt; into tensor&lt;256x256xf32&gt;
      }
    }
    return %1 : tensor&lt;256x256xf32&gt;
  }

If we want to fuse producer fill into %extracted_slice_4, current status in master will be broken at getUntiledProducerFromSliceSource due to upper candidate slice %extracted_slice_1 . This PR extends fusing producer to multi-level extractSliceOp cases.

This patch uses multiple application of existing tiling interface as same as another counterpart PR.


Full diff: https://github.com/llvm/llvm-project/pull/97803.diff

5 Files Affected:

  • (modified) mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h (+8)
  • (modified) mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp (+110-1)
  • (added) mlir/test/Interfaces/TilingInterface/tile-and-fuse-producer.mlir (+86)
  • (modified) mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp (+50)
  • (modified) mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.td (+19)
diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
index d68ca11207376..36888c3d6d607 100644
--- a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
+++ b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
@@ -136,11 +136,19 @@ struct SCFFuseProducerOfSliceResult {
   Value tiledAndFusedProducer; // Tile and fused producer value.
   SmallVector<Operation *> tiledOps;
 };
+std::optional<SCFFuseProducerOfSliceResult>
+tileAndFuseProducerOfSliceImpl(RewriterBase &rewriter,
+                               tensor::ExtractSliceOp candidateSliceOp,
+                               MutableArrayRef<LoopLikeOpInterface> loops);
+
 std::optional<SCFFuseProducerOfSliceResult>
 tileAndFuseProducerOfSlice(RewriterBase &rewriter,
                            tensor::ExtractSliceOp candidateSliceOp,
                            MutableArrayRef<LoopLikeOpInterface> loops);
 
+std::optional<SCFFuseProducerOfSliceResult>
+tileAndFuseProducerOfSlice(RewriterBase &rewriter, Operation *candidateSliceOp);
+
 /// Reconstruct the fused producer from within the tiled-and-fused code. Based
 /// on the slice of the producer computed in place it is possible that within
 /// the loop nest same slice of the producer is computed multiple times. It is
diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
index a1392813d6de3..d6baa69618dae 100644
--- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
@@ -835,7 +835,7 @@ getUntiledProducerFromSliceSource(OpOperand *source,
 /// Implementation of fusing producer of a single slice by computing the
 /// slice of the producer in-place.
 std::optional<scf::SCFFuseProducerOfSliceResult>
-mlir::scf::tileAndFuseProducerOfSlice(
+mlir::scf::tileAndFuseProducerOfSliceImpl(
     RewriterBase &rewriter, tensor::ExtractSliceOp candidateSliceOp,
     MutableArrayRef<LoopLikeOpInterface> loops) {
   // 1. Get the producer of the source (potentially walking through
@@ -949,6 +949,115 @@ mlir::scf::tileAndFuseProducerOfSlice(
                                            tileAndFuseResult->tiledOps};
 }
 
+/// Get the Root source of target ExtractSliceOp
+/// %0 =
+/// %1 = scf.for(%arg1 = %0)
+///   %2 = extract %arg1
+///   %3 = scf.for(%arg2 = %2)
+///      %4 = extract %args2
+///      ...
+/// @param targetSliceOp: %4 = extract %args2
+/// @param extractSliceOpChain: chain of all related extract sliceOp
+/// @return Value of Root Source : %0
+static FailureOr<Value> getRootSourceOfExtractSliceOp(
+    Operation *targetSliceOp,
+    SmallVectorImpl<tensor::ExtractSliceOp> &extractSliceOpChain,
+    int curDepth = 0, int maxDepth = 5) {
+  assert(isa<tensor::ExtractSliceOp>(targetSliceOp));
+  // control recursive time in avoid of stack overflow
+  if (curDepth > maxDepth)
+    return failure();
+
+  auto extractOp = cast<tensor::ExtractSliceOp>(targetSliceOp);
+  extractSliceOpChain.push_back(extractOp);
+  Value rootSource = extractOp.getSourceMutable().get();
+
+  while (true) {
+    if (auto iterArg = dyn_cast<BlockArgument>(rootSource)) {
+      if (auto outerLoop = dyn_cast<LoopLikeOpInterface>(
+              iterArg.getOwner()->getParentOp())) {
+        rootSource = outerLoop.getTiedLoopInit(iterArg)->get();
+        continue;
+      }
+      return failure();
+    } else if (auto sliceOp =
+                   rootSource.getDefiningOp<tensor::ExtractSliceOp>()) {
+      // walk up loop to find larger candidate extractSliceOp
+      return getRootSourceOfExtractSliceOp(sliceOp, extractSliceOpChain,
+                                           curDepth + 1);
+    }
+    break;
+  }
+  return rootSource;
+}
+
+/// Return outer loops of given ForOp(included) until the predict function
+/// succeed(excluded), sorted from outer to inner.
+static SmallVector<LoopLikeOpInterface>
+getOuterLoopsUntil(LoopLikeOpInterface loop,
+                   std::function<LogicalResult(LoopLikeOpInterface)> pred) {
+  SmallVector<LoopLikeOpInterface> outerLoops = {loop};
+  auto forOp = loop->getParentOfType<LoopLikeOpInterface>();
+  while (forOp) {
+    if (succeeded(pred(forOp)))
+      break;
+    outerLoops.push_back(forOp);
+    forOp = forOp->getParentOfType<LoopLikeOpInterface>();
+  }
+  return {outerLoops.rbegin(), outerLoops.rend()};
+}
+
+/// Enhanced version of `tileAndFuseProducerOfSliceImpl`, which can deal with
+/// multi-level `extractSliceOp`. E.g.
+///
+/// %0 = untiled_producer
+/// %1 = scf.for(%arg1 = %0)
+///   %2 = extract %arg1
+///   %3 = scf.for(%arg2 = %2)
+///      %4 = extract %args2
+///      %5 = tiled_consumer ins(%4)
+std::optional<scf::SCFFuseProducerOfSliceResult>
+mlir::scf::tileAndFuseProducerOfSlice(RewriterBase &rewriter,
+                                      Operation *candidateSliceOp) {
+  SmallVector<tensor::ExtractSliceOp> sliceOpChain;
+  if (failed(getRootSourceOfExtractSliceOp(candidateSliceOp, sliceOpChain))) {
+    return std::nullopt;
+  }
+
+  std::optional<scf::SCFFuseProducerOfSliceResult> fuseProducerResult;
+  // reverse from outer to inner
+  std::reverse(sliceOpChain.begin(), sliceOpChain.end());
+  // multiple application of `tileAndFuseProducerOfSliceImpl`
+  for (auto &&[index, sliceOp] : llvm::enumerate(sliceOpChain)) {
+    Operation *upperSliceOp = index ? sliceOpChain[index - 1] : nullptr;
+    auto untilUpperSliceFound =
+        [&upperSliceOp](LoopLikeOpInterface loop) -> LogicalResult {
+      if (upperSliceOp) {
+        Block &body = loop->getRegion(0).front();
+        if (upperSliceOp->getBlock() == &body)
+          return success();
+      }
+      return failure();
+    };
+    SmallVector<LoopLikeOpInterface> outerLoops = getOuterLoopsUntil(
+        sliceOp->getParentOfType<LoopLikeOpInterface>(), untilUpperSliceFound);
+    fuseProducerResult =
+        tileAndFuseProducerOfSliceImpl(rewriter, sliceOp, outerLoops);
+    if (!fuseProducerResult) {
+      return std::nullopt;
+    }
+  }
+  return fuseProducerResult;
+}
+
+/// To be compatible with previous behavior
+std::optional<scf::SCFFuseProducerOfSliceResult>
+mlir::scf::tileAndFuseProducerOfSlice(
+    RewriterBase &rewriter, tensor::ExtractSliceOp candidateSliceOp,
+    MutableArrayRef<LoopLikeOpInterface> loops) {
+  return tileAndFuseProducerOfSliceImpl(rewriter, candidateSliceOp, loops);
+}
+
 /// Reconstruct the fused producer from within the tiled-and-fused code.
 LogicalResult mlir::scf::yieldReplacementForFusedProducer(
     RewriterBase &rewriter, tensor::ExtractSliceOp sliceOp,
diff --git a/mlir/test/Interfaces/TilingInterface/tile-and-fuse-producer.mlir b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-producer.mlir
new file mode 100644
index 0000000000000..ef1c6952a55e1
--- /dev/null
+++ b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-producer.mlir
@@ -0,0 +1,86 @@
+// RUN: mlir-opt --transform-interpreter --cse --split-input-file %s | FileCheck %s
+
+#map = affine_map<(d0) -> (d0 * 128)>
+module {
+  func.func @gemm_fill_fusion_multi_level_extract_slice(%arg0: tensor<256x512xf32>, %arg1: tensor<512x256xf32>, %arg2: tensor<256x256xf32>) -> tensor<256x256xf32> {
+    %c0 = arith.constant 0 : index
+    %c64 = arith.constant 64 : index
+    %c128 = arith.constant 128 : index
+    %cst = arith.constant 0.000000e+00 : f32
+    %dest0 = tensor.empty() : tensor<256x256xf32>
+    %dest1 = linalg.fill ins(%cst : f32) outs(%dest0 : tensor<256x256xf32>) -> tensor<256x256xf32>
+    %1 = scf.forall (%arg3, %arg4) in (2, 2) shared_outs(%arg5 = %dest1) -> tensor<256x256xf32> {
+      %iv0 = affine.apply #map(%arg3)
+      %iv1 = affine.apply #map(%arg4)
+      %extracted_slice_1 = tensor.extract_slice %arg5[%iv0, %iv1] [128, 128] [1, 1] : tensor<256x256xf32> to tensor<128x128xf32>
+      %extracted_slice_2 = tensor.extract_slice %arg0[%iv0, 0] [128, 512] [1, 1] : tensor<256x512xf32> to tensor<128x512xf32>
+      %extracted_slice_3 = tensor.extract_slice %arg1[0, %iv1] [512, 128] [1, 1] : tensor<512x256xf32> to tensor<512x128xf32>
+      %2 = scf.for %arg6 = %c0 to %c128 step %c64 iter_args(%arg7 = %extracted_slice_1) -> (tensor<128x128xf32>) {
+        %3 = scf.for %arg8 = %c0 to %c128 step %c64 iter_args(%arg9 = %arg7) -> (tensor<128x128xf32>) {
+          %extracted_slice_4 = tensor.extract_slice %arg9[%arg6, %arg8] [64, 64] [1, 1] : tensor<128x128xf32> to tensor<64x64xf32>
+          %extracted_slice_5 = tensor.extract_slice %extracted_slice_2[%arg6, 0] [64, 512] [1, 1] : tensor<128x512xf32> to tensor<64x512xf32>
+          %extracted_slice_6 = tensor.extract_slice %extracted_slice_3[0, %arg8] [512, 64] [1, 1] : tensor<512x128xf32> to tensor<512x64xf32>
+          %4 = linalg.matmul ins(%extracted_slice_5, %extracted_slice_6 : tensor<64x512xf32>, tensor<512x64xf32>) outs(%extracted_slice_4 : tensor<64x64xf32>) -> tensor<64x64xf32>
+          %insert_slice = tensor.insert_slice %4 into %arg9[%arg6, %arg8] [64, 64] [1, 1] : tensor<64x64xf32> into tensor<128x128xf32>
+          scf.yield %insert_slice : tensor<128x128xf32>
+        }
+        scf.yield %3 : tensor<128x128xf32>
+      }
+      scf.forall.in_parallel {
+         tensor.parallel_insert_slice %2 into %arg5[%iv0, %iv1] [128, 128] [1, 1] : tensor<128x128xf32> into tensor<256x256xf32>
+      }
+    }
+    return %1 : tensor<256x256xf32>
+  }
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
+    %matmul = transform.structured.match ops{["linalg.matmul"]} in %arg1
+      : (!transform.any_op) -> !transform.any_op
+    %yield = transform.get_producer_of_operand %matmul[2]
+      : (!transform.any_op) -> !transform.any_op
+    %a, %b = transform.test.fuse_producer %yield
+      : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+    transform.yield
+  }
+}
+
+//      CHECK: #[[MAP0:.*]] =  affine_map<(d0) -> (d0 * 128)>
+//      CHECK: func.func @gemm_fill_fusion_multi_level_extract_slice(
+// CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]: tensor<256x512xf32>
+// CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]: tensor<512x256xf32>
+// CHECK-SAME:     %[[ARG2:[a-zA-Z0-9]+]]: tensor<256x256xf32>
+//      CHECK:   %[[C0:.*]] = arith.constant 0 : index
+//      CHECK:   %[[dest0:.*]] = tensor.empty() : tensor<256x256xf32>
+//      CHECK:   %[[FORALL_RESULT:.*]] = scf.forall (%[[IV1:.*]], %[[IV2:.*]]) in (2, 2)
+// CHECK-SAME:      shared_outs(%[[INIT_ARG0:.*]] = %[[dest0]])
+// CHECK-SAME:   {
+//      CHECK:      %[[AFFINE_IV1:.*]] = affine.apply #[[MAP0]](%[[IV1]])
+//      CHECK:      %[[AFFINE_IV2:.*]] = affine.apply #[[MAP0]](%[[IV2]])
+//      CHECK:      %[[FILL_OUT_SLICE0:.*]] = tensor.extract_slice %[[INIT_ARG0]][%[[AFFINE_IV1]], %[[AFFINE_IV2]]] [128, 128] [1, 1]
+//      CHECK:      %[[INPUT_SLICE0:.*]] = tensor.extract_slice %[[ARG0]][%[[AFFINE_IV1]], 0] [128, 512] [1, 1]
+//      CHECK:      %[[WEIGHT_SLICE0:.*]] = tensor.extract_slice %[[ARG1]][0, %[[AFFINE_IV2]]] [512, 128] [1, 1]
+//      CHECK:      %[[LOOP_RESULT1:.*]] = scf.for %[[IV3:.*]] = %[[C0]]
+// CHECK-SAME:          iter_args(%[[INIT_ARG1:.*]] = %[[FILL_OUT_SLICE0]])
+// CHECK-SAME:      {
+//      CHECK:          %[[LOOP_RESULT2:.*]] = scf.for %[[IV4:.*]] = %[[C0]]
+// CHECK-SAME:            iter_args(%[[INIT_ARG2:.*]] = %[[INIT_ARG1]])
+// CHECK-SAME:          {
+//      CHECK:            %[[FILL_OUT_SLICE1:.*]] = tensor.extract_slice %[[INIT_ARG2]][%[[IV3]], %[[IV4]]] [64, 64] [1, 1]
+//      CHECK:            %[[TILED_FILL_OUT:.*]] = linalg.fill
+// CHECK-SAME:                  outs(%[[FILL_OUT_SLICE1]] :
+//      CHECK:            %[[INPUT_SLICE1:.*]] = tensor.extract_slice %[[INPUT_SLICE0]][%[[IV3]], 0] [64, 512] [1, 1]
+//      CHECK:            %[[WEIGHT_SLICE1:.*]] = tensor.extract_slice %[[WEIGHT_SLICE0]][0, %[[IV4]]] [512, 64] [1, 1]
+//      CHECK:            %[[TILED_MAT_OUT:.*]] = linalg.matmul
+// CHECK-SAME:                  outs(%[[TILED_FILL_OUT]] :
+//      CHECK:            %[[INSERT_MAT:.*]] = tensor.insert_slice %[[TILED_MAT_OUT]] into %[[INIT_ARG2]][%[[IV3]], %[[IV4]]] [64, 64] [1, 1]
+//      CHECK:            scf.yield %[[INSERT_MAT]] :
+//      CHECK:          }
+//      CHECK:          scf.yield %[[LOOP_RESULT2]] :
+//      CHECK:      }
+//      CHECK:      scf.forall.in_parallel {
+//      CHECK:          tensor.parallel_insert_slice %[[LOOP_RESULT1]] into %[[INIT_ARG0]][%[[AFFINE_IV1]], %[[AFFINE_IV2]]] [128, 128] [1, 1]
+//      CHECK:       }
+//      CHECK:   }
+//      CHECK:   return %[[FORALL_RESULT]] :
\ No newline at end of file
diff --git a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp
index 8f206d9077272..8d3d5b23a53e9 100644
--- a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp
+++ b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp
@@ -160,6 +160,56 @@ transform::TestFuseAndYieldOp::apply(TransformRewriter &rewriter,
                         : DiagnosedSilenceableFailure::success();
 }
 
+//===----------------------------------------------------------------------===//
+// TestFuseProducerOp
+//===----------------------------------------------------------------------===//
+
+/// Apply fusing of producer transformation to all payload ops and store both
+/// the original producer operation as well as the fused producer operation.
+template <typename Range>
+static LogicalResult
+applyFuseProducer(RewriterBase &rewriter, Operation *transformOp,
+                  Range &&payloadOps, TransformResults &transformResults) {
+  SmallVector<Operation *> originalProducerOps;
+  SmallVector<Operation *> fusedProducerOps;
+
+  for (Operation *target : payloadOps) {
+    rewriter.setInsertionPoint(target);
+
+    std::optional<scf::SCFFuseProducerOfSliceResult> fuseProducerResults =
+        scf::tileAndFuseProducerOfSlice(rewriter, target);
+
+    if (!fuseProducerResults)
+      return failure();
+
+    // Report back the relevant handles to the transform op.
+    originalProducerOps.push_back(fuseProducerResults->origProducer.getOwner());
+    fusedProducerOps.push_back(fuseProducerResults->tiledOps[0]);
+  }
+
+  transformResults.set(transformOp->getOpResult(0), originalProducerOps);
+  transformResults.set(transformOp->getOpResult(1), fusedProducerOps);
+  return success();
+}
+
+DiagnosedSilenceableFailure
+transform::TestFuseProducerOp::apply(TransformRewriter &rewriter,
+                                     TransformResults &transformResults,
+                                     TransformState &state) {
+  LogicalResult result =
+      applyFuseProducer(rewriter, getOperation(),
+                        state.getPayloadOps(getTarget()), transformResults);
+  return failed(result) ? DiagnosedSilenceableFailure::definiteFailure()
+                        : DiagnosedSilenceableFailure::success();
+}
+
+void transform::TestFuseProducerOp::getEffects(
+    SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+  consumesHandle(getTargetMutable(), effects);
+  producesHandle(getOperation()->getOpResults(), effects);
+  modifiesPayload(effects);
+}
+
 //===----------------------------------------------------------------------===//
 // TestFuseConsumerOp
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.td b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.td
index d55d746bd6aa9..6e73478c35c4a 100644
--- a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.td
+++ b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.td
@@ -49,6 +49,25 @@ def TestFuseAndYieldOp : Op<Transform_Dialect, "test.fuse_and_yield",
   }];
 }
 
+def TestFuseProducerOp : Op<Transform_Dialect, "test.fuse_producer",
+       [DeclareOpInterfaceMethods<TransformOpInterface>,
+        DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
+        ReportTrackingListenerFailuresOpTrait]> {
+  let description = [{
+    Fuses the producer of the operation pointed to by the target handle
+    using the options provided as attributes.
+  }];
+
+  let arguments =
+    (ins TransformHandleTypeInterface:$target);
+  let results = (outs TransformHandleTypeInterface:$producer,
+                      TransformHandleTypeInterface:$fused_producer);
+
+  let assemblyFormat = [{
+    $target attr-dict `:` functional-type(operands, results)
+  }];
+}
+
 def TestFuseConsumerOp : Op<Transform_Dialect, "test.fuse_consumer",
        [DeclareOpInterfaceMethods<TransformOpInterface>,
         DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,

@Yun-Fly Yun-Fly force-pushed the yunfei/fuse_producer_multilevel_slice branch from a4c2778 to 8c9a585 Compare July 6, 2024 13:00
Copy link
Contributor

@MaheshRavishankar MaheshRavishankar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Please give me some time to review. Broadly my concern is how do you control which loop nest you want to fuse into when you are doing multiple tiling levels... I can see it is being controlled by API knobs, but those seem fragile. I need to look more in detail though.

mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp Outdated Show resolved Hide resolved
@Yun-Fly
Copy link
Contributor Author

Yun-Fly commented Jul 9, 2024

Thanks! Please give me some time to review. Broadly my concern is how do you control which loop nest you want to fuse into when you are doing multiple tiling levels... I can see it is being controlled by API knobs, but those seem fragile. I need to look more in detail though.

Good question! In plan, this PR just focuses on the functionality extension. In another word, user is responsible for deciding which loop nest to fuse by feeding specific candidate into tileAndFuseProducerOfSlice which can accept candidates at any tiling level.

BTW, as introduced in this early RFC, we intend to design a higher-level utility or API to control which loop nest to fuse based on semantic check(validity) and cost model(performance), both of which is expected to be designed as an call-back option with default behavior.

I can give you more example for details:

%0 = tensor.pack ins() out() : tensor<128x256xf32> -> tensor<16x8 x 8x32 xf32>
%1 = linalg.fill ins() out()
%1 = scf.for()
   %2 = extract_slice
   %3 = scf.for() 
   %4 = extract_slice
  1. Validity: some op has specific tiling requirement, saying pack which has limitation on inner_tile part. Thus, we need to firstly check them and filter out those invalid candidates.
  2. Performance: most op has no tiling requirement, like fill, whereas, which loop to fuse may also affect performance. This should be decided by specific target machine assisted by cost model. If nothing is given, the inner loop usually represents better cache locality.

NOTE: fusing consumer has the same concern as producer. A typical use case is fusing linalg.reduce which expects no tiling on reduction type of loops.

As the result, I prefer to use another PR to particularly deal with which loops to fuse in multiple level tiling when both of this patch and #94190 merged.

Looking forward to hear your advice!

@Yun-Fly Yun-Fly force-pushed the yunfei/fuse_producer_multilevel_slice branch 2 times, most recently from 6d38e2d to 9feadc4 Compare July 16, 2024 07:23
Copy link
Contributor

@MaheshRavishankar MaheshRavishankar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

THis seems related to the other PR (#94190). Lets come back to this after we have some idea of that. Or if you want to try to land this first, we can switch the order around.

@Yun-Fly
Copy link
Contributor Author

Yun-Fly commented Jul 17, 2024

THis seems related to the other PR (#94190). Lets come back to this after we have some idea of that. Or if you want to try to land this first, we can switch the order around.

I have no specific request about the priory order so far. Its up to you or which one is closer to land status in your mind. Hopefully both of them would be merged ASAP because another coming PR may depend on them.

@MaheshRavishankar
Copy link
Contributor

THis seems related to the other PR (#94190). Lets come back to this after we have some idea of that. Or if you want to try to land this first, we can switch the order around.

I have no specific request about the priory order so far. Its up to you or which one is closer to land status in your mind. Hopefully both of them would be merged ASAP because another coming PR may depend on them.

Ok, lets do the producer fusion first, cause that is easier to navigate than the consumer fusion.

(I know you say ASAP, but this is involved code and reviews take a time. Ill do my best to keep it timely, but I am juggling this with other work that is a bit higher priority for my employer)

Copy link
Contributor

@MaheshRavishankar MaheshRavishankar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This PR is much easier to review... Ill ask a few more questions to help navigate this.

Trying to put all of this logic upstream is hard. This seems like a very special case where you already have someone that is partially tiling the code and you are trying to reuse some logic from tile and fuse for this case. But there is too much of an assumption being made that the loop structure you create is consistent with what tiling creates. In fact that assumption is already made by the implementation. That is why the current mlir::tileAndFuseProducerOfSlice is taking the loop nest generated by MutableArrayRef<LoopLikeOpInterface> loops as arguments.
I'd like to understand if there is a way you can keep the logic of finding loops closer to your context and then consider if we need to change the implementation of how the producer is found by walking the loop nest and extract slice source definition chain. Does that question make sense?

@Yun-Fly
Copy link
Contributor Author

Yun-Fly commented Jul 18, 2024

But there is too much of an assumption being made that the loop structure you create is consistent with what tiling creates. In fact that assumption is already made by the implementation.

The loop structure created in this PR is exactly what multiple level tiling creates. So it quite looks consistent with what tiling creates, but also different in multiple level of nest loops and candidate slices for better utilization of parallelism and memory hierarchy. I will detail more our context at last.

That is why the current mlir::tileAndFuseProducerOfSlice is taking the loop nest generated by MutableArrayRef<LoopLikeOpInterface> loops as arguments.

Yes, I know it and that is one of reason why this PR looks easier reviewer than another patch of fusing consumer. IIRC, current mlir::tileAndFuseProducerOfSlice is only used in tileConsumerAndFuseProducersUsingSCF, where consumer is tiled by tileUsingSCF into single level tiling with all extract slices located in inner-most loop. So, it is good enough for this use case.

I'd like to understand if there is a way you can keep the logic of finding loops closer to your context and then consider if we need to change the implementation of how the producer is found by walking the loop nest and extract slice source definition chain. Does that question make sense?

Our context is that some developers want to costume the way how to tile the certain op instead of directly using general but common tileUsingSCF API, especially for matmul or convolution on different target device. The only concern of these developers is how to improve the kernel performance by multiple level tiling, but not how subsequent fusion would be done because what kind of preceding and succeeding ops are uncertain for them. In another word, this kind of use case decouples the tiling stage of contraction op and fusion process of surrounding ops, which brings more flexible tiling and requires robust fusion interface, saying tileAndFuseProducerOfSlice and tileAndFuseConsumerOfSlice.

@Yun-Fly
Copy link
Contributor Author

Yun-Fly commented Jul 18, 2024

I know you say ASAP, but this is involved code and reviews take a time. Ill do my best to keep it timely.

I would like to express my sincere gratitude for your involving!

/// @param candidateSliceOp: %4 = extract %args2
/// @param backwardSlice: in-out parameter populated by backward extractSliceOps
/// @return OpResult Producer : %0 = producer
static FailureOr<OpResult> getRealProducerFromExtractSliceOp(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of doing this, would it help if you just folded the extract_slices using something like this pattern

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I know that we can merge consecutive extractSlice. But, I am sorry that we have to just keep these consecutive extractSlice for selection. The most intuitive example is as below:

%0 = producer
%1 = scf.for(%arg1 = %0)
   %2 = extract %arg1
   %3 = scf.for(%arg2 = %2)
      %4 = extract %args2

In some cases, the producer could be only fused into the outer candidate slice due to semantic of producer, i.e. reduce/pack/unpack those ops have specific demands of tiling size. If we merge the two consecutive extractSlice, it may break this fusion. And that is why we emphasize multi-level candidates here.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That is OK, but the point here is to fuse with %2 you dont need to look through extract_slices.... Just not sure of why there would be a case where you need to look through extract_slices to get to the producer, when you cant collapse all the extract slices to begin with.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That is just simple test. If we have one more candidate slice, saying:

%0 = producer
%1 = scf.for(%arg1 = %0)
   %2 = extract %arg1
   %3 = scf.for(%arg2 = %2)
      %4 = extract %args2
      %5 = scf.for(%arg3 = %4)
          %6 = extract %args3
          tiled_consumer ins(%6)

where %2 and %4 are both valid candidate to fuse producer. And the user just want to manually pass %4 to tileAndFuseProducerOfSlice because %4 usually has smaller tileSize in most cases. Then, the current getUntiledProducerFromSliceSource will fail to get real producer without looking through extract_slices.

Another use case is that if someone want to enable automatic fusion by starting with tiled_consumer just like what tileConsumerAndFuseProducersUsingSCF did. The available extract slice we can find via operand of tiled_consumer is only the nearest candidate %6. Then, we also need to look through all extract_slices to get real producer.

Meanwhile, we dont know which candidate slice could be fuse until we know what the exact producer is. Thus, it is unreasonable to collapse all the extract slices before fusion.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure this reasoning holds. If you want to fuse %4 and not %2 you are effectively "fusing" the slices and the doing the producer fusion. It would be better to stage it so that you fuse the slices so that you get the producer directly and then do the fusion. So effectively this change is trying to do both of this

  1. Combine extract slices
  2. Do producer fusion
    in code. That is just added complexity. It would be easier to just combine the extract slices first and then fuse the producer.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

%0 = producerOp
scf.for(%1=%0)
   %2 = extract_slice %1
   scf.for(%3=%2)
      %4 = extract_slice %3

Besides, I am worried about merging two candidates interleaved by a scf.for with its inits involved is not trivial...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi, @MaheshRavishankar. Since PR #108318 has been broken down into three parts:

  1. support single nested scf.for (merged, great thanks for your review!)
  2. multi-level candidates
  3. multiple users

I think the second part is similar to this patch, no matter producer or consumer fusion. In general, this patch focus on how to deal with consecutive candidates interleaved by a scf.for.

Lets look back what the major argument remains so far to speed up our next stage of review:

  1. why need to look through the total chain of candidate slice?
  2. why not merge two consecutive candidates?

The answer is listed in above several threads. The root cause behind are two different solutions to solve multi-level candidates cases:

  1. merge consecutive candidates in advance.
  2. iteratively fuse producer/consumer with one of candidates step by step(from outer to inner).

IMO, the most concern about the first solution is that it may introduce too many additional transform regarding scf.for. E.g..

%0 = producerOp
scf.for(%1=%0)
   %2 = extract_slice %1
   scf.for(%3=%2)
      %4 = extract_slice %3
      ... 
      yield %x

As I explained above, this scenario is a little different from what MergeConsecutiveExtractSlice expects due to scf.for.

If we just merge consecutive extract slice(BTW, here we also need look through to penetrate scf.for ), the possible resultant IR looks like below:

%0 = producerOp
scf.for(%1=%0)
   %2 = extract_slice %1
   scf.for(%3=%2)
      %4 = extract_slice %1 [newOffsets] [newSizes]
      ... 
      yield %x

Then the problem is that how we modify new inits of scf.for with latest dpsInit of fused producerOp? (getUntiledProducerFromSliceSource will fail to find destinationIterArg)

On the other hand, if we modify %3=%2 at the same time, the resultant IR would become:

%0 = producerOp
scf.for(%1=%0)
   %2 = extract_slice %1
   scf.for(%3=%1)
      %4 = extract_slice %3 [newOffsets] [newSizes]
      ... 
      yield %x_new

First of all, this transform has already introduced more code change and there is an assumption that %3 has only single one user. Moreover, considering the semantic of scf.for, it seems that we also need to modify yield value %x(and even its producer, like insert_slice) to make the result of scf.for matched with its inits .

In summary, I am afraid that this kind of complexity is much more than the second solution based on simple iterative process of reusing existed function.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanka for the explanation. I see the issues a bit more clearly now. I havent reviewed the code yet. I dont know if this has been rebased on top of what was submitted.
But here is the complexity that I am concerned about. In you example, you need to somewhere keep track of the sequence of extract slices that you need to walk to get the producer because the actual offset and size you need is obtained by "combining" the offsets and sizes of all the slices to get the "real offset" and size. Also I am not convinced that fusing the first extract slice + producer and then doing the second extract slice + producer is not feasible. That should be always possible.

So if you start with this

%0 = producerOp
scf.for(%1=%0)
   %2 = extract_slice %1
   scf.for(%3=%2)
      %4 = extract_slice %3
      ... 
      yield %x

you should always be able to do

scf.for(%1=%0)
   %2 = tiled producer 1
   scf.for(%3=%2)
      %4 = extract_slice %3
      ... 
      yield %x

and then you do

scf.for(%1=%0)
   scf.for(%3=%2)
      %4 = tiled producer 2
      ... 
      yield %x

The amount of state you need to carry during the transformation to fuse with one extract slice is prohibitively high IMO. That will make it hard to change/fix the transformation.

Copy link
Contributor Author

@Yun-Fly Yun-Fly Sep 19, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I dont know if this has been rebased on top of what was submitted.

Not rebased yet, because I think it is more import to reach agreement on solution before pushing...

you need to somewhere keep track of the sequence of extract slices that you need to walk to get the producer because the actual offset and size you need is obtained by "combining" the offsets and sizes of all the slices to get the "real offset" and size.

Yes, from the view of consequentialism, the real offset is indeed combined. However, there are something else we need to address, like what I explained above regarding loop transform. In this way, each two of extract_slice (interleaved by a scf.for) merged, the certain scf.for(and its yieldValue, etc...) need transformed. Is the amount of state need to carry similar to what I do in iterative fashion?

Also I am not convinced that fusing the first extract slice + producer and then doing the second extract slice + producer is not feasible. That should be always possible.

It is decided by concrete semantic of tilable op, such as reduce/pack/unpack op. Lets say fusing producer unpack into following loop:

// unpack tensor from ABab to AB
%1 = tensor.unpack ... inner_tiles = [32, 32] ... : tensor<2x2x32x32xbf16> -> tensor<64x64xbf16>
scf.for(0, 2) { // tile A dimension
 extract_slice1 // tileSize<32, 64>
 scf.for(0, 4) { // tile B dimension
    extract_slice2 // tileSize<32, 16>
    ...
 }
}

As you can see, the tileSize comes from extract_slice2 is <32,16>, but tensor.unpack prefer perfect tiling case, i.e. tileSize should be exactly divided by inner_tiles. So, it may be not feasible in this case.

BTW, fusing consumer reduce maybe more intuitive:

%0 = scf.for(0, 2) { // tile 128 dimension
 scf.for(0, 4) { // tile B dimension
    ...
   insert_slice1 // tileSize<64, 64>
 }
 insert_slice1 // tileSize<64, 256>
}
%2 = linalg.reduce { arith.addf } ins(%0 : tensor<128x256xf32>) outs(%1: tensor<128xf32>) dimensions = [1]

We could not furtherly fuse reduce into the inner-most insert slice, otherwise, it will lead to partial reduce(its another topic).

So if you start with this

you should always be able to do

and then you do

Yes, it is exactly what I do now. IIRC, that is also what you suggest before in nested consumer fusion thread... (In fact, my first version of implement regarding nested consumer fusion is just combining offset and size of several candidate slice, however, it got negated because of too many code changes...)

That will make it hard to change/fix the transformation.

Current iterative method break down the whole transformation into several repeated logic by existed tiledAndFuseProducerOfSlice, which should be much easier to debug, at least from my view..

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ping

Copy link
Contributor

@MaheshRavishankar MaheshRavishankar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Left a few more comments. Thanks for being patient. I am a bit busy this week, but I am more free next week. If you are on discord (IREE discord or MLIR discord), we can connect there and I am happy to work with you with a faster cadence to help you land what you need upstream, and maybe suggest things you can keep downstream since it is valid only for your context.

@Yun-Fly
Copy link
Contributor Author

Yun-Fly commented Jul 24, 2024

Thanks for being patient. I am a bit busy this week, but I am more free next week. If you are on discord (IREE discord or MLIR discord), we can connect there and I am happy to work with you with a faster cadence...

Yes, surely. I just ping you there. Feel free to have a talk if you have time.

@MaheshRavishankar
Copy link
Contributor

Thanks for being patient. I am a bit busy this week, but I am more free next week. If you are on discord (IREE discord or MLIR discord), we can connect there and I am happy to work with you with a faster cadence...

Yes, surely. I just ping you there. Feel free to have a talk if you have time.

I dont think I got the ping.

@Yun-Fly Yun-Fly force-pushed the yunfei/fuse_producer_multilevel_slice branch from 9feadc4 to c52cf40 Compare August 7, 2024 04:27
@Yun-Fly Yun-Fly force-pushed the yunfei/fuse_producer_multilevel_slice branch from c52cf40 to 23796bf Compare September 3, 2024 08:19
Copy link
Contributor

@MaheshRavishankar MaheshRavishankar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I dont know if this commit is still relevant. I was looking for a commit where things were broken up a bit more. If this commit is still relevant, I responded to the outstanding comment here.

/// @param candidateSliceOp: %4 = extract %args2
/// @param backwardSlice: in-out parameter populated by backward extractSliceOps
/// @return OpResult Producer : %0 = producer
static FailureOr<OpResult> getRealProducerFromExtractSliceOp(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure this reasoning holds. If you want to fuse %4 and not %2 you are effectively "fusing" the slices and the doing the producer fusion. It would be better to stage it so that you fuse the slices so that you get the producer directly and then do the fusion. So effectively this change is trying to do both of this

  1. Combine extract slices
  2. Do producer fusion
    in code. That is just added complexity. It would be easier to just combine the extract slices first and then fuse the producer.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants