Skip to content

Commit

Permalink
Add pass to bubble-up extract_slice operations. (#18332)
Browse files Browse the repository at this point in the history
This adds pass to replace a `tensor.extract_slice` operation with a
slice of the producer. In general there might be more opportunities to
use this pass more aggressively (like when an operation has a single
use which is a slice), but for now this is being done only for
bit-extend operations.

Co-authored-by: Ian Wood <[email protected]>
  • Loading branch information
MaheshRavishankar and IanWood1 committed Aug 28, 2024
1 parent 6e3be28 commit d6762d4
Show file tree
Hide file tree
Showing 8 changed files with 242 additions and 0 deletions.
1 change: 1 addition & 0 deletions compiler/src/iree/compiler/DispatchCreation/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ iree_compiler_cc_library(
name = "DispatchCreation",
srcs = [
"BubbleUpExpandShapes.cpp",
"BubbleUpExtractSlices.cpp",
"CloneProducersIntoDispatchRegions.cpp",
"CollapseDimensions.cpp",
"CollapseReductionDimensions.cpp",
Expand Down
134 changes: 134 additions & 0 deletions compiler/src/iree/compiler/DispatchCreation/BubbleUpExtractSlices.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
// Copyright 2024 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#include "iree/compiler/Dialect/LinalgExt/Utils/Utils.h"
#include "llvm/ADT/STLExtras.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/Linalg/Utils/Utils.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tensor/Transforms/Transforms.h"
#include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"

namespace mlir::iree_compiler::DispatchCreation {

#define GEN_PASS_DEF_BUBBLEUPEXTRACTSLICESPASS
#include "iree/compiler/DispatchCreation/Passes.h.inc"

namespace {

// Convert extract_slice(dequant) to dequant(extract_slice)
//
// Because `extract_slice` ops and dequantize-like ops get cloned into regions
// later, it's okay to bubble up through multi-use dequant ops.
struct BubbleUpExtract : OpRewritePattern<tensor::ExtractSliceOp> {
using OpRewritePattern<tensor::ExtractSliceOp>::OpRewritePattern;

LogicalResult matchAndRewrite(tensor::ExtractSliceOp sliceOp,
PatternRewriter &rewriter) const final {
Value source = sliceOp.getSource();
auto genericOp = source.getDefiningOp<linalg::GenericOp>();
if (!genericOp || genericOp->getNumResults() != 1) {
return rewriter.notifyMatchFailure(
sliceOp, "expected source to implement `linalg::LinalgOp` and have a "
"single result");
}

if (!IREE::LinalgExt::isBitExtendOp(genericOp)) {
return rewriter.notifyMatchFailure(
sliceOp, "expected source to be dequantize-like");
}

if (!sliceOp.hasUnitStride()) {
return rewriter.notifyMatchFailure(sliceOp, "expected unit stride");
}

if (!llvm::all_of(genericOp.getIndexingMapsArray(), [](AffineMap map) {
return map.isProjectedPermutation();
})) {
return rewriter.notifyMatchFailure(
genericOp,
"expected generic op to have all projected permutation maps");
}

if (genericOp.hasIndexSemantics()) {
return rewriter.notifyMatchFailure(
genericOp, "pattern doesn't support index semantics");
}

Value replacement;
linalg::GenericOp swappedOp;
{
FailureOr<TilingResult> tilingResult =
tensor::replaceExtractSliceWithTiledProducer(rewriter, sliceOp,
genericOp->getResult(0));
assert(succeeded(tilingResult) && "failed to swap extract_slice with op");
assert(tilingResult->tiledOps.size() == 1);
replacement = tilingResult->tiledValues[0];
swappedOp = cast<linalg::GenericOp>(tilingResult->tiledOps[0]);
}

// Check if this is a rank-reducing slice, if so we need to fold the unit
// dimensions of the op.
// This is necessary because `replaceExtractSliceWithTiledProducer` does not
// take into account the `extract_slice`'s implicit rank reduction. The
// operations generated by that function will have any unit dims that were
// removed by the original `extract_slice`. Folding them away ensures that
// the types match.
if (sliceOp.getSourceType().getRank() !=
sliceOp.getResultType().getRank()) {

llvm::SmallBitVector droppedDims = sliceOp.getDroppedDims();
// Get the indexing map for the result.
AffineMap resultMap =
swappedOp.getIndexingMapMatchingResult(swappedOp->getResult(0));
linalg::ControlDropUnitDims options;
options.rankReductionStrategy = linalg::ControlDropUnitDims::
RankReductionStrategy::ExtractInsertSlice;
options.controlFn = [&](Operation *op) -> SmallVector<unsigned> {
SmallVector<unsigned> droppedDimsVec;
for (auto [index, expr] : llvm::enumerate(resultMap.getResults())) {
if (!droppedDims.test(index)) {
continue;
}
auto dimExpr = cast<AffineDimExpr>(expr);
droppedDimsVec.push_back(dimExpr.getPosition());
}
return droppedDimsVec;
};
FailureOr<linalg::DropUnitDimsResult> dropUnitDims =
linalg::dropUnitDims(rewriter, swappedOp, options);
assert(succeeded(dropUnitDims) &&
"failed to drop unit dims of produced operation");
swappedOp = dropUnitDims->resultOp;
replacement = swappedOp->getResult(0);
}
rewriter.replaceOp(sliceOp, replacement);
return success();
}
};

struct BubbleUpExtractSlicesPass
: impl::BubbleUpExtractSlicesPassBase<BubbleUpExtractSlicesPass> {
void runOnOperation() override {
MLIRContext *context = &getContext();
{
RewritePatternSet patterns(context);
patterns.insert<BubbleUpExtract>(context);
if (failed(applyPatternsAndFoldGreedily(getOperation(),
std::move(patterns)))) {
return signalPassFailure();
}
}
}
};
} // namespace

} // namespace mlir::iree_compiler::DispatchCreation
1 change: 1 addition & 0 deletions compiler/src/iree/compiler/DispatchCreation/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ iree_cc_library(
"Passes.h"
SRCS
"BubbleUpExpandShapes.cpp"
"BubbleUpExtractSlices.cpp"
"CloneProducersIntoDispatchRegions.cpp"
"CollapseDimensions.cpp"
"CollapseReductionDimensions.cpp"
Expand Down
1 change: 1 addition & 0 deletions compiler/src/iree/compiler/DispatchCreation/Passes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,7 @@ void addDispatchRegionCreationPreprocessingPasses(OpPassManager &passManager) {
// elementwise operation into higher dimensions for more fusion
// opportunities.
.addPass(DispatchCreation::createBubbleUpExpandShapesPass)
.addPass(DispatchCreation::createBubbleUpExtractSlicesPass)
.addPass(IREE::Flow::createCanonicalizerPass)
.addPass(mlir::createCSEPass)

Expand Down
7 changes: 7 additions & 0 deletions compiler/src/iree/compiler/DispatchCreation/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,13 @@ def TransposeGenericOpsPass :
];
}

def BubbleUpExtractSlicesPass : Pass<"iree-dispatch-creation-bubble-up-extract-slices"> {
let summary = "Bubble up `extract_slice` ops through Linalg-like ops";
let dependentDialects = [
"mlir::affine::AffineDialect",
];
}

//===---------------------------------------------------------------------===//
// Dispatch region creation passes.
//===---------------------------------------------------------------------===//
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ iree_lit_test_suite(
"form_dispatch_regions.mlir",
"dispatch_linalg_on_tensors.mlir",
"convert_region_to_workgroups.mlir",
"bubble_up_extract_slice.mlir",
"form_dispatch_workgroups.mlir",
"dispatch_linalg_ext_fusion.mlir",
"hoist_encoding_ops.mlir",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ iree_lit_test_suite(
lit
SRCS
"attention_fuse_by_expansion.mlir"
"bubble_up_extract_slice.mlir"
"clone_producers_into_dispatch_regions.mlir"
"collapse_dimensions.mlir"
"collapse_linalg_generic_on_tensors.mlir"
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
// RUN: iree-opt --split-input-file --iree-dispatch-creation-bubble-up-extract-slices --iree-flow-canonicalize %s | FileCheck %s

util.func public @bubble_up_extract_rank_reduce(%arg0 : tensor<1024x7x7x2xi8>) -> tensor<1024x7x7xf32>{
%0 = tensor.empty() : tensor<1024x7x7x2xf32>
%cst = arith.constant 5.000000e-01 : f32
%1 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg0 : tensor<1024x7x7x2xi8>) outs(%0 : tensor<1024x7x7x2xf32>) {
^bb0(%in: i8, %out: f32):
%4 = arith.extsi %in : i8 to i32
%5 = arith.sitofp %4 : i32 to f32
%6 = arith.mulf %5, %cst : f32
linalg.yield %6 : f32
} -> tensor<1024x7x7x2xf32>

%extracted_slice = tensor.extract_slice %1[0, 0, 0, 1] [1024, 7, 7, 1] [1, 1, 1, 1] : tensor<1024x7x7x2xf32> to tensor<1024x7x7xf32>
util.return %extracted_slice : tensor<1024x7x7xf32>
}

// CHECK-LABEL: @bubble_up_extract_rank_reduce
// CHECK: %[[EXTRACT:.+]] = tensor.extract_slice
// CHECK: %[[GENERIC:.+]] = linalg.generic
// CHECK: util.return %[[GENERIC]]

// -----

util.func public @bubble_up_extract(%arg0 : tensor<1024x7x7x2xi8>) -> tensor<1024x7x7x1xf32>{
%0 = tensor.empty() : tensor<1024x7x7x2xf32>
%cst = arith.constant 5.000000e-01 : f32
%1 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg0 : tensor<1024x7x7x2xi8>) outs(%0 : tensor<1024x7x7x2xf32>) {
^bb0(%in: i8, %out: f32):
%4 = arith.extsi %in : i8 to i32
%5 = arith.sitofp %4 : i32 to f32
%6 = arith.mulf %5, %cst : f32
linalg.yield %6 : f32
} -> tensor<1024x7x7x2xf32>

%extracted_slice = tensor.extract_slice %1[0, 0, 0, 1] [1024, 7, 7, 1] [1, 1, 1, 1] : tensor<1024x7x7x2xf32> to tensor<1024x7x7x1xf32>
util.return %extracted_slice : tensor<1024x7x7x1xf32>
}

// CHECK-LABEL: @bubble_up_extract
// CHECK: %[[EXTRACT:.+]] = tensor.extract_slice
// CHECK: %[[GENERIC:.+]] = linalg.generic
// CHECK: util.return %[[GENERIC]]

// -----

util.func public @bubble_up_extract_multi_input(%arg0 : tensor<1024x7x7x2xi8>, %arg1 : tensor<1024x7x7x2xi8>) -> tensor<1024x7x7x1xf32>{
%0 = tensor.empty() : tensor<1024x7x7x2xf32>
%cst = arith.constant 5.000000e-01 : f32
%1 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg0, %arg1 : tensor<1024x7x7x2xi8>, tensor<1024x7x7x2xi8>) outs(%0 : tensor<1024x7x7x2xf32>) {
^bb0(%in: i8, %in_0 : i8, %out: f32):
%4 = arith.extsi %in : i8 to i32
%5 = arith.sitofp %4 : i32 to f32
%6 = arith.mulf %5, %cst : f32
linalg.yield %6 : f32
} -> tensor<1024x7x7x2xf32>

%extracted_slice = tensor.extract_slice %1[0, 0, 0, 1] [1024, 7, 7, 1] [1, 1, 1, 1] : tensor<1024x7x7x2xf32> to tensor<1024x7x7x1xf32>
util.return %extracted_slice : tensor<1024x7x7x1xf32>
}

// CHECK-LABEL: @bubble_up_extract_multi_input
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]
// CHECK-DAG: %[[EXTRACT0:.+]] = tensor.extract_slice %[[ARG0]]
// CHECK-DAG: %[[EXTRACT1:.+]] = tensor.extract_slice %[[ARG1]]
// CHECK: %[[GENERIC:.+]] = linalg.generic
// CHECK-SAME: ins(%[[EXTRACT0]], %[[EXTRACT1]] : tensor<1024x7x7x1xi8>, tensor<1024x7x7x1xi8>)
// CHECK: util.return %[[GENERIC]]

// -----

util.func public @bubble_up_extract_with_use(%arg0 : tensor<1024x7x7x2xi8>) -> (tensor<1024x7x7xf32>, tensor<1024x7x7x2xf32>) {
%0 = tensor.empty() : tensor<1024x7x7x2xf32>
%cst = arith.constant 5.000000e-01 : f32
%1 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg0 : tensor<1024x7x7x2xi8>) outs(%0 : tensor<1024x7x7x2xf32>) {
^bb0(%in: i8, %out: f32):
%4 = arith.extsi %in : i8 to i32
%5 = arith.sitofp %4 : i32 to f32
%6 = arith.mulf %5, %cst : f32
linalg.yield %6 : f32
} -> tensor<1024x7x7x2xf32>

%extracted_slice = tensor.extract_slice %1[0, 0, 0, 1] [1024, 7, 7, 1] [1, 1, 1, 1] : tensor<1024x7x7x2xf32> to tensor<1024x7x7xf32>
util.return %extracted_slice, %1 : tensor<1024x7x7xf32>, tensor<1024x7x7x2xf32>
}

// CHECK-LABEL: @bubble_up_extract_with_use
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
// CHECK-DAG: %[[GENERIC0:.+]] = linalg.generic
// CHECK-SAME: ins(%[[ARG0]] : tensor<1024x7x7x2xi8>)
//
// CHECK-DAG: %[[EXTRACT0:.+]] = tensor.extract_slice %[[ARG0]]
// CHECK-DAG: %[[GENERIC1:.+]] = linalg.generic
// CHECK-SAME: ins(%[[EXTRACT0]] : tensor<1024x7x7xi8>)
// CHECK: util.return %[[GENERIC1]], %[[GENERIC0]]

0 comments on commit d6762d4

Please sign in to comment.