Skip to content

Commit

Permalink
[Codegen] Add control options in pack unpack decomposition (#18469)
Browse files Browse the repository at this point in the history
Some early patterns in pack/unpack decomposition avoid generating
reshape ops for unit dim packs and unpacks. However, the TileAndFuse
pipeline uses these reshape ops to propagate expanded shapes to other
fusable ops. This PR adds an option to the DecomposePackUnPackOps pass
to create reshape ops anyway for unit dim cases.

The reason these unit dims show up right now is that the
iree_linalg_ext.im2col op of a unit-batched conv will have a unit
dimension in the batch dim. Ultimately, it would be good to allow for
batchless im2col ops, but in general it is good to support ops that have
required unit dimensions. When prototyping new ops, it can be easiest to
not support rank-reducing cases at first (winograd ops are another
example), so these unit dims may appear again in the future.

This PR also adds an optional control function to the pass options,
which controls which packs and unpacks get decomposed. The control
function is currently expected to be used when the `useOnlyReshapes`
option is true, since there is no control function in some upstream
patterns yet, but adding the control function upstream and fixing this
is left as a TODO.

---------

Signed-off-by: Max Dawkins <[email protected]>
  • Loading branch information
Max191 authored Sep 20, 2024
1 parent d834aa7 commit 891f438
Show file tree
Hide file tree
Showing 5 changed files with 221 additions and 151 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/Visitors.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"

#define DEBUG_TYPE "iree-codegen-decompose-pack-unpack-ops"
Expand All @@ -35,15 +36,25 @@ namespace {
struct LowerPackPattern : public OpRewritePattern<tensor::PackOp> {
using OpRewritePattern<tensor::PackOp>::OpRewritePattern;

explicit LowerPackPattern(MLIRContext *context,
std::optional<PackUnPackControlFn> controlFn)
: OpRewritePattern(context), controlFn(controlFn) {}

LogicalResult matchAndRewrite(tensor::PackOp op,
PatternRewriter &rewriter) const override {
if (controlFn && failed(controlFn.value()(op))) {
return failure();
}
FailureOr<linalg::LowerPackResult> res = linalg::lowerPack(rewriter, op);
if (failed(res)) {
return rewriter.notifyMatchFailure(
op, "cannot lower to pad + expand + transpose");
}
return success();
}

private:
std::optional<PackUnPackControlFn> controlFn;
};

/// A warpper pattern that calls linalg::lowerUnPack on tensor::UnPackOp. It
Expand All @@ -52,8 +63,15 @@ struct LowerPackPattern : public OpRewritePattern<tensor::PackOp> {
struct LowerUnPackPattern : public OpRewritePattern<tensor::UnPackOp> {
using OpRewritePattern<tensor::UnPackOp>::OpRewritePattern;

explicit LowerUnPackPattern(MLIRContext *context,
std::optional<PackUnPackControlFn> controlFn)
: OpRewritePattern(context), controlFn(controlFn) {}

LogicalResult matchAndRewrite(tensor::UnPackOp op,
PatternRewriter &rewriter) const override {
if (controlFn && failed(controlFn.value()(op))) {
return failure();
}
FailureOr<linalg::LowerUnPackOpResult> res =
linalg::lowerUnPack(rewriter, op);
if (failed(res)) {
Expand All @@ -62,21 +80,31 @@ struct LowerUnPackPattern : public OpRewritePattern<tensor::UnPackOp> {
}
return success();
}

private:
std::optional<PackUnPackControlFn> controlFn;
};

struct DecomposePackUnPackOpsPass final
: impl::DecomposePackUnPackOpsPassBase<DecomposePackUnPackOpsPass> {
using impl::DecomposePackUnPackOpsPassBase<
DecomposePackUnPackOpsPass>::DecomposePackUnPackOpsPassBase;
explicit DecomposePackUnPackOpsPass(bool tileOuterToOne) {
explicit DecomposePackUnPackOpsPass(
bool tileOuterToOne, bool useOnlyReshapes,
std::optional<PackUnPackControlFn> controlFn) {
this->tileOuterToOne = tileOuterToOne;
this->useOnlyReshapes = useOnlyReshapes;
this->controlFn = controlFn;
}
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<linalg::LinalgDialect, arith::ArithDialect, scf::SCFDialect,
tensor::TensorDialect>();
}

void runOnOperation() override;

private:
std::optional<PackUnPackControlFn> controlFn;
};

} // namespace
Expand All @@ -86,7 +114,7 @@ void DecomposePackUnPackOpsPass::runOnOperation() {
auto funcOp = getOperation();
// Generalization patterns for outer unit dims have higher priority because
// they do not generate reshape ops.
{
if (!useOnlyReshapes) {
RewritePatternSet patterns(ctx);
patterns.add<linalg::GeneralizeOuterUnitDimsPackOpPattern,
linalg::GeneralizeOuterUnitDimsUnPackOpPattern>(ctx);
Expand All @@ -102,7 +130,7 @@ void DecomposePackUnPackOpsPass::runOnOperation() {
// tiled to one.
if (!tileOuterToOne) {
RewritePatternSet patterns(ctx);
patterns.add<LowerPackPattern, LowerUnPackPattern>(ctx);
patterns.add<LowerPackPattern, LowerUnPackPattern>(ctx, controlFn);
if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) {
funcOp.emitError(
"failed to apply generalization patterns on pack/unpack ops for "
Expand Down Expand Up @@ -136,6 +164,9 @@ void DecomposePackUnPackOpsPass::runOnOperation() {
return tileSizes;
}));
funcOp->walk([&](tensor::PackOp op) {
if (controlFn && failed(controlFn.value()(op))) {
return;
}
FailureOr<scf::SCFTileAndFuseResult> tileAndFuseResult =
scf::tileConsumerAndFuseProducersUsingSCF(
rewriter, cast<TilingInterface>(op.getOperation()), packOptions);
Expand All @@ -161,6 +192,9 @@ void DecomposePackUnPackOpsPass::runOnOperation() {
return tileSizes;
});
funcOp->walk([&](tensor::UnPackOp op) {
if (controlFn && failed(controlFn.value()(op))) {
return;
}
FailureOr<scf::SCFTilingResult> tilingResult =
scf::tileUsingSCF(rewriter, cast<TilingInterface>(op.getOperation()),
unpackTilingOptions);
Expand Down Expand Up @@ -197,17 +231,23 @@ void DecomposePackUnPackOpsPass::runOnOperation() {

{
RewritePatternSet patterns(ctx);
patterns.add<linalg::GeneralizeOuterUnitDimsPackOpPattern,
linalg::GeneralizeOuterUnitDimsUnPackOpPattern>(ctx);
if (useOnlyReshapes) {
patterns.add<LowerPackPattern, LowerUnPackPattern>(ctx, controlFn);
} else {
patterns.add<linalg::GeneralizeOuterUnitDimsPackOpPattern,
linalg::GeneralizeOuterUnitDimsUnPackOpPattern>(ctx);
}
if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) {
return signalPassFailure();
}
}
}

std::unique_ptr<InterfacePass<mlir::FunctionOpInterface>>
createDecomposePackUnPackOpsPass(bool tileOuterToOne) {
return std::make_unique<DecomposePackUnPackOpsPass>(tileOuterToOne);
createDecomposePackUnPackOpsPass(bool tileOuterToOne, bool useOnlyReshapes,
std::optional<PackUnPackControlFn> controlFn) {
return std::make_unique<DecomposePackUnPackOpsPass>(
tileOuterToOne, useOnlyReshapes, controlFn);
}

} // namespace mlir::iree_compiler
10 changes: 9 additions & 1 deletion compiler/src/iree/compiler/Codegen/Common/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,16 @@ using ConfigFn =
std::unique_ptr<InterfacePass<FunctionOpInterface>>
createConvolutionToIGEMMPass(ConfigFn configFn);

using PackUnPackControlFn = std::function<LogicalResult(Operation *)>;
/// Pass to decompose pack and unpack ops into pad/extract_slice and reshape
/// ops. If specified, `controlFn` controls which ops get decomposed. The
/// `controlFn` should be used with `useOnlyReshapes` set to true.
/// TODO(Max191): Add a controlFn upstream for `GeneralizeOuterUnitDim*`
/// patterns and remove the need to have `useOnlyReshapes = true` when using
/// `controlFn`.
std::unique_ptr<InterfacePass<FunctionOpInterface>>
createDecomposePackUnPackOpsPass(bool tileOuterToOne);
createDecomposePackUnPackOpsPass(bool tileOuterToOne, bool useOnlyReshapes,
std::optional<PackUnPackControlFn> controlFn);

std::unique_ptr<Pass> createDecomposeSoftmaxPass(bool useFusion);

Expand Down
4 changes: 3 additions & 1 deletion compiler/src/iree/compiler/Codegen/Common/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,9 @@ def DecomposePackUnPackOpsPass :
let summary = "Decompose pack/unpack ops into vectorizable ops";
let options = [
Option<"tileOuterToOne", "tile-outer-to-one", "bool", "false",
"Always apply tiling to make outer dimension be ones">
"Always apply tiling to make outer dimension be ones">,
Option<"useOnlyReshapes", "use-only-reshapes", "bool", "false",
"Use decomposition into reshape ops, even when packing unit dimensions.">
];
}

Expand Down
Loading

0 comments on commit 891f438

Please sign in to comment.