From d6762d431a3930c97961bdd743bd44b83be7e4ac Mon Sep 17 00:00:00 2001 From: MaheshRavishankar <1663364+MaheshRavishankar@users.noreply.github.com> Date: Tue, 27 Aug 2024 18:29:46 -0700 Subject: [PATCH] Add pass to bubble-up extract_slice operations. (#18332) This adds pass to replace a `tensor.extract_slice` operation with a slice of the producer. In general there might be more opportunities to use this pass more aggressively (like when an operation has a single use which is a slice), but for now this is being done only for bit-extend operations. Co-authored-by: Ian Wood --- .../compiler/DispatchCreation/BUILD.bazel | 1 + .../BubbleUpExtractSlices.cpp | 134 ++++++++++++++++++ .../compiler/DispatchCreation/CMakeLists.txt | 1 + .../iree/compiler/DispatchCreation/Passes.cpp | 1 + .../iree/compiler/DispatchCreation/Passes.td | 7 + .../DispatchCreation/test/BUILD.bazel | 1 + .../DispatchCreation/test/CMakeLists.txt | 1 + .../test/bubble_up_extract_slice.mlir | 96 +++++++++++++ 8 files changed, 242 insertions(+) create mode 100644 compiler/src/iree/compiler/DispatchCreation/BubbleUpExtractSlices.cpp create mode 100644 compiler/src/iree/compiler/DispatchCreation/test/bubble_up_extract_slice.mlir diff --git a/compiler/src/iree/compiler/DispatchCreation/BUILD.bazel b/compiler/src/iree/compiler/DispatchCreation/BUILD.bazel index 3be481779e54..9df4709f5320 100644 --- a/compiler/src/iree/compiler/DispatchCreation/BUILD.bazel +++ b/compiler/src/iree/compiler/DispatchCreation/BUILD.bazel @@ -16,6 +16,7 @@ iree_compiler_cc_library( name = "DispatchCreation", srcs = [ "BubbleUpExpandShapes.cpp", + "BubbleUpExtractSlices.cpp", "CloneProducersIntoDispatchRegions.cpp", "CollapseDimensions.cpp", "CollapseReductionDimensions.cpp", diff --git a/compiler/src/iree/compiler/DispatchCreation/BubbleUpExtractSlices.cpp b/compiler/src/iree/compiler/DispatchCreation/BubbleUpExtractSlices.cpp new file mode 100644 index 000000000000..4b03b60f5779 --- /dev/null +++ b/compiler/src/iree/compiler/DispatchCreation/BubbleUpExtractSlices.cpp @@ -0,0 +1,134 @@ +// Copyright 2024 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "iree/compiler/Dialect/LinalgExt/Utils/Utils.h" +#include "llvm/ADT/STLExtras.h" +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h" +#include "mlir/Dialect/Linalg/Transforms/Transforms.h" +#include "mlir/Dialect/Linalg/Utils/Utils.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Tensor/Transforms/Transforms.h" +#include "mlir/Dialect/Utils/ReshapeOpsUtils.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +namespace mlir::iree_compiler::DispatchCreation { + +#define GEN_PASS_DEF_BUBBLEUPEXTRACTSLICESPASS +#include "iree/compiler/DispatchCreation/Passes.h.inc" + +namespace { + +// Convert extract_slice(dequant) to dequant(extract_slice) +// +// Because `extract_slice` ops and dequantize-like ops get cloned into regions +// later, it's okay to bubble up through multi-use dequant ops. +struct BubbleUpExtract : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(tensor::ExtractSliceOp sliceOp, + PatternRewriter &rewriter) const final { + Value source = sliceOp.getSource(); + auto genericOp = source.getDefiningOp(); + if (!genericOp || genericOp->getNumResults() != 1) { + return rewriter.notifyMatchFailure( + sliceOp, "expected source to implement `linalg::LinalgOp` and have a " + "single result"); + } + + if (!IREE::LinalgExt::isBitExtendOp(genericOp)) { + return rewriter.notifyMatchFailure( + sliceOp, "expected source to be dequantize-like"); + } + + if (!sliceOp.hasUnitStride()) { + return rewriter.notifyMatchFailure(sliceOp, "expected unit stride"); + } + + if (!llvm::all_of(genericOp.getIndexingMapsArray(), [](AffineMap map) { + return map.isProjectedPermutation(); + })) { + return rewriter.notifyMatchFailure( + genericOp, + "expected generic op to have all projected permutation maps"); + } + + if (genericOp.hasIndexSemantics()) { + return rewriter.notifyMatchFailure( + genericOp, "pattern doesn't support index semantics"); + } + + Value replacement; + linalg::GenericOp swappedOp; + { + FailureOr tilingResult = + tensor::replaceExtractSliceWithTiledProducer(rewriter, sliceOp, + genericOp->getResult(0)); + assert(succeeded(tilingResult) && "failed to swap extract_slice with op"); + assert(tilingResult->tiledOps.size() == 1); + replacement = tilingResult->tiledValues[0]; + swappedOp = cast(tilingResult->tiledOps[0]); + } + + // Check if this is a rank-reducing slice, if so we need to fold the unit + // dimensions of the op. + // This is necessary because `replaceExtractSliceWithTiledProducer` does not + // take into account the `extract_slice`'s implicit rank reduction. The + // operations generated by that function will have any unit dims that were + // removed by the original `extract_slice`. Folding them away ensures that + // the types match. + if (sliceOp.getSourceType().getRank() != + sliceOp.getResultType().getRank()) { + + llvm::SmallBitVector droppedDims = sliceOp.getDroppedDims(); + // Get the indexing map for the result. + AffineMap resultMap = + swappedOp.getIndexingMapMatchingResult(swappedOp->getResult(0)); + linalg::ControlDropUnitDims options; + options.rankReductionStrategy = linalg::ControlDropUnitDims:: + RankReductionStrategy::ExtractInsertSlice; + options.controlFn = [&](Operation *op) -> SmallVector { + SmallVector droppedDimsVec; + for (auto [index, expr] : llvm::enumerate(resultMap.getResults())) { + if (!droppedDims.test(index)) { + continue; + } + auto dimExpr = cast(expr); + droppedDimsVec.push_back(dimExpr.getPosition()); + } + return droppedDimsVec; + }; + FailureOr dropUnitDims = + linalg::dropUnitDims(rewriter, swappedOp, options); + assert(succeeded(dropUnitDims) && + "failed to drop unit dims of produced operation"); + swappedOp = dropUnitDims->resultOp; + replacement = swappedOp->getResult(0); + } + rewriter.replaceOp(sliceOp, replacement); + return success(); + } +}; + +struct BubbleUpExtractSlicesPass + : impl::BubbleUpExtractSlicesPassBase { + void runOnOperation() override { + MLIRContext *context = &getContext(); + { + RewritePatternSet patterns(context); + patterns.insert(context); + if (failed(applyPatternsAndFoldGreedily(getOperation(), + std::move(patterns)))) { + return signalPassFailure(); + } + } + } +}; +} // namespace + +} // namespace mlir::iree_compiler::DispatchCreation diff --git a/compiler/src/iree/compiler/DispatchCreation/CMakeLists.txt b/compiler/src/iree/compiler/DispatchCreation/CMakeLists.txt index d8cbd3b93e7c..9f7b2b063bd6 100644 --- a/compiler/src/iree/compiler/DispatchCreation/CMakeLists.txt +++ b/compiler/src/iree/compiler/DispatchCreation/CMakeLists.txt @@ -18,6 +18,7 @@ iree_cc_library( "Passes.h" SRCS "BubbleUpExpandShapes.cpp" + "BubbleUpExtractSlices.cpp" "CloneProducersIntoDispatchRegions.cpp" "CollapseDimensions.cpp" "CollapseReductionDimensions.cpp" diff --git a/compiler/src/iree/compiler/DispatchCreation/Passes.cpp b/compiler/src/iree/compiler/DispatchCreation/Passes.cpp index 9f18cebc8d91..01d132e894c2 100644 --- a/compiler/src/iree/compiler/DispatchCreation/Passes.cpp +++ b/compiler/src/iree/compiler/DispatchCreation/Passes.cpp @@ -142,6 +142,7 @@ void addDispatchRegionCreationPreprocessingPasses(OpPassManager &passManager) { // elementwise operation into higher dimensions for more fusion // opportunities. .addPass(DispatchCreation::createBubbleUpExpandShapesPass) + .addPass(DispatchCreation::createBubbleUpExtractSlicesPass) .addPass(IREE::Flow::createCanonicalizerPass) .addPass(mlir::createCSEPass) diff --git a/compiler/src/iree/compiler/DispatchCreation/Passes.td b/compiler/src/iree/compiler/DispatchCreation/Passes.td index a89579e05e71..32e0660e6b95 100644 --- a/compiler/src/iree/compiler/DispatchCreation/Passes.td +++ b/compiler/src/iree/compiler/DispatchCreation/Passes.td @@ -143,6 +143,13 @@ def TransposeGenericOpsPass : ]; } +def BubbleUpExtractSlicesPass : Pass<"iree-dispatch-creation-bubble-up-extract-slices"> { + let summary = "Bubble up `extract_slice` ops through Linalg-like ops"; + let dependentDialects = [ + "mlir::affine::AffineDialect", + ]; +} + //===---------------------------------------------------------------------===// // Dispatch region creation passes. //===---------------------------------------------------------------------===// diff --git a/compiler/src/iree/compiler/DispatchCreation/test/BUILD.bazel b/compiler/src/iree/compiler/DispatchCreation/test/BUILD.bazel index 8ca4ad2f6d5f..9b21c78a8618 100644 --- a/compiler/src/iree/compiler/DispatchCreation/test/BUILD.bazel +++ b/compiler/src/iree/compiler/DispatchCreation/test/BUILD.bazel @@ -27,6 +27,7 @@ iree_lit_test_suite( "form_dispatch_regions.mlir", "dispatch_linalg_on_tensors.mlir", "convert_region_to_workgroups.mlir", + "bubble_up_extract_slice.mlir", "form_dispatch_workgroups.mlir", "dispatch_linalg_ext_fusion.mlir", "hoist_encoding_ops.mlir", diff --git a/compiler/src/iree/compiler/DispatchCreation/test/CMakeLists.txt b/compiler/src/iree/compiler/DispatchCreation/test/CMakeLists.txt index 152c1ac86788..36d0ceee98c4 100644 --- a/compiler/src/iree/compiler/DispatchCreation/test/CMakeLists.txt +++ b/compiler/src/iree/compiler/DispatchCreation/test/CMakeLists.txt @@ -15,6 +15,7 @@ iree_lit_test_suite( lit SRCS "attention_fuse_by_expansion.mlir" + "bubble_up_extract_slice.mlir" "clone_producers_into_dispatch_regions.mlir" "collapse_dimensions.mlir" "collapse_linalg_generic_on_tensors.mlir" diff --git a/compiler/src/iree/compiler/DispatchCreation/test/bubble_up_extract_slice.mlir b/compiler/src/iree/compiler/DispatchCreation/test/bubble_up_extract_slice.mlir new file mode 100644 index 000000000000..a5b7ea13ee27 --- /dev/null +++ b/compiler/src/iree/compiler/DispatchCreation/test/bubble_up_extract_slice.mlir @@ -0,0 +1,96 @@ +// RUN: iree-opt --split-input-file --iree-dispatch-creation-bubble-up-extract-slices --iree-flow-canonicalize %s | FileCheck %s + +util.func public @bubble_up_extract_rank_reduce(%arg0 : tensor<1024x7x7x2xi8>) -> tensor<1024x7x7xf32>{ + %0 = tensor.empty() : tensor<1024x7x7x2xf32> + %cst = arith.constant 5.000000e-01 : f32 + %1 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg0 : tensor<1024x7x7x2xi8>) outs(%0 : tensor<1024x7x7x2xf32>) { + ^bb0(%in: i8, %out: f32): + %4 = arith.extsi %in : i8 to i32 + %5 = arith.sitofp %4 : i32 to f32 + %6 = arith.mulf %5, %cst : f32 + linalg.yield %6 : f32 + } -> tensor<1024x7x7x2xf32> + + %extracted_slice = tensor.extract_slice %1[0, 0, 0, 1] [1024, 7, 7, 1] [1, 1, 1, 1] : tensor<1024x7x7x2xf32> to tensor<1024x7x7xf32> + util.return %extracted_slice : tensor<1024x7x7xf32> +} + +// CHECK-LABEL: @bubble_up_extract_rank_reduce +// CHECK: %[[EXTRACT:.+]] = tensor.extract_slice +// CHECK: %[[GENERIC:.+]] = linalg.generic +// CHECK: util.return %[[GENERIC]] + +// ----- + +util.func public @bubble_up_extract(%arg0 : tensor<1024x7x7x2xi8>) -> tensor<1024x7x7x1xf32>{ + %0 = tensor.empty() : tensor<1024x7x7x2xf32> + %cst = arith.constant 5.000000e-01 : f32 + %1 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg0 : tensor<1024x7x7x2xi8>) outs(%0 : tensor<1024x7x7x2xf32>) { + ^bb0(%in: i8, %out: f32): + %4 = arith.extsi %in : i8 to i32 + %5 = arith.sitofp %4 : i32 to f32 + %6 = arith.mulf %5, %cst : f32 + linalg.yield %6 : f32 + } -> tensor<1024x7x7x2xf32> + + %extracted_slice = tensor.extract_slice %1[0, 0, 0, 1] [1024, 7, 7, 1] [1, 1, 1, 1] : tensor<1024x7x7x2xf32> to tensor<1024x7x7x1xf32> + util.return %extracted_slice : tensor<1024x7x7x1xf32> +} + +// CHECK-LABEL: @bubble_up_extract +// CHECK: %[[EXTRACT:.+]] = tensor.extract_slice +// CHECK: %[[GENERIC:.+]] = linalg.generic +// CHECK: util.return %[[GENERIC]] + +// ----- + +util.func public @bubble_up_extract_multi_input(%arg0 : tensor<1024x7x7x2xi8>, %arg1 : tensor<1024x7x7x2xi8>) -> tensor<1024x7x7x1xf32>{ + %0 = tensor.empty() : tensor<1024x7x7x2xf32> + %cst = arith.constant 5.000000e-01 : f32 + %1 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg0, %arg1 : tensor<1024x7x7x2xi8>, tensor<1024x7x7x2xi8>) outs(%0 : tensor<1024x7x7x2xf32>) { + ^bb0(%in: i8, %in_0 : i8, %out: f32): + %4 = arith.extsi %in : i8 to i32 + %5 = arith.sitofp %4 : i32 to f32 + %6 = arith.mulf %5, %cst : f32 + linalg.yield %6 : f32 + } -> tensor<1024x7x7x2xf32> + + %extracted_slice = tensor.extract_slice %1[0, 0, 0, 1] [1024, 7, 7, 1] [1, 1, 1, 1] : tensor<1024x7x7x2xf32> to tensor<1024x7x7x1xf32> + util.return %extracted_slice : tensor<1024x7x7x1xf32> +} + +// CHECK-LABEL: @bubble_up_extract_multi_input +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]] +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]] +// CHECK-DAG: %[[EXTRACT0:.+]] = tensor.extract_slice %[[ARG0]] +// CHECK-DAG: %[[EXTRACT1:.+]] = tensor.extract_slice %[[ARG1]] +// CHECK: %[[GENERIC:.+]] = linalg.generic +// CHECK-SAME: ins(%[[EXTRACT0]], %[[EXTRACT1]] : tensor<1024x7x7x1xi8>, tensor<1024x7x7x1xi8>) +// CHECK: util.return %[[GENERIC]] + +// ----- + +util.func public @bubble_up_extract_with_use(%arg0 : tensor<1024x7x7x2xi8>) -> (tensor<1024x7x7xf32>, tensor<1024x7x7x2xf32>) { + %0 = tensor.empty() : tensor<1024x7x7x2xf32> + %cst = arith.constant 5.000000e-01 : f32 + %1 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg0 : tensor<1024x7x7x2xi8>) outs(%0 : tensor<1024x7x7x2xf32>) { + ^bb0(%in: i8, %out: f32): + %4 = arith.extsi %in : i8 to i32 + %5 = arith.sitofp %4 : i32 to f32 + %6 = arith.mulf %5, %cst : f32 + linalg.yield %6 : f32 + } -> tensor<1024x7x7x2xf32> + + %extracted_slice = tensor.extract_slice %1[0, 0, 0, 1] [1024, 7, 7, 1] [1, 1, 1, 1] : tensor<1024x7x7x2xf32> to tensor<1024x7x7xf32> + util.return %extracted_slice, %1 : tensor<1024x7x7xf32>, tensor<1024x7x7x2xf32> +} + +// CHECK-LABEL: @bubble_up_extract_with_use +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]] +// CHECK-DAG: %[[GENERIC0:.+]] = linalg.generic +// CHECK-SAME: ins(%[[ARG0]] : tensor<1024x7x7x2xi8>) +// +// CHECK-DAG: %[[EXTRACT0:.+]] = tensor.extract_slice %[[ARG0]] +// CHECK-DAG: %[[GENERIC1:.+]] = linalg.generic +// CHECK-SAME: ins(%[[EXTRACT0]] : tensor<1024x7x7xi8>) +// CHECK: util.return %[[GENERIC1]], %[[GENERIC0]]