Skip to content

Commit

Permalink
[mlir][tensor] Generalize/restrict `GeneralizeOuterUnitDimsPackOpPatt…
Browse files Browse the repository at this point in the history
…ern`

This PR _restricts_ `GeneralizeOuterUnitDimsPackOpPattern` by making it
follow its intended use (as per the documentation), i.e. to:

  > require all outer dims of tensor.pack to be 1.

There was one test in-tree that violated that assumption (and, happened
to work), see `@simple_KCRS_to_KRSCsr` in "generalize-tensor-pack.mlir".
That test has been updated to satisfy the updated requirements of the
pattern.

By making the pattern follow its intended design (i.e. making it
stricter), the calculation of shapes and sizes for various Ops that the
pattern generates (PadOp + ExtractSliceOp +  EmptyOp + TensorOp +
InsertSliceOp) becomes much simpler and easier to document. It also
helped _generalize_ the pattern to support cases like the one below:

```mlir
func.func @simple_pad_and_pack_dynamic_tile_cst(
    %src: tensor<5x1xf32>,
    %dest: tensor<1x1x?x2xf32>,
    %pad: f32) -> tensor<1x1x?x2xf32> {

  %tile_dim_0 = arith.constant 8 : index
  %0 = tensor.pack %src
    padding_value(%pad : f32)
    inner_dims_pos = [0, 1]
    inner_tiles = [%tile_dim_0, 2]
    into %dest : tensor<5x1xf32> -> tensor<1x1x?x2xf32>

  return %0 : tensor<1x1x?x2xf32>
}
```

Note that the inner tile slice is dynamic, but compile-time constant.
`getPackOpSourceOrPaddedSource` - that's used to generated PadOp - is
able to see that and generates PadOp with static shapes. This a good
optimization, but it means that all shapes/sizes for Ops generated by
`GeneralizeOuterUnitDimsPackOpPattern` also have to be updated to be
constant/static. By _restricting_ the pattern and making the size/shape
calculation more straightforward, supporting the case above becomes much
easier.

Notable implementation changes:
  * PadOp processes the original source (no change in dimensions/rank).
    ExtractSliceOp extracts the tile to pack and it may reduce the rank.
    All ops that follow operate on the tile extracted by ExtractSliceOp
    (possibly rank-reducded).
  * All shape/size calculations assume that trailing dims match `inner_tiles`
    from `tensor.pack`. All the leading dims (i.e. outer dims) are
    assumed to be 1.
  * Dynamic sizes for ops like ExtractSliceOp are taken from
    `inner_tiles` rather than computed as e.g. `tensor.dim %dest, 2`.
    It's for "producers" of `tensor.pack` to make sure that the
    dimensions in `%dest` match the specified tile sizes.
  • Loading branch information
banach-space committed Oct 30, 2024
1 parent e61a7dc commit 7f79675
Show file tree
Hide file tree
Showing 3 changed files with 239 additions and 84 deletions.
40 changes: 37 additions & 3 deletions mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
Original file line number Diff line number Diff line change
Expand Up @@ -1515,9 +1515,43 @@ struct GeneralizePadOpPattern : public OpRewritePattern<tensor::PadOp> {
const SmallVector<Value> &dynSizes) const;
};

/// Rewrites a tensor::PackOp into a sequence of tensor.pad + linalg.transpose +
/// tensor.insert_slice ops, where the tensor::PackOp has outer dims being all
/// 1s.
/// Rewrites a tensor::PackOp into a sequence of:
/// * tensor::PadOp + linalg::TransposeOp + tensor::ExtractSliceOp +
/// tensor::EmptyOp + tensor::InsertSliceOp ops.
///
/// Required that all the outer dims of the input tensor::PackOp are 1.
///
/// Before:
/// ```
/// %packed = tensor.pack %input
/// padding_value(%pad : f32)
/// inner_dims_pos = [1, 0]
/// inner_tiles = [2, %high]
/// into %output : tensor<5x1xf32> -> tensor<1x1x2x?xf32>
/// ```
///
/// After:
/// ```
/// // PadOp
/// %padded = tensor.pad %arg0 low[0, 0] high[%0, 1] {
/// ^bb0(...):
/// tensor.yield %arg2 : f32
/// } : tensor<5x1xf32> to tensor<?x2xf32>
/// // ExtractSliceOp
/// %extracted_slice = tensor.extract_slice %padded[0, 0] [%tile_dim_1, 2] [1,
/// 1]
/// : tensor<?x2xf32> to tensor<?x2xf32>
/// // EmptyOp + TransposeOp
/// %empty = tensor.empty(%arg3) : tensor<2x?xf32>
/// %transposed = linalg.transpose
/// ins(%extracted_slice : tensor<?x2xf32>)
/// outs(%empty : tensor<2x?xf32>)
/// permutation = [1, 0]
/// // InsertSliceOp
/// %inserted_slice = tensor.insert_slice %transposed
/// into %arg1[0, 0, 0, 0] [1, 1, 2, %tile_dim_1] [1, 1, 1, 1]
/// : tensor<2x?xf32> into tensor<1x1x2x?xf32>
/// ```
struct GeneralizeOuterUnitDimsPackOpPattern
: public OpRewritePattern<tensor::PackOp> {
using OpRewritePattern<tensor::PackOp>::OpRewritePattern;
Expand Down
125 changes: 92 additions & 33 deletions mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
Expand Down Expand Up @@ -1138,6 +1139,29 @@ getPackUnpackRankReducedPerm(ArrayRef<int64_t> shape,
return perm;
}

// A helper function to generate a dim-and-size pair for Ops like
// ExtractSliceOp that require both:
// * dims to specify the output shape, and
// * sizes for the sizes attribute (or similar).
// For dynamic sizes, if the corresponding size is a compile time constant:
// * the return size becomes the attribute encapsulating the known size, and
// * dim is updated from kDynamic to its actual known value.
static std::pair<int64_t, OpFoldResult>
getSimplifiedDimSizePair(OpFoldResult tileSizeOfr, PatternRewriter &rewriter) {
int64_t tileSizeForShape =
getConstantIntValue(tileSizeOfr).value_or(ShapedType::kDynamic);

OpFoldResult tileSizeOfrSimplified;
if (tileSizeForShape != ShapedType::kDynamic) {
tileSizeOfrSimplified = rewriter.getIndexAttr(tileSizeForShape);
} else {
tileSizeOfrSimplified = tileSizeOfr;
}

return std::pair<int64_t, OpFoldResult>(tileSizeForShape,
tileSizeOfrSimplified);
}

LogicalResult GeneralizeOuterUnitDimsPackOpPattern::matchAndRewrite(
tensor::PackOp packOp, PatternRewriter &rewriter) const {
// TODO: support the case that outer dimensions are not all 1s. A
Expand All @@ -1148,69 +1172,104 @@ LogicalResult GeneralizeOuterUnitDimsPackOpPattern::matchAndRewrite(
packOp, "require the tiled outer dimensions of the result are all 1s");
}

// 1. Use rank-reduced tensor.extract_slice op to extract the tile and untiled
// outer dims.
Attribute zeroIdxAttr = rewriter.getIndexAttr(0);
Attribute oneIdxAttr = rewriter.getIndexAttr(1);
Location loc = packOp.getLoc();

Value input = getPackOpSourceOrPaddedSource(rewriter, packOp);
auto inputShape = packOp.getSourceType().getShape();
DenseMap<int64_t, OpFoldResult> dimAndTileMapping =
packOp.getDimAndTileMapping();
Attribute zeroIdxAttr = rewriter.getIndexAttr(0);
Attribute oneIdxAttr = rewriter.getIndexAttr(1);
int64_t srcRank = packOp.getSourceRank();

int64_t destRank = packOp.getDestRank();
size_t numTiles = destRank - srcRank;

// 1. Use rank-reduced tensor.extract_slice op to extract the tile:
// %extracted_tile = tensor.extract_slice(%pack_op_input)
SmallVector<OpFoldResult> readOffsets(srcRank, zeroIdxAttr);
SmallVector<OpFoldResult> readStrides(srcRank, oneIdxAttr);
SmallVector<OpFoldResult> readSizes;
SmallVector<OpFoldResult> transShapeForEmpty;
SmallVector<int64_t> readShapeForExtractSlice;

// The sizes attribute for ExtractSliceOp. The leading sizes are set to 1 as
// all outer dims are 1.
SmallVector<OpFoldResult> extractSliceSizes(srcRank - numTiles, oneIdxAttr);
// The shape of the output for ExtractSliceOp. All leading unit dims are
// effectively rank-reduced, hence skipped.
SmallVector<int64_t> outputShapeForExtractSlice;

// Extract the trailing sizes and shape dims for ExtractSliceOp. These should
// be equal to the inner tile sizes.
for (auto i : llvm::seq<unsigned>(0, srcRank)) {
if (dimAndTileMapping.count(i)) {
readShapeForExtractSlice.push_back(
getConstantIntValue(dimAndTileMapping[i])
.value_or(ShapedType::kDynamic));
readSizes.push_back(dimAndTileMapping[i]);
transShapeForEmpty.push_back(dimAndTileMapping[i]);
continue;
}
if (ShapedType::isDynamic(inputShape[i])) {
readSizes.push_back(
rewriter.create<tensor::DimOp>(loc, input, i).getResult());
} else {
readSizes.push_back(rewriter.getIndexAttr(inputShape[i]));
}
if (inputShape[i] != 1) {
readShapeForExtractSlice.push_back(inputShape[i]);
transShapeForEmpty.push_back(rewriter.getIndexAttr(inputShape[i]));
auto [tileSize, tileSizeOfr] =
getSimplifiedDimSizePair(dimAndTileMapping[i], rewriter);
extractSliceSizes.push_back(tileSizeOfr);
outputShapeForExtractSlice.push_back(tileSize);
}
}

Type elemType = packOp.getSourceType().getElementType();
auto readType = RankedTensorType::get(readShapeForExtractSlice, elemType);
auto readType = RankedTensorType::get(outputShapeForExtractSlice, elemType);

Value tile = rewriter.create<tensor::ExtractSliceOp>(
loc, readType, input, readOffsets, readSizes, readStrides);
loc, readType, input, readOffsets, extractSliceSizes, readStrides);

// 2. Transpose the tile to match the inner tile order.
// 2. Transpose the tile to match the inner tile order:
// %init = tensor.empty()
// %transposed_tile = linalg.transpose ins(%extracted_tile), outs(%init)
// NOTE: Outer dims are 1 and hence effectively ignored.
SmallVector<int64_t> perm = getPackUnpackRankReducedPerm(
inputShape, packOp.getInnerDimsPos(), packOp.getOuterDimsPerm());

LLVM_DEBUG(DBGS() << "Pack permutation: " << packOp << "\n";
llvm::interleaveComma(perm, DBGS() << "perm: "); DBGSNL(););

applyPermutationToVector<OpFoldResult>(transShapeForEmpty, perm);
// 2.1 Create tensor.empty (init value for TransposeOp)
SmallVector<OpFoldResult> transShapeForEmptyOpDynamic;
SmallVector<int64_t> transShapeForEmptyOpStatic;

// Acquire tensor shape required to create EmptyOp. This will match the inner
// tile sizes, but the actual data format will depend on whether the tile
// sizes are static or dynamic (each case leads to a different builder for
// EmptyOp). Conservatively, prepare for both scenarios.
size_t idx = numTiles;
while (idx != 0) {
transShapeForEmptyOpDynamic.push_back(extractSliceSizes[srcRank - idx]);
transShapeForEmptyOpStatic.push_back(
outputShapeForExtractSlice[numTiles - idx]);
idx--;
}

Value empty =
rewriter.create<tensor::EmptyOp>(loc, transShapeForEmpty, elemType);
applyPermutationToVector<int64_t>(transShapeForEmptyOpStatic, perm);
applyPermutationToVector<OpFoldResult>(transShapeForEmptyOpDynamic, perm);

Value empty = ShapedType::isDynamicShape(transShapeForEmptyOpStatic)
? rewriter.create<tensor::EmptyOp>(
loc, transShapeForEmptyOpDynamic, elemType)
: rewriter.create<tensor::EmptyOp>(
loc, transShapeForEmptyOpStatic, elemType);

// 2.2 Create linalg.transpose
auto transposedOp =
rewriter.create<linalg::TransposeOp>(loc, tile, empty, perm);

// 3. Insert the inner tile to the destination.
int64_t destRank = packOp.getDestRank();
// 3. Insert the inner tile to the destination:
// %inserted_tile = tensor.insert_slice(%transposed_tile)
SmallVector<OpFoldResult> writeStrides(destRank, oneIdxAttr);
SmallVector<OpFoldResult> writeOffsets(destRank, zeroIdxAttr);
SmallVector<OpFoldResult> writeSizes =
tensor::getMixedSizes(rewriter, loc, packOp.getDest());
// Outer dims are all 1s!
SmallVector<OpFoldResult> writeSizes(destRank - dimAndTileMapping.size(),
oneIdxAttr);
SmallVector<int64_t> writeShape;

for (auto tileSize : packOp.getMixedTiles()) {
auto [tileSizeStatic, tileSizeOfr] =
getSimplifiedDimSizePair(tileSize, rewriter);
writeSizes.push_back(tileSizeOfr);
writeShape.push_back(tileSizeStatic);
}

// 4. Replace tensor.packOp with tensor.insert_slice created above
auto insert = rewriter.create<tensor::InsertSliceOp>(
loc, transposedOp.getResult()[0], packOp.getDest(), writeOffsets,
writeSizes, writeStrides);
Expand Down
Loading

0 comments on commit 7f79675

Please sign in to comment.