Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DispatchCreation] Extend multi-use producer fusion #18551

Merged
merged 7 commits into from
Oct 16, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions .github/workflows/pkgci_regression_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -220,8 +220,8 @@ jobs:
--goldentime-rocm-unet-ms 419.0 \
--goldentime-rocm-clip-ms 18.5 \
--goldentime-rocm-vae-ms 337.0 \
--goldendispatch-rocm-unet 1551 \
--goldendispatch-rocm-clip 1139 \
--goldendispatch-rocm-unet 1545 \
--goldendispatch-rocm-clip 1225 \
--goldendispatch-rocm-vae 248 \
--goldensize-rocm-unet-bytes 2280000 \
--goldensize-rocm-clip-bytes 860000 \
Expand All @@ -241,8 +241,8 @@ jobs:
--goldentime-rocm-unet-ms 95.0 \
--goldentime-rocm-clip-ms 15.5 \
--goldentime-rocm-vae-ms 80.0 \
--goldendispatch-rocm-unet 1551 \
--goldendispatch-rocm-clip 1139 \
--goldendispatch-rocm-unet 1545 \
--goldendispatch-rocm-clip 1225 \
--goldendispatch-rocm-vae 248 \
--goldensize-rocm-unet-bytes 2270000 \
--goldensize-rocm-clip-bytes 860000 \
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include "iree/compiler/Dialect/Flow/IR/FlowOps.h"
#include "iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.h"
#include "iree/compiler/Dialect/LinalgExt/Utils/Utils.h"
#include "iree/compiler/DispatchCreation/FusionUtils.h"
#include "iree/compiler/DispatchCreation/Passes.h"
#include "mlir/Analysis/SliceAnalysis.h"
#include "mlir/Analysis/TopologicalSortUtils.h"
Expand Down Expand Up @@ -107,25 +108,6 @@ static bool isEmptyFillContractionDAGRootOp(
return true;
}

/// Check that a given operation is "horizontal" to the group. The operation
/// is horizontal if the `slice` of the operation does not contain any op
/// from the group.
static bool isHorizontalToGroup(Operation *op,
const llvm::SetVector<Operation *> &currGroup,
const DominanceInfo &dominanceInfo,
Operation *seedOp) {
BackwardSliceOptions options;
// Limit the slice to the seed to make sure the slice is small.
options.filter = [&](Operation *op) {
return !dominanceInfo.properlyDominates(op, seedOp);
};
llvm::SetVector<Operation *> slice;
getBackwardSlice(op, &slice, options);
return !llvm::any_of(currGroup, [&](Operation *groupedOp) {
return slice.contains(groupedOp);
});
}

/// Get user of operation that is a truncate operation.
static std::optional<linalg::GenericOp>
getTruncateOp(Operation *op,
Expand All @@ -149,8 +131,8 @@ getTruncateOp(Operation *op,
if (!checkOperationEquivalence(genericOp, seedTruncateOp.value())) {
return std::nullopt;
}
if (!isHorizontalToGroup(genericOp, groupedOperations, dominanceInfo,
seedTruncateOp.value())) {
if (!isHorizontalToGroup(genericOp, groupedOperations.getArrayRef(),
dominanceInfo, seedTruncateOp.value())) {
return std::nullopt;
}
}
Expand Down Expand Up @@ -226,7 +208,8 @@ static std::optional<HorizontalFusionGroup> getHorizontalFusionGroupMembers(
if (!dominanceInfo.properlyDominates(seedOp, linalgOp)) {
return false;
}
if (!isHorizontalToGroup(linalgOp, allOps, dominanceInfo, seedOp)) {
if (!isHorizontalToGroup(linalgOp, allOps.getArrayRef(), dominanceInfo,
seedOp)) {
return false;
}
return true;
Expand Down Expand Up @@ -346,40 +329,6 @@ static AffineMap getConcatenatedIndexingMap(RewriterBase &rewriter,
return newIndexingMap.insertResult(rewriter.getAffineDimExpr(0), 0);
}

/// During horizontal fusion, there might be operands of the fused operations
/// whose definitions are interspersed between the fused operations. For groups
/// chosen to fuse horizontally, such operations can be moved before the
/// seed contraction operation (where the fused operation is generated).
template <typename T>
static LogicalResult
moveOperandDefs(RewriterBase &rewriter, ArrayRef<T> operations,
Operation *insertionPoint, DominanceInfo &dominanceInfo,
ArrayRef<linalg::LinalgOp> ignoreOperations = {}) {
BackwardSliceOptions options;
llvm::DenseSet<Operation *> ignoreOperationsSet;
ignoreOperationsSet.insert(ignoreOperations.begin(), ignoreOperations.end());
options.filter = [&](Operation *op) {
return !dominanceInfo.properlyDominates(op, insertionPoint) &&
!ignoreOperationsSet.contains(op);
};
// Set inclusive to true cause the slice is computed from the operand, and
// we want to include the defining op (which is the point here)
options.inclusive = true;

llvm::SetVector<Operation *> slice;
for (auto op : operations) {
for (auto operand : op->getOperands()) {
getBackwardSlice(operand, &slice, options);
}
}

mlir::topologicalSort(slice);
for (auto op : slice) {
rewriter.moveOpBefore(op, insertionPoint);
}
return success();
}

/// On finding this pattern
/// ```
/// %0 = linalg.matmul ins(%arg0, %arg1)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,13 @@
#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtInterfaces.h"
#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.h"
#include "iree/compiler/Dialect/LinalgExt/Utils/Utils.h"
#include "iree/compiler/DispatchCreation/FusionUtils.h"
#include "iree/compiler/DispatchCreation/Passes.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/Debug.h"
#include "mlir/Analysis/SliceAnalysis.h"
#include "mlir/Analysis/TopologicalSortUtils.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
Expand All @@ -45,25 +49,55 @@ static llvm::cl::opt<int64_t> clLinalgMaxConstantFoldElements(
llvm::cl::desc("Maximum number of elements to try to constant fold."),
llvm::cl::init(0));

static Operation *getMostDominantUse(Operation *op,
const DominanceInfo &dominanceInfo) {
auto uses = op->getUses();
auto it = llvm::find_if(uses, [&](OpOperand &source) {
Operation *sourceOp = source.getOwner();

return llvm::all_of(uses, [&](OpOperand &target) {
Operation *targetOp = target.getOwner();
return dominanceInfo.dominates(sourceOp, targetOp);
});
});
if (it != uses.end()) {
return it->getOwner();
}
return nullptr;
}

/// Check if any of the use dominates all other uses of the operation.
static std::optional<OpOperand *> getFusableUse(Operation *op,
DominanceInfo &dominanceInfo) {
static Operation *getFusableUse(Operation *op,
const DominanceInfo &dominanceInfo) {
auto uses = op->getUses();
Operation *fusableUse = nullptr;
for (OpOperand &source : uses) {
Operation *sourceOp = source.getOwner();
bool dominatesAllUsers = true;
for (OpOperand &target : uses) {

bool dominatesAllFusableOps = llvm::all_of(uses, [&](OpOperand &target) {
Operation *targetOp = target.getOwner();
if (!dominanceInfo.dominates(sourceOp, targetOp)) {
dominatesAllUsers = false;
break;
}
}
if (dominatesAllUsers) {
return &source;
return !isa<linalg::GenericOp>(targetOp) ||
dominanceInfo.dominates(sourceOp, targetOp);
});
if (dominatesAllFusableOps) {
fusableUse = sourceOp;
break;
}
}
return std::nullopt;
Operation *mostDominantOp = getMostDominantUse(op, dominanceInfo);
if (!fusableUse || !mostDominantOp) {
return nullptr;
}

// If `fusableUse` dominates all other users, there's nothing else to do.
if (fusableUse == mostDominantOp) {
return fusableUse;
}

SmallVector<Operation *> users(op->getUsers().begin(), op->getUsers().end());
return isHorizontalToGroup(fusableUse, users, dominanceInfo, mostDominantOp)
? fusableUse
: nullptr;
}

static OpOperand *getFirstUseInConsumer(Operation *producer,
Expand Down Expand Up @@ -91,6 +125,7 @@ static SmallVector<OpOperand *> getAllUsesInConsumer(Operation *producer,
/// using elementwise fusion.
static LogicalResult doMultiUseFusion(Operation *rootOp,
llvm::SetVector<Operation *> &fusableOps,
const DominanceInfo &dominanceInfo,
RewriterBase &rewriter) {
assert(rootOp && "root op cant be null");

Expand All @@ -112,11 +147,20 @@ static LogicalResult doMultiUseFusion(Operation *rootOp,
Operation *consumerOp = rootOp;
OpBuilder::InsertionGuard g(rewriter);
for (Operation *producerOp : llvm::reverse(fusedOpsVec)) {
Operation *mostDominantUser = getMostDominantUse(producerOp, dominanceInfo);
// Fuse all uses from producer -> consumer. It has been checked
// before that all uses are fusable.
while (OpOperand *fusedOperand =
getFirstUseInConsumer(producerOp, consumerOp)) {
rewriter.setInsertionPoint(consumerOp);

if (consumerOp != mostDominantUser &&
failed(moveOperandDefs(rewriter, ArrayRef<Operation *>{consumerOp},
mostDominantUser, dominanceInfo))) {
return rewriter.notifyMatchFailure(consumerOp,
"failed to move operand defs");
}
rewriter.moveOpBefore(consumerOp, mostDominantUser);
FailureOr<linalg::ElementwiseOpFusionResult> fusionResult =
linalg::fuseElementwiseOps(rewriter, fusedOperand);
if (failed(fusionResult)) {
Expand Down Expand Up @@ -190,9 +234,8 @@ static FailureOr<unsigned> fuseMultiUseProducers(Operation *funcOp,
}

// 6. Check that the `genericOp` dominates all uses of `producer`.
std::optional<OpOperand *> fusableUse =
getFusableUse(producer, dominanceInfo);
if (!fusableUse || fusableUse.value()->getOwner() != genericOp) {
Operation *fusableUse = getFusableUse(producer, dominanceInfo);
if (!fusableUse || fusableUse != genericOp) {
continue;
}

Expand Down Expand Up @@ -232,7 +275,8 @@ static FailureOr<unsigned> fuseMultiUseProducers(Operation *funcOp,

IRRewriter rewriter(context);
for (auto it = fusedOps.rbegin(), ie = fusedOps.rend(); it != ie; ++it) {
if (failed(doMultiUseFusion(it->first, it->second, rewriter))) {
if (failed(
doMultiUseFusion(it->first, it->second, dominanceInfo, rewriter))) {
return funcOp->emitOpError("failed multi use fusion");
}
}
Expand Down
35 changes: 35 additions & 0 deletions compiler/src/iree/compiler/DispatchCreation/FusionUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,11 @@
#include "compiler/src/iree/compiler/DispatchCreation/FusionUtils.h"
#include "compiler/src/iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.h"
#include "iree/compiler/Dialect/LinalgExt/Utils/Utils.h"
#include "mlir/Analysis/SliceAnalysis.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/IR/Dominance.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/Transforms/RegionUtils.h"

namespace mlir::iree_compiler::DispatchCreation {

Expand Down Expand Up @@ -97,4 +101,35 @@ bool areFusableAsElementwiseOps(MLIRContext *context, OpOperand *fusedOperand,
return true;
}

// Returns true when an operation `op` is horizontal to `currGroup` when
// considering the program slice between `seedOp` and `op`.
IanWood1 marked this conversation as resolved.
Show resolved Hide resolved
bool isHorizontalToGroup(Operation *op, ArrayRef<Operation *> currGroup,
const DominanceInfo &dominanceInfo,
Operation *seedOp) {
assert(dominanceInfo.properlyDominates(seedOp, op) &&
op->getParentRegion() == seedOp->getParentRegion());
BackwardSliceOptions options;
// Limit the slice to the seed to make sure the slice is small.
options.filter = [&](Operation *op) {
return !dominanceInfo.properlyDominates(op, seedOp);
};
llvm::SetVector<Operation *> slice;
getBackwardSlice(op, &slice, options);

// `getBackwardSlice` doesnt track uses from within an ops region, so make
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure I follow this comment. Can you explain more?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

%0 = linalg.generic
%1 = linalg.generic %0
%2 = linalg.generic %1 ins(%0 : tensor<...>) {
  ^bb0(%in: i64, %out: i64):
    %extracted = tensor.extract %1 [...]
    linalg.yield %extracted : f16
  } -> tensor<...>

In the above, the backward slice starting at %2 would only include %2 (but not %1 because it isn't a direct operand). Which makes it seem like %2 can be moved before %1. Just looking at the slice doesn't tell you if there is a dependency between the ops.

This was causing issues with the open llama regression tests since (i think) there are multiple gathers

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to use a slice of the values that are captured from above to make sure that there is no dependence?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I think so. Maybe this could be a option of getBackwardsSlice?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe... lets leave this this way for now.

Copy link
Contributor

@nirvedhmeshram nirvedhmeshram Oct 24, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I am seeing the test added in this PR correctly there wasnt a test added for this? The failure in #18879 seems related to this.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@nirvedhmeshram There were several other tests that failed related to the backward slice problem (regression tests), but none were because the consumer was using values defined above which is why this slipped through. I'll make sure to add more testing when fixing this

// sure there are no values defined above.
for (Operation *sliceOp : slice) {
bool usesValuesFromAbove = false;
mlir::visitUsedValuesDefinedAbove(
sliceOp->getRegions(), [&](void *) { usesValuesFromAbove = true; });
if (usesValuesFromAbove) {
return false;
}
}

return !llvm::any_of(currGroup, [&](Operation *groupedOp) {
return slice.contains(groupedOp);
});
}

} // namespace mlir::iree_compiler::DispatchCreation
44 changes: 44 additions & 0 deletions compiler/src/iree/compiler/DispatchCreation/FusionUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@
//
//===----------------------------------------------------------------------===//

#include "mlir/Analysis/SliceAnalysis.h"
#include "mlir/Analysis/TopologicalSortUtils.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/IR/Dominance.h"
#include "mlir/IR/Operation.h"

namespace mlir::iree_compiler::DispatchCreation {
Expand All @@ -19,4 +23,44 @@ namespace mlir::iree_compiler::DispatchCreation {
bool areFusableAsElementwiseOps(MLIRContext *context, OpOperand *operand,
bool fuseMultiReduction);

/// Check that a given operation is "horizontal" to the group. The operation
/// is horizontal if the program slice of the operation (from op back to seedOp)
/// does not contain any op from the group.
bool isHorizontalToGroup(Operation *op, ArrayRef<Operation *> currGroup,
const DominanceInfo &dominanceInfo, Operation *seedOp);

/// Moves the operands and transitive defs for each op in `operations` directly
/// after `insertionPoint`. Note: this does not check if it is legal to move the
/// operands.
template <typename T>
static LogicalResult
moveOperandDefs(RewriterBase &rewriter, ArrayRef<T> operations,
Operation *insertionPoint, const DominanceInfo &dominanceInfo,
ArrayRef<linalg::LinalgOp> ignoreOperations = {}) {
BackwardSliceOptions options;
llvm::DenseSet<Operation *> ignoreOperationsSet;
ignoreOperationsSet.insert(ignoreOperations.begin(), ignoreOperations.end());
options.filter = [&](Operation *op) {
return !dominanceInfo.properlyDominates(op, insertionPoint) &&
!ignoreOperationsSet.contains(op);
};
// Set inclusive to true cause the slice is computed from the operand, and
// we want to include the defining op (which is the point here)
options.inclusive = true;

llvm::SetVector<Operation *> slice;
for (auto op : operations) {
assert(insertionPoint->getBlock() == op->getBlock());
for (auto operand : op->getOperands()) {
getBackwardSlice(operand, &slice, options);
}
}

mlir::topologicalSort(slice);
for (auto op : slice) {
rewriter.moveOpBefore(op, insertionPoint);
}
return success();
}

} // namespace mlir::iree_compiler::DispatchCreation
Original file line number Diff line number Diff line change
Expand Up @@ -139,3 +139,28 @@ util.func public @math_sin() {
// CHECK: %[[GENERIC:.+]]:2 = linalg.generic
// CHECK-DAG: check.expect_almost_eq(%[[GENERIC]]#0,
// CHECK-DAG: check.expect_almost_eq(%[[GENERIC]]#1,

// -----

#map = affine_map<(d0, d1) -> (d0, d1)>
util.func public @fuse_by_moving_consumer(%arg0: tensor<5x5xf32>, %arg1: tensor<5x5xf32>) -> (tensor<5x5xf32>, tensor<25xf32>) {
%cst = arith.constant 1.000000e+00 : f32
%cst_0 = arith.constant 2.000000e+00 : f32
%cst_1 = arith.constant 3.000000e+00 : f32
%4 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0 : tensor<5x5xf32>) outs(%arg1 : tensor<5x5xf32>) {
^bb0(%arg2: f32, %arg3: f32):
%8 = arith.addf %arg2, %cst : f32
linalg.yield %8 : f32
} -> tensor<5x5xf32>
// expected-note @below {{prior use here}}
%collapsed = tensor.collapse_shape %4 [[0, 1]] : tensor<5x5xf32> into tensor<25xf32>
%5 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]} ins(%4 : tensor<5x5xf32>) outs(%arg1 : tensor<5x5xf32>) {
^bb0(%arg2: f32, %arg3: f32):
%8 = arith.subf %arg2, %cst_0 : f32
linalg.yield %8 : f32
} -> tensor<5x5xf32>
util.return %5, %collapsed: tensor<5x5xf32>, tensor<25xf32>
}
// CHECK-LABEL: util.func public @fuse_by_moving_consumer
// CHECK: linalg.generic
// CHECK-NOT: linalg.generic
Loading