Skip to content

Commit

Permalink
extend fuse producer to multi-level extractSliceOp
Browse files Browse the repository at this point in the history
  • Loading branch information
Yun-Fly committed Jul 6, 2024
1 parent 4762f3b commit 8c9a585
Show file tree
Hide file tree
Showing 5 changed files with 293 additions and 1 deletion.
8 changes: 8 additions & 0 deletions mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
Original file line number Diff line number Diff line change
Expand Up @@ -136,11 +136,19 @@ struct SCFFuseProducerOfSliceResult {
Value tiledAndFusedProducer; // Tile and fused producer value.
SmallVector<Operation *> tiledOps;
};
std::optional<SCFFuseProducerOfSliceResult>
tileAndFuseProducerOfSliceImpl(RewriterBase &rewriter,
tensor::ExtractSliceOp candidateSliceOp,
MutableArrayRef<LoopLikeOpInterface> loops);

std::optional<SCFFuseProducerOfSliceResult>
tileAndFuseProducerOfSlice(RewriterBase &rewriter,
tensor::ExtractSliceOp candidateSliceOp,
MutableArrayRef<LoopLikeOpInterface> loops);

std::optional<SCFFuseProducerOfSliceResult>
tileAndFuseProducerOfSlice(RewriterBase &rewriter, Operation *candidateSliceOp);

/// Reconstruct the fused producer from within the tiled-and-fused code. Based
/// on the slice of the producer computed in place it is possible that within
/// the loop nest same slice of the producer is computed multiple times. It is
Expand Down
131 changes: 130 additions & 1 deletion mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -835,7 +835,7 @@ getUntiledProducerFromSliceSource(OpOperand *source,
/// Implementation of fusing producer of a single slice by computing the
/// slice of the producer in-place.
std::optional<scf::SCFFuseProducerOfSliceResult>
mlir::scf::tileAndFuseProducerOfSlice(
mlir::scf::tileAndFuseProducerOfSliceImpl(
RewriterBase &rewriter, tensor::ExtractSliceOp candidateSliceOp,
MutableArrayRef<LoopLikeOpInterface> loops) {
// 1. Get the producer of the source (potentially walking through
Expand Down Expand Up @@ -949,6 +949,135 @@ mlir::scf::tileAndFuseProducerOfSlice(
tileAndFuseResult->tiledOps};
}

/// Get the Root source of target ExtractSliceOp
/// %0 =
/// %1 = scf.for(%arg1 = %0)
/// %2 = extract %arg1
/// %3 = scf.for(%arg2 = %2)
/// %4 = extract %args2
/// ...
/// @param targetSliceOp: %4 = extract %args2
/// @param extractSliceOpChain: chain of all related extract sliceOp
/// @return Value of Root Source : %0
static FailureOr<Value> getRootSourceOfExtractSliceOp(
Operation *targetSliceOp,
SmallVectorImpl<tensor::ExtractSliceOp> &extractSliceOpChain,
int curDepth = 0, int maxDepth = 5) {
assert(isa<tensor::ExtractSliceOp>(targetSliceOp));
// control recursive time in avoid of stack overflow
if (curDepth > maxDepth)
return failure();

auto extractOp = cast<tensor::ExtractSliceOp>(targetSliceOp);
extractSliceOpChain.push_back(extractOp);
Value rootSource = extractOp.getSourceMutable().get();

while (true) {
if (auto iterArg = dyn_cast<BlockArgument>(rootSource)) {
if (auto outerLoop = dyn_cast<LoopLikeOpInterface>(
iterArg.getOwner()->getParentOp())) {
rootSource = outerLoop.getTiedLoopInit(iterArg)->get();
continue;
}
return failure();
} else if (auto sliceOp =
rootSource.getDefiningOp<tensor::ExtractSliceOp>()) {
// walk up loop to find larger candidate extractSliceOp
return getRootSourceOfExtractSliceOp(sliceOp, extractSliceOpChain,
curDepth + 1);
}
break;
}
return rootSource;
}

/// Recursively find the outer nest loops of given loop(included) while the
/// predict function succeed, sorted from outer to inner.
///
/// @param loop: target loop, note that this loop will be also included. I.e.
/// if no other nest loops were found, just return itself.
/// @param pred: predict function, the termination condition of recursive
/// process.
/// @return Outer Nest Loops: nest loops outside given target loop(included).
///
/// E.g.
///
/// ```
/// %0 = scf.for()
/// %1 = scf.for()
/// %2 = scf.for()
/// ```
///
/// If `%2 = scf.for` is given without specific prediction function, this
/// function will return three nest loops: %0 + %1 + %2.
static SmallVector<LoopLikeOpInterface>
getOuterNestLoopsWhile(LoopLikeOpInterface loop,
std::function<LogicalResult(LoopLikeOpInterface)> pred) {
SmallVector<LoopLikeOpInterface> nestLoops = {loop};
auto outerLoop = dyn_cast<LoopLikeOpInterface>(loop->getParentOp());
while (outerLoop && succeeded(pred(outerLoop))) {
nestLoops.push_back(outerLoop);
outerLoop = dyn_cast<LoopLikeOpInterface>(outerLoop->getParentOp());
}
// sorted from outer to inner
return {nestLoops.rbegin(), nestLoops.rend()};
}

/// Enhanced version of `tileAndFuseProducerOfSliceImpl`, which can deal with
/// multi-level `extractSliceOp`. E.g.
///
/// ```
/// %0 = untiled_producer
/// %1 = scf.for(%arg1 = %0)
/// %2 = extract %arg1
/// %3 = scf.for(%arg2 = %2)
/// %4 = extract %args2
/// %5 = tiled_consumer ins(%4)
/// ```
std::optional<scf::SCFFuseProducerOfSliceResult>
mlir::scf::tileAndFuseProducerOfSlice(RewriterBase &rewriter,
Operation *candidateSliceOp) {
SmallVector<tensor::ExtractSliceOp> sliceOpChain;
if (failed(getRootSourceOfExtractSliceOp(candidateSliceOp, sliceOpChain))) {
return std::nullopt;
}

std::optional<scf::SCFFuseProducerOfSliceResult> fuseProducerResult;
// reverse from outer to inner
std::reverse(sliceOpChain.begin(), sliceOpChain.end());
// multiple application of `tileAndFuseProducerOfSliceImpl`
for (auto &&[index, sliceOp] : llvm::enumerate(sliceOpChain)) {
// get nest loops between next candidate sliceOp and tiled producer.
auto whileProducerOutOfBlock =
[&fuseProducerResult](LoopLikeOpInterface loop) -> LogicalResult {
if (fuseProducerResult) {
Block &body = loop->getRegion(0).front();
if (fuseProducerResult->tiledAndFusedProducer.getDefiningOp()
->getBlock() == &body)
return failure();
}
return success();
};
SmallVector<LoopLikeOpInterface> outerLoops =
getOuterNestLoopsWhile(sliceOp->getParentOfType<LoopLikeOpInterface>(),
whileProducerOutOfBlock);
fuseProducerResult =
tileAndFuseProducerOfSliceImpl(rewriter, sliceOp, outerLoops);
if (!fuseProducerResult) {
return std::nullopt;
}
}
return fuseProducerResult;
}

/// To be compatible with previous behavior
std::optional<scf::SCFFuseProducerOfSliceResult>
mlir::scf::tileAndFuseProducerOfSlice(
RewriterBase &rewriter, tensor::ExtractSliceOp candidateSliceOp,
MutableArrayRef<LoopLikeOpInterface> loops) {
return tileAndFuseProducerOfSliceImpl(rewriter, candidateSliceOp, loops);
}

/// Reconstruct the fused producer from within the tiled-and-fused code.
LogicalResult mlir::scf::yieldReplacementForFusedProducer(
RewriterBase &rewriter, tensor::ExtractSliceOp sliceOp,
Expand Down
86 changes: 86 additions & 0 deletions mlir/test/Interfaces/TilingInterface/tile-and-fuse-producer.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
// RUN: mlir-opt --transform-interpreter --cse --split-input-file %s | FileCheck %s

#map = affine_map<(d0) -> (d0 * 128)>
module {
func.func @gemm_fill_fusion_multi_level_extract_slice(%arg0: tensor<256x512xf32>, %arg1: tensor<512x256xf32>, %arg2: tensor<256x256xf32>) -> tensor<256x256xf32> {
%c0 = arith.constant 0 : index
%c64 = arith.constant 64 : index
%c128 = arith.constant 128 : index
%cst = arith.constant 0.000000e+00 : f32
%dest0 = tensor.empty() : tensor<256x256xf32>
%dest1 = linalg.fill ins(%cst : f32) outs(%dest0 : tensor<256x256xf32>) -> tensor<256x256xf32>
%1 = scf.forall (%arg3, %arg4) in (2, 2) shared_outs(%arg5 = %dest1) -> tensor<256x256xf32> {
%iv0 = affine.apply #map(%arg3)
%iv1 = affine.apply #map(%arg4)
%extracted_slice_1 = tensor.extract_slice %arg5[%iv0, %iv1] [128, 128] [1, 1] : tensor<256x256xf32> to tensor<128x128xf32>
%extracted_slice_2 = tensor.extract_slice %arg0[%iv0, 0] [128, 512] [1, 1] : tensor<256x512xf32> to tensor<128x512xf32>
%extracted_slice_3 = tensor.extract_slice %arg1[0, %iv1] [512, 128] [1, 1] : tensor<512x256xf32> to tensor<512x128xf32>
%2 = scf.for %arg6 = %c0 to %c128 step %c64 iter_args(%arg7 = %extracted_slice_1) -> (tensor<128x128xf32>) {
%3 = scf.for %arg8 = %c0 to %c128 step %c64 iter_args(%arg9 = %arg7) -> (tensor<128x128xf32>) {
%extracted_slice_4 = tensor.extract_slice %arg9[%arg6, %arg8] [64, 64] [1, 1] : tensor<128x128xf32> to tensor<64x64xf32>
%extracted_slice_5 = tensor.extract_slice %extracted_slice_2[%arg6, 0] [64, 512] [1, 1] : tensor<128x512xf32> to tensor<64x512xf32>
%extracted_slice_6 = tensor.extract_slice %extracted_slice_3[0, %arg8] [512, 64] [1, 1] : tensor<512x128xf32> to tensor<512x64xf32>
%4 = linalg.matmul ins(%extracted_slice_5, %extracted_slice_6 : tensor<64x512xf32>, tensor<512x64xf32>) outs(%extracted_slice_4 : tensor<64x64xf32>) -> tensor<64x64xf32>
%insert_slice = tensor.insert_slice %4 into %arg9[%arg6, %arg8] [64, 64] [1, 1] : tensor<64x64xf32> into tensor<128x128xf32>
scf.yield %insert_slice : tensor<128x128xf32>
}
scf.yield %3 : tensor<128x128xf32>
}
scf.forall.in_parallel {
tensor.parallel_insert_slice %2 into %arg5[%iv0, %iv1] [128, 128] [1, 1] : tensor<128x128xf32> into tensor<256x256xf32>
}
}
return %1 : tensor<256x256xf32>
}
}

module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
%matmul = transform.structured.match ops{["linalg.matmul"]} in %arg1
: (!transform.any_op) -> !transform.any_op
%yield = transform.get_producer_of_operand %matmul[2]
: (!transform.any_op) -> !transform.any_op
%a, %b = transform.test.fuse_producer %yield
: (!transform.any_op) -> (!transform.any_op, !transform.any_op)
transform.yield
}
}

// CHECK: #[[MAP0:.*]] = affine_map<(d0) -> (d0 * 128)>
// CHECK: func.func @gemm_fill_fusion_multi_level_extract_slice(
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<256x512xf32>
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<512x256xf32>
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: tensor<256x256xf32>
// CHECK: %[[C0:.*]] = arith.constant 0 : index
// CHECK: %[[dest0:.*]] = tensor.empty() : tensor<256x256xf32>
// CHECK: %[[FORALL_RESULT:.*]] = scf.forall (%[[IV1:.*]], %[[IV2:.*]]) in (2, 2)
// CHECK-SAME: shared_outs(%[[INIT_ARG0:.*]] = %[[dest0]])
// CHECK-SAME: {
// CHECK: %[[AFFINE_IV1:.*]] = affine.apply #[[MAP0]](%[[IV1]])
// CHECK: %[[AFFINE_IV2:.*]] = affine.apply #[[MAP0]](%[[IV2]])
// CHECK: %[[FILL_OUT_SLICE0:.*]] = tensor.extract_slice %[[INIT_ARG0]][%[[AFFINE_IV1]], %[[AFFINE_IV2]]] [128, 128] [1, 1]
// CHECK: %[[INPUT_SLICE0:.*]] = tensor.extract_slice %[[ARG0]][%[[AFFINE_IV1]], 0] [128, 512] [1, 1]
// CHECK: %[[WEIGHT_SLICE0:.*]] = tensor.extract_slice %[[ARG1]][0, %[[AFFINE_IV2]]] [512, 128] [1, 1]
// CHECK: %[[LOOP_RESULT1:.*]] = scf.for %[[IV3:.*]] = %[[C0]]
// CHECK-SAME: iter_args(%[[INIT_ARG1:.*]] = %[[FILL_OUT_SLICE0]])
// CHECK-SAME: {
// CHECK: %[[LOOP_RESULT2:.*]] = scf.for %[[IV4:.*]] = %[[C0]]
// CHECK-SAME: iter_args(%[[INIT_ARG2:.*]] = %[[INIT_ARG1]])
// CHECK-SAME: {
// CHECK: %[[FILL_OUT_SLICE1:.*]] = tensor.extract_slice %[[INIT_ARG2]][%[[IV3]], %[[IV4]]] [64, 64] [1, 1]
// CHECK: %[[TILED_FILL_OUT:.*]] = linalg.fill
// CHECK-SAME: outs(%[[FILL_OUT_SLICE1]] :
// CHECK: %[[INPUT_SLICE1:.*]] = tensor.extract_slice %[[INPUT_SLICE0]][%[[IV3]], 0] [64, 512] [1, 1]
// CHECK: %[[WEIGHT_SLICE1:.*]] = tensor.extract_slice %[[WEIGHT_SLICE0]][0, %[[IV4]]] [512, 64] [1, 1]
// CHECK: %[[TILED_MAT_OUT:.*]] = linalg.matmul
// CHECK-SAME: outs(%[[TILED_FILL_OUT]] :
// CHECK: %[[INSERT_MAT:.*]] = tensor.insert_slice %[[TILED_MAT_OUT]] into %[[INIT_ARG2]][%[[IV3]], %[[IV4]]] [64, 64] [1, 1]
// CHECK: scf.yield %[[INSERT_MAT]] :
// CHECK: }
// CHECK: scf.yield %[[LOOP_RESULT2]] :
// CHECK: }
// CHECK: scf.forall.in_parallel {
// CHECK: tensor.parallel_insert_slice %[[LOOP_RESULT1]] into %[[INIT_ARG0]][%[[AFFINE_IV1]], %[[AFFINE_IV2]]] [128, 128] [1, 1]
// CHECK: }
// CHECK: }
// CHECK: return %[[FORALL_RESULT]] :
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,56 @@ transform::TestFuseAndYieldOp::apply(TransformRewriter &rewriter,
: DiagnosedSilenceableFailure::success();
}

//===----------------------------------------------------------------------===//
// TestFuseProducerOp
//===----------------------------------------------------------------------===//

/// Apply fusing of producer transformation to all payload ops and store both
/// the original producer operation as well as the fused producer operation.
template <typename Range>
static LogicalResult
applyFuseProducer(RewriterBase &rewriter, Operation *transformOp,
Range &&payloadOps, TransformResults &transformResults) {
SmallVector<Operation *> originalProducerOps;
SmallVector<Operation *> fusedProducerOps;

for (Operation *target : payloadOps) {
rewriter.setInsertionPoint(target);

std::optional<scf::SCFFuseProducerOfSliceResult> fuseProducerResults =
scf::tileAndFuseProducerOfSlice(rewriter, target);

if (!fuseProducerResults)
return failure();

// Report back the relevant handles to the transform op.
originalProducerOps.push_back(fuseProducerResults->origProducer.getOwner());
fusedProducerOps.push_back(fuseProducerResults->tiledOps[0]);
}

transformResults.set(transformOp->getOpResult(0), originalProducerOps);
transformResults.set(transformOp->getOpResult(1), fusedProducerOps);
return success();
}

DiagnosedSilenceableFailure
transform::TestFuseProducerOp::apply(TransformRewriter &rewriter,
TransformResults &transformResults,
TransformState &state) {
LogicalResult result =
applyFuseProducer(rewriter, getOperation(),
state.getPayloadOps(getTarget()), transformResults);
return failed(result) ? DiagnosedSilenceableFailure::definiteFailure()
: DiagnosedSilenceableFailure::success();
}

void transform::TestFuseProducerOp::getEffects(
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
consumesHandle(getTargetMutable(), effects);
producesHandle(getOperation()->getOpResults(), effects);
modifiesPayload(effects);
}

//===----------------------------------------------------------------------===//
// TestFuseConsumerOp
//===----------------------------------------------------------------------===//
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,25 @@ def TestFuseAndYieldOp : Op<Transform_Dialect, "test.fuse_and_yield",
}];
}

def TestFuseProducerOp : Op<Transform_Dialect, "test.fuse_producer",
[DeclareOpInterfaceMethods<TransformOpInterface>,
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
ReportTrackingListenerFailuresOpTrait]> {
let description = [{
Fuses the producer of the operation pointed to by the target handle
using the options provided as attributes.
}];

let arguments =
(ins TransformHandleTypeInterface:$target);
let results = (outs TransformHandleTypeInterface:$producer,
TransformHandleTypeInterface:$fused_producer);

let assemblyFormat = [{
$target attr-dict `:` functional-type(operands, results)
}];
}

def TestFuseConsumerOp : Op<Transform_Dialect, "test.fuse_consumer",
[DeclareOpInterfaceMethods<TransformOpInterface>,
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
Expand Down

0 comments on commit 8c9a585

Please sign in to comment.