diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h index b5710bd78f0089..a8662a3d6f63be 100644 --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -1515,9 +1515,43 @@ struct GeneralizePadOpPattern : public OpRewritePattern { const SmallVector &dynSizes) const; }; -/// Rewrites a tensor::PackOp into a sequence of tensor.pad + linalg.transpose + -/// tensor.insert_slice ops, where the tensor::PackOp has outer dims being all -/// 1s. +/// Rewrites a tensor::PackOp into a sequence of: +/// * tensor::PadOp + linalg::TransposeOp + tensor::ExtractSliceOp + +/// tensor::EmptyOp + tensor::InsertSliceOp ops. +/// +/// Required that all the outer dims of the input tensor::PackOp are 1. +/// +/// Before: +/// ``` +/// %packed = tensor.pack %input +/// padding_value(%pad : f32) +/// inner_dims_pos = [1, 0] +/// inner_tiles = [2, %high] +/// into %output : tensor<5x1xf32> -> tensor<1x1x2x?xf32> +/// ``` +/// +/// After: +/// ``` +/// // PadOp +/// %padded = tensor.pad %arg0 low[0, 0] high[%0, 1] { +/// ^bb0(...): +/// tensor.yield %arg2 : f32 +/// } : tensor<5x1xf32> to tensor +/// // ExtractSliceOp +/// %extracted_slice = tensor.extract_slice %padded[0, 0] [%tile_dim_1, 2] [1, +/// 1] +/// : tensor to tensor +/// // EmptyOp + TransposeOp +/// %empty = tensor.empty(%arg3) : tensor<2x?xf32> +/// %transposed = linalg.transpose +/// ins(%extracted_slice : tensor) +/// outs(%empty : tensor<2x?xf32>) +/// permutation = [1, 0] +/// // InsertSliceOp +/// %inserted_slice = tensor.insert_slice %transposed +/// into %arg1[0, 0, 0, 0] [1, 1, 2, %tile_dim_1] [1, 1, 1, 1] +/// : tensor<2x?xf32> into tensor<1x1x2x?xf32> +/// ``` struct GeneralizeOuterUnitDimsPackOpPattern : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp index da5233049aaf69..ed5f1bd602d7f4 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp @@ -27,6 +27,7 @@ #include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/IR/AffineExpr.h" #include "mlir/IR/Matchers.h" +#include "mlir/IR/PatternMatch.h" #include "mlir/Pass/Pass.h" #include "mlir/Support/LLVM.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" @@ -1138,6 +1139,29 @@ getPackUnpackRankReducedPerm(ArrayRef shape, return perm; } +// A helper function to generate a dim-and-size pair for Ops like +// ExtractSliceOp that require both: +// * dims to specify the output shape, and +// * sizes for the sizes attribute (or similar). +// For dynamic sizes, if the corresponding size is a compile time constant: +// * the return size becomes the attribute encapsulating the known size, and +// * dim is updated from kDynamic to its actual known value. +static std::pair +getSimplifiedDimSizePair(OpFoldResult tileSizeOfr, PatternRewriter &rewriter) { + int64_t tileSizeForShape = + getConstantIntValue(tileSizeOfr).value_or(ShapedType::kDynamic); + + OpFoldResult tileSizeOfrSimplified; + if (tileSizeForShape != ShapedType::kDynamic) { + tileSizeOfrSimplified = rewriter.getIndexAttr(tileSizeForShape); + } else { + tileSizeOfrSimplified = tileSizeOfr; + } + + return std::pair(tileSizeForShape, + tileSizeOfrSimplified); +} + LogicalResult GeneralizeOuterUnitDimsPackOpPattern::matchAndRewrite( tensor::PackOp packOp, PatternRewriter &rewriter) const { // TODO: support the case that outer dimensions are not all 1s. A @@ -1148,69 +1172,104 @@ LogicalResult GeneralizeOuterUnitDimsPackOpPattern::matchAndRewrite( packOp, "require the tiled outer dimensions of the result are all 1s"); } - // 1. Use rank-reduced tensor.extract_slice op to extract the tile and untiled - // outer dims. + Attribute zeroIdxAttr = rewriter.getIndexAttr(0); + Attribute oneIdxAttr = rewriter.getIndexAttr(1); Location loc = packOp.getLoc(); + Value input = getPackOpSourceOrPaddedSource(rewriter, packOp); auto inputShape = packOp.getSourceType().getShape(); DenseMap dimAndTileMapping = packOp.getDimAndTileMapping(); - Attribute zeroIdxAttr = rewriter.getIndexAttr(0); - Attribute oneIdxAttr = rewriter.getIndexAttr(1); int64_t srcRank = packOp.getSourceRank(); + + int64_t destRank = packOp.getDestRank(); + size_t numTiles = destRank - srcRank; + + // 1. Use rank-reduced tensor.extract_slice op to extract the tile: + // %extracted_tile = tensor.extract_slice(%pack_op_input) SmallVector readOffsets(srcRank, zeroIdxAttr); SmallVector readStrides(srcRank, oneIdxAttr); - SmallVector readSizes; - SmallVector transShapeForEmpty; - SmallVector readShapeForExtractSlice; + + // The sizes attribute for ExtractSliceOp. The leading sizes are set to 1 as + // all outer dims are 1. + SmallVector extractSliceSizes(srcRank - numTiles, oneIdxAttr); + // The shape of the output for ExtractSliceOp. All leading unit dims are + // effectively rank-reduced, hence skipped. + SmallVector outputShapeForExtractSlice; + + // Extract the trailing sizes and shape dims for ExtractSliceOp. These should + // be equal to the inner tile sizes. for (auto i : llvm::seq(0, srcRank)) { if (dimAndTileMapping.count(i)) { - readShapeForExtractSlice.push_back( - getConstantIntValue(dimAndTileMapping[i]) - .value_or(ShapedType::kDynamic)); - readSizes.push_back(dimAndTileMapping[i]); - transShapeForEmpty.push_back(dimAndTileMapping[i]); - continue; - } - if (ShapedType::isDynamic(inputShape[i])) { - readSizes.push_back( - rewriter.create(loc, input, i).getResult()); - } else { - readSizes.push_back(rewriter.getIndexAttr(inputShape[i])); - } - if (inputShape[i] != 1) { - readShapeForExtractSlice.push_back(inputShape[i]); - transShapeForEmpty.push_back(rewriter.getIndexAttr(inputShape[i])); + auto [tileSize, tileSizeOfr] = + getSimplifiedDimSizePair(dimAndTileMapping[i], rewriter); + extractSliceSizes.push_back(tileSizeOfr); + outputShapeForExtractSlice.push_back(tileSize); } } Type elemType = packOp.getSourceType().getElementType(); - auto readType = RankedTensorType::get(readShapeForExtractSlice, elemType); + auto readType = RankedTensorType::get(outputShapeForExtractSlice, elemType); Value tile = rewriter.create( - loc, readType, input, readOffsets, readSizes, readStrides); + loc, readType, input, readOffsets, extractSliceSizes, readStrides); - // 2. Transpose the tile to match the inner tile order. + // 2. Transpose the tile to match the inner tile order: + // %init = tensor.empty() + // %transposed_tile = linalg.transpose ins(%extracted_tile), outs(%init) + // NOTE: Outer dims are 1 and hence effectively ignored. SmallVector perm = getPackUnpackRankReducedPerm( inputShape, packOp.getInnerDimsPos(), packOp.getOuterDimsPerm()); LLVM_DEBUG(DBGS() << "Pack permutation: " << packOp << "\n"; llvm::interleaveComma(perm, DBGS() << "perm: "); DBGSNL();); - applyPermutationToVector(transShapeForEmpty, perm); + // 2.1 Create tensor.empty (init value for TransposeOp) + SmallVector transShapeForEmptyOpDynamic; + SmallVector transShapeForEmptyOpStatic; + + // Acquire tensor shape required to create EmptyOp. This will match the inner + // tile sizes, but the actual data format will depend on whether the tile + // sizes are static or dynamic (each case leads to a different builder for + // EmptyOp). Conservatively, prepare for both scenarios. + size_t idx = numTiles; + while (idx != 0) { + transShapeForEmptyOpDynamic.push_back(extractSliceSizes[srcRank - idx]); + transShapeForEmptyOpStatic.push_back( + outputShapeForExtractSlice[numTiles - idx]); + idx--; + } - Value empty = - rewriter.create(loc, transShapeForEmpty, elemType); + applyPermutationToVector(transShapeForEmptyOpStatic, perm); + applyPermutationToVector(transShapeForEmptyOpDynamic, perm); + + Value empty = ShapedType::isDynamicShape(transShapeForEmptyOpStatic) + ? rewriter.create( + loc, transShapeForEmptyOpDynamic, elemType) + : rewriter.create( + loc, transShapeForEmptyOpStatic, elemType); + + // 2.2 Create linalg.transpose auto transposedOp = rewriter.create(loc, tile, empty, perm); - // 3. Insert the inner tile to the destination. - int64_t destRank = packOp.getDestRank(); + // 3. Insert the inner tile to the destination: + // %inserted_tile = tensor.insert_slice(%transposed_tile) SmallVector writeStrides(destRank, oneIdxAttr); SmallVector writeOffsets(destRank, zeroIdxAttr); - SmallVector writeSizes = - tensor::getMixedSizes(rewriter, loc, packOp.getDest()); + // Outer dims are all 1s! + SmallVector writeSizes(destRank - dimAndTileMapping.size(), + oneIdxAttr); + SmallVector writeShape; + + for (auto tileSize : packOp.getMixedTiles()) { + auto [tileSizeStatic, tileSizeOfr] = + getSimplifiedDimSizePair(tileSize, rewriter); + writeSizes.push_back(tileSizeOfr); + writeShape.push_back(tileSizeStatic); + } + // 4. Replace tensor.packOp with tensor.insert_slice created above auto insert = rewriter.create( loc, transposedOp.getResult()[0], packOp.getDest(), writeOffsets, writeSizes, writeStrides); diff --git a/mlir/test/Dialect/Linalg/generalize-tensor-pack.mlir b/mlir/test/Dialect/Linalg/generalize-tensor-pack.mlir index 7f6b5e279f6857..8abf7a11bed5c9 100644 --- a/mlir/test/Dialect/Linalg/generalize-tensor-pack.mlir +++ b/mlir/test/Dialect/Linalg/generalize-tensor-pack.mlir @@ -1,21 +1,32 @@ // RUN: mlir-opt -split-input-file --test-linalg-transform-patterns="test-generalize-tensor-pack" %s | FileCheck %s -func.func @simple_KCRS_to_KCRSsr(%arg0: tensor<1x1x32x8xf32>, %arg1: tensor<1x1x1x1x8x32xf32>) -> tensor<1x1x1x1x8x32xf32> { - %0 = tensor.pack %arg0 inner_dims_pos = [3, 2] inner_tiles = [8, 32] into %arg1 : tensor<1x1x32x8xf32> -> tensor<1x1x1x1x8x32xf32> - return %0 : tensor<1x1x1x1x8x32xf32> + +func.func @simple_KCRS_to_KCRSsr(%arg0: tensor, %arg1: tensor<1x1x?x1xi32>) -> tensor<1x1x?x1xi32> { + %c8 = arith.constant 8 : index + %c5 = arith.constant 5 : i32 + %pack = tensor.pack %arg0 padding_value(%c5 : i32) inner_dims_pos = [0, 1] inner_tiles = [%c8, 1] into %arg1 : tensor -> tensor<1x1x?x1xi32> + return %pack : tensor<1x1x?x1xi32> } -// CHECK-LABEL: func.func @simple_KCRS_to_KCRSsr -// CHECK-SAME: %[[SRC:[a-zA-Z0-9]+]] -// CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]] -// CHECK: %[[TILE:.+]] = tensor.extract_slice %[[SRC]][0, 0, 0, 0] [1, 1, 32, 8] [1, 1, 1, 1] -// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<8x32xf32> -// CHECK: %[[TRANSP:.+]] = linalg.transpose -// CHECK-SAME: ins(%[[TILE]] : tensor<32x8xf32>) -// CHECK-SAME: outs(%[[EMPTY]] : tensor<8x32xf32>) -// CHECK-SAME: permutation = [1, 0] -// CHECK: %[[INSERT:.+]] = tensor.insert_slice %[[TRANSP]] into %[[DEST]] -// CHECK-SAME: [0, 0, 0, 0, 0, 0] [1, 1, 1, 1, 8, 32] [1, 1, 1, 1, 1, 1] -// CHECK: return %[[INSERT]] + +// CHECK: #[[$ATTR_0:.+]] = affine_map<()[s0] -> (-s0 + 8)> +// CHECK: #[[$ATTR_1:.+]] = affine_map<()[s0] -> (-s0 + 1)> + +// CHECK-LABEL: func.func @simple_KCRS_to_KCRSsr( +// CHECK-SAME: %[[SRC:[a-zA-Z0-9]+]]: tensor, +// CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]]: tensor<1x1x?x1xi32>) -> tensor<1x1x?x1xi32> +// CHECK: %[[VAL_2:.*]] = arith.constant 1 : index +// CHECK: %[[VAL_3:.*]] = arith.constant 5 : i32 +// CHECK: %[[VAL_4:.*]] = arith.constant 0 : index +// CHECK: %[[VAL_5:.*]] = tensor.dim %[[SRC]], %[[VAL_4]] : tensor +// CHECK: %[[VAL_6:.*]] = affine.apply #[[$ATTR_0]](){{\[}}%[[VAL_5]]] +// CHECK: %[[VAL_7:.*]] = tensor.dim %[[SRC]], %[[VAL_2]] : tensor +// CHECK: %[[VAL_8:.*]] = affine.apply #[[$ATTR_1]](){{\[}}%[[VAL_7]]] +// CHECK: %[[PAD:.*]] = tensor.pad %[[SRC]] low[0, 0] high{{\[}}%[[VAL_6]], %[[VAL_8]]] { +// CHECK: ^bb0(%[[VAL_10:.*]]: index, %[[VAL_11:.*]]: index): +// CHECK: tensor.yield %[[VAL_3]] : i32 +// CHECK: } : tensor to tensor<8x1xi32> +// CHECK: %[[INSERT:.*]] = tensor.insert_slice %[[PAD:.*]] into %[[DEST]][0, 0, 0, 0] [1, 1, 8, 1] [1, 1, 1, 1] : tensor<8x1xi32> into tensor<1x1x?x1xi32> +// CHECK: return %[[INSERT]] : tensor<1x1x?x1xi32> // ----- @@ -39,26 +50,59 @@ func.func @simple_pad_and_pack_static_tiles(%input: tensor<5x1xf32>, %output: te /// Same as example above, but with 1 dynamic tile size. -func.func @simple_pad_and_pack_dynamic_tile(%input: tensor<5x1xf32>, %output: tensor<1x1x?x2xf32>, %pad: f32, %high: index) -> tensor<1x1x?x2xf32> { - %0 = tensor.pack %input padding_value(%pad : f32) inner_dims_pos = [0, 1] inner_tiles = [%high, 2] into %output : tensor<5x1xf32> -> tensor<1x1x?x2xf32> +func.func @simple_pad_and_pack_dynamic_tile(%input: tensor<5x1xf32>, %output: tensor<1x1x?x2xf32>, %pad: f32, %tile_dim_0: index) -> tensor<1x1x?x2xf32> { + %0 = tensor.pack %input padding_value(%pad : f32) inner_dims_pos = [0, 1] inner_tiles = [%tile_dim_0, 2] into %output : tensor<5x1xf32> -> tensor<1x1x?x2xf32> return %0 : tensor<1x1x?x2xf32> } - // CHECK-LABEL: func.func @simple_pad_and_pack_dynamic_tile( // CHECK-SAME: %[[SRC:[a-zA-Z0-9]+]] // CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]] // CHECK-SAME: %[[PAD_VAL:[a-zA-Z0-9]+]] -// CHECK-SAME: %[[HIGH_VAL:[a-zA-Z0-9]+]]: index) -> tensor<1x1x?x2xf32> { -// CHECK: %[[C2:.*]] = arith.constant 2 : index -// CHECK: %[[PAD_HIGH:.*]] = affine.apply #[[$ATTR_0]](){{\[}}%[[HIGH_VAL]]] +// CHECK-SAME: %[[TILE_DIM_0:[a-zA-Z0-9]+]]: index) -> tensor<1x1x?x2xf32> { +// CHECK: %[[PAD_HIGH:.*]] = affine.apply #[[$ATTR_0]](){{\[}}%[[TILE_DIM_0]]] // CHECK: %[[PAD:.*]] = tensor.pad %[[SRC]] low[0, 0] high{{\[}}%[[PAD_HIGH]], 1] { // CHECK: tensor.yield %[[PAD_VAL]] : f32 // CHECK-NOT: linalg.transpose -// CHECK: %[[SLICE:.*]] = tensor.extract_slice %[[VAL_10:.*]][0, 0] {{\[}}%[[HIGH_VAL]], 2] [1, 1] : tensor to tensor -// CHECK: %[[DIM:.*]] = tensor.dim %[[DEST]], %[[C2]] : tensor<1x1x?x2xf32> -// CHECK: %[[RES:.*]] = tensor.insert_slice %[[SLICE]] into %[[DEST]][0, 0, 0, 0] [1, 1, %[[DIM]], 2] [1, 1, 1, 1] : tensor into tensor<1x1x?x2xf32> +// CHECK: %[[SLICE:.*]] = tensor.extract_slice %[[PAD:.*]][0, 0] {{\[}}%[[TILE_DIM_0]], 2] [1, 1] : tensor to tensor +// CHECK: %[[RES:.*]] = tensor.insert_slice %[[SLICE]] into %[[DEST]][0, 0, 0, 0] [1, 1, %[[TILE_DIM_0]], 2] [1, 1, 1, 1] : tensor into tensor<1x1x?x2xf32> // CHECK: return %[[RES]] : tensor<1x1x?x2xf32> +func.func @simple_pad_and_pack_dynamic_tile_cst(%input: tensor<5x1xf32>, %output: tensor<1x1x?x2xf32>, %pad: f32) -> tensor<1x1x?x2xf32> { + %tile_dim_0 = arith.constant 8 : index + %0 = tensor.pack %input padding_value(%pad : f32) inner_dims_pos = [0, 1] inner_tiles = [%tile_dim_0, 2] into %output : tensor<5x1xf32> -> tensor<1x1x?x2xf32> + return %0 : tensor<1x1x?x2xf32> +} +// CHECK-LABEL: func.func @simple_pad_and_pack_dynamic_tile_cst( +// CHECK-SAME: %[[SRC:[a-zA-Z0-9]+]] +// CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]] +// CHECK-SAME: %[[PAD_VAL:[a-zA-Z0-9]+]]: f32) -> tensor<1x1x?x2xf32> { +// CHECK: %[[PAD:.*]] = tensor.pad %[[SRC]] low[0, 0] high[3, 1] { +// CHECK: tensor.yield %[[PAD_VAL]] : f32 +// CHECK: } : tensor<5x1xf32> to tensor<8x2xf32> +// CHECK: %[[RES:.*]] = tensor.insert_slice %[[PAD:.*]] into %[[DEST]][0, 0, 0, 0] [1, 1, 8, 2] [1, 1, 1, 1] : tensor<8x2xf32> into tensor<1x1x?x2xf32> +// CHECK: return %[[RES]] : tensor<1x1x?x2xf32> + +func.func @simple_pad_and_pack_dynamic_tile_transpose(%input: tensor<5x1xf32>, %output: tensor<1x1x2x?xf32>, %pad: f32, %tile_dim_1: index) -> tensor<1x1x2x?xf32> { + %0 = tensor.pack %input padding_value(%pad : f32) inner_dims_pos = [1, 0] inner_tiles = [2, %tile_dim_1] into %output : tensor<5x1xf32> -> tensor<1x1x2x?xf32> + return %0 : tensor<1x1x2x?xf32> +} +// CHECK-LABEL: func.func @simple_pad_and_pack_dynamic_tile_transpose( +// CHECK-SAME: %[[SRC:[a-zA-Z0-9]+]] +// CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]] +// CHECK-SAME: %[[PAD_VAL:[a-zA-Z0-9]+]] +// CHECK-SAME: %[[TILE_DIM_1:[a-zA-Z0-9]+]]: index) -> tensor<1x1x2x?xf32> { +// CHECK: %[[PAD_HIGH:.*]] = affine.apply #[[$ATTR_0]](){{\[}}%[[TILE_DIM_1]]] +// CHECK: %[[PAD:.*]] = tensor.pad %[[SRC]] low[0, 0] high{{\[}}%[[PAD_HIGH]], 1] { +// CHECK: tensor.yield %[[PAD_VAL]] : f32 +// CHECK-NEXT: } : tensor<5x1xf32> to tensor +// CHECK: %[[SLICE:.*]] = tensor.extract_slice %[[PAD:.*]][0, 0] {{\[}}%[[TILE_DIM_1]], 2] [1, 1] : tensor to tensor +// CHECK: %[[EMPTY:.*]] = tensor.empty(%[[TILE_DIM_1]]) : tensor<2x?xf32> +// CHECK: %[[TR:.*]] = linalg.transpose +// CHECK-SAME: ins(%[[SLICE]] : tensor) outs(%[[EMPTY]] : tensor<2x?xf32>) +// CHECK-SAME: permutation = [1, 0] +// CHECK: %[[RES:.*]] = tensor.insert_slice %[[TR]] into %[[DEST]][0, 0, 0, 0] [1, 1, 2, %[[TILE_DIM_1]]] [1, 1, 1, 1] : tensor<2x?xf32> into tensor<1x1x2x?xf32> +// CHECK: return %[[RES]] : tensor<1x1x2x?xf32> + /// Same as example above, but with 1 scalable tile size. /// NOTE: For this example to make sense in practice, the "?" in the output shape @@ -77,7 +121,6 @@ func.func @simple_pad_and_pack_scalable_tile(%input: tensor<5x1xf32>, %output: t // CHECK-SAME: %[[SRC:[a-zA-Z0-9]+]]: tensor<5x1xf32>, // CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]]: tensor<1x1x?x2xf32>, // CHECK-SAME: %[[PAD_VAL:[a-zA-Z0-9]+]]: f32) -> tensor<1x1x?x2xf32> { -// CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index // CHECK-DAG: %[[C8:.+]] = arith.constant 8 : index // CHECK-DAG: %[[VS:.+]] = vector.vscale // CHECK: %[[C8_VS:.+]] = arith.muli %[[VS]], %[[C8]] : index @@ -86,37 +129,56 @@ func.func @simple_pad_and_pack_scalable_tile(%input: tensor<5x1xf32>, %output: t // CHECK: tensor.yield %[[PAD_VAL]] : f32 // CHECK-NOT: linalg.transpose // CHECK: %[[SLICE:.+]] = tensor.extract_slice %[[PAD:.+]][0, 0] {{\[}}%[[C8_VS]], 2] [1, 1] : tensor to tensor -// CHECK: %[[DIM:.+]] = tensor.dim %[[DEST]], %[[C2]] : tensor<1x1x?x2xf32> -// CHECK: %[[RES:.+]] = tensor.insert_slice %[[SLICE]] into %[[DEST]][0, 0, 0, 0] [1, 1, %[[DIM]], 2] [1, 1, 1, 1] : tensor into tensor<1x1x?x2xf32> +// CHECK: %[[RES:.+]] = tensor.insert_slice %[[SLICE]] into %[[DEST]][0, 0, 0, 0] [1, 1, %[[C8_VS]], 2] [1, 1, 1, 1] : tensor into tensor<1x1x?x2xf32> // CHECK: return %[[RES]] : tensor<1x1x?x2xf32> /// Same as example above, but with both tile sizes dynamic. -func.func @simple_pad_and_pack_dynamic_tiles(%input: tensor<5x1xf32>, %output: tensor<1x1x?x?xf32>, %pad: f32, %high_1: index, %high_2: index) -> tensor<1x1x?x?xf32> { - %0 = tensor.pack %input padding_value(%pad : f32) inner_dims_pos = [0, 1] inner_tiles = [%high_1, %high_2] into %output : tensor<5x1xf32> -> tensor<1x1x?x?xf32> +func.func @simple_pad_and_pack_dynamic_tiles(%input: tensor<5x1xf32>, %output: tensor<1x1x?x?xf32>, %pad: f32, %tile_dim_0: index, %tile_dim_1: index) -> tensor<1x1x?x?xf32> { + %0 = tensor.pack %input padding_value(%pad : f32) inner_dims_pos = [0, 1] inner_tiles = [%tile_dim_0, %tile_dim_1] into %output : tensor<5x1xf32> -> tensor<1x1x?x?xf32> return %0 : tensor<1x1x?x?xf32> } // CHECK-LABEL: func.func @simple_pad_and_pack_dynamic_tiles( // CHECK-SAME: %[[SRC:[a-zA-Z0-9]+]]: tensor<5x1xf32>, // CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]]: tensor<1x1x?x?xf32>, // CHECK-SAME: %[[PAD_VAL:[a-zA-Z0-9]+]]: f32, -// CHECK-SAME: %[[HIGH_VAL_1:[a-zA-Z0-9]+]]: index, -// CHECK-SAME: %[[HIGH_VAL_2:[a-zA-Z0-9]+]]: index) -> tensor<1x1x?x?xf32> { -// CHECK: %[[C3:.*]] = arith.constant 3 : index -// CHECK: %[[C2:.*]] = arith.constant 2 : index -// CHECK: %[[PAD_HIGH_1:.*]] = affine.apply #[[$ATTR_0]](){{\[}}%[[HIGH_VAL_1]]] -// CHECK: %[[PAD_HIGH_2:.*]] = affine.apply #[[$ATTR_1]](){{\[}}%[[HIGH_VAL_2]]] +// CHECK-SAME: %[[TILE_DIM_0:[a-zA-Z0-9]+]]: index, +// CHECK-SAME: %[[TILE_DIM_1:[a-zA-Z0-9]+]]: index) -> tensor<1x1x?x?xf32> { +// CHECK: %[[PAD_HIGH_1:.*]] = affine.apply #[[$ATTR_0]](){{\[}}%[[TILE_DIM_0]]] +// CHECK: %[[PAD_HIGH_2:.*]] = affine.apply #[[$ATTR_1]](){{\[}}%[[TILE_DIM_1]]] // CHECK: %[[PAD:.*]] = tensor.pad %[[SRC]] low[0, 0] high{{\[}}%[[PAD_HIGH_1]], %[[PAD_HIGH_2]]] { // CHECK: tensor.yield %[[PAD_VAL]] : f32 // CHECK-NOT: linalg.transpose -// CHECK: %[[SLICE:.*]] = tensor.extract_slice %[[PAD:.*]][0, 0] {{\[}}%[[HIGH_VAL_1]], %[[HIGH_VAL_2]]] [1, 1] : tensor to tensor -// CHECK: %[[DIM_1:.*]] = tensor.dim %[[DEST]], %[[C2]] : tensor<1x1x?x?xf32> -// CHECK: %[[DIM_2:.*]] = tensor.dim %[[DEST]], %[[C3]] : tensor<1x1x?x?xf32> -// CHECK: %[[RES:.*]] = tensor.insert_slice %[[SLICE]] into %[[DEST]][0, 0, 0, 0] [1, 1, %[[DIM_1]], %[[DIM_2]]] [1, 1, 1, 1] : tensor into tensor<1x1x?x?xf32> +// CHECK: %[[SLICE:.*]] = tensor.extract_slice %[[PAD:.*]][0, 0] {{\[}}%[[TILE_DIM_0]], %[[TILE_DIM_1]]] [1, 1] : tensor to tensor +// CHECK: %[[RES:.*]] = tensor.insert_slice %[[SLICE]] into %[[DEST]][0, 0, 0, 0] [1, 1, %[[TILE_DIM_0]], %[[TILE_DIM_1]]] [1, 1, 1, 1] : tensor into tensor<1x1x?x?xf32> // CHECK: return %[[RES]] : tensor<1x1x?x?xf32> // ----- +func.func @simple_pad_and_pack_dynamic_tile_not_all_dims_tiled(%input: tensor<1x1x5x1xf32>, %output: tensor<1x1x1x1x2x?xf32>, %pad: f32, %high: index) -> tensor<1x1x1x1x2x?xf32> { + %0 = tensor.pack %input padding_value(%pad : f32) outer_dims_perm = [1, 0, 2, 3] inner_dims_pos = [3, 2] inner_tiles = [2, %high] into %output : tensor<1x1x5x1xf32> -> tensor<1x1x1x1x2x?xf32> + return %0 : tensor<1x1x1x1x2x?xf32> +} +// CHECK: #[[$ATTR_2:.+]] = affine_map<()[s0] -> (s0 - 5)> +// CHECK-LABEL: func.func @simple_pad_and_pack_dynamic_tile_not_all_dims_tiled +// CHECK-SAME: %[[VAL_0:.*]]: tensor<1x1x5x1xf32>, +// CHECK-SAME: %[[VAL_1:.*]]: tensor<1x1x1x1x2x?xf32>, +// CHECK-SAME: %[[VAL_2:.*]]: f32, +// CHECK-SAME: %[[VAL_3:.*]]: index) -> tensor<1x1x1x1x2x?xf32> { +// CHECK: %[[VAL_4:.*]] = affine.apply #[[$ATTR_2]](){{\[}}%[[VAL_3]]] +// CHECK: %[[VAL_5:.*]] = tensor.pad %[[VAL_0]] low[0, 0, 0, 0] high[0, 0, %[[VAL_4]], 1] { +// CHECK: ^bb0(%[[VAL_6:.*]]: index, %[[VAL_7:.*]]: index, %[[VAL_8:.*]]: index, %[[VAL_9:.*]]: index): +// CHECK: tensor.yield %[[VAL_2]] : f32 +// CHECK: } : tensor<1x1x5x1xf32> to tensor<1x1x?x2xf32> +// CHECK: %[[VAL_10:.*]] = tensor.extract_slice %[[VAL_11:.*]][0, 0, 0, 0] [1, 1, %[[VAL_3]], 2] [1, 1, 1, 1] : tensor<1x1x?x2xf32> to tensor +// CHECK: %[[VAL_12:.*]] = tensor.empty(%[[VAL_3]]) : tensor<2x?xf32> +// CHECK: %[[VAL_13:.*]] = linalg.transpose ins(%[[VAL_10]] : tensor) outs(%[[VAL_12]] : tensor<2x?xf32>) permutation = [1, 0] +// CHECK: %[[VAL_14:.*]] = tensor.insert_slice %[[VAL_13]] into %[[VAL_1]][0, 0, 0, 0, 0, 0] [1, 1, 1, 1, 2, %[[VAL_3]]] [1, 1, 1, 1, 1, 1] : tensor<2x?xf32> into tensor<1x1x1x1x2x?xf32> +// CHECK: return %[[VAL_14]] : tensor<1x1x1x1x2x?xf32> +// CHECK: } + +// ----- + func.func @simple_NC_to_CNnc(%arg0: tensor<32x8xf32>, %arg1: tensor<1x1x32x8xf32>) -> tensor<1x1x32x8xf32>{ %0 = tensor.pack %arg0 outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [32, 8] into %arg1 : tensor<32x8xf32> -> tensor<1x1x32x8xf32> return %0 : tensor<1x1x32x8xf32> @@ -149,19 +211,19 @@ func.func @simple_CHW_to_CHWhwc(%arg0: tensor<3x5x7xf32>, %arg1: tensor<1x1x1x5x // ----- -func.func @simple_KCRS_to_KRSCsr(%arg0: tensor<3x1x32x8xf32>, %arg1: tensor<3x1x1x1x8x32xf32>) -> tensor<3x1x1x1x8x32xf32> { - %0 = tensor.pack %arg0 outer_dims_perm = [0, 2, 3, 1] inner_dims_pos = [3, 2] inner_tiles = [8, 32] into %arg1 : tensor<3x1x32x8xf32> -> tensor<3x1x1x1x8x32xf32> - return %0 : tensor<3x1x1x1x8x32xf32> +func.func @simple_KCRS_to_KRSCsr(%arg0: tensor<1x1x32x8xf32>, %arg1: tensor<1x1x1x1x8x32xf32>) -> tensor<1x1x1x1x8x32xf32> { + %0 = tensor.pack %arg0 outer_dims_perm = [0, 2, 3, 1] inner_dims_pos = [3, 2] inner_tiles = [8, 32] into %arg1 : tensor<1x1x32x8xf32> -> tensor<1x1x1x1x8x32xf32> + return %0 : tensor<1x1x1x1x8x32xf32> } // CHECK-LABEL: func.func @simple_KCRS_to_KRSCsr // CHECK-SAME: %[[SRC:[a-zA-Z0-9]+]] // CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]] -// CHECK: %[[TILE:.+]] = tensor.extract_slice %[[SRC]][0, 0, 0, 0] [3, 1, 32, 8] [1, 1, 1, 1] -// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<3x8x32xf32> +// CHECK: %[[TILE:.+]] = tensor.extract_slice %[[SRC]][0, 0, 0, 0] [1, 1, 32, 8] [1, 1, 1, 1] +// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<8x32xf32> // CHECK: %[[TRANSP:.+]] = linalg.transpose -// CHECK-SAME: ins(%[[TILE]] : tensor<3x32x8xf32>) -// CHECK-SAME: outs(%[[EMPTY]] : tensor<3x8x32xf32>) -// CHECK-SAME: permutation = [0, 2, 1] +// CHECK-SAME: ins(%[[TILE]] : tensor<32x8xf32>) +// CHECK-SAME: outs(%[[EMPTY]] : tensor<8x32xf32>) +// CHECK-SAME: permutation = [1, 0] // CHECK: %[[INSERT:.+]] = tensor.insert_slice %[[TRANSP]] into %[[DEST]] -// CHECK-SAME: [0, 0, 0, 0, 0, 0] [3, 1, 1, 1, 8, 32] [1, 1, 1, 1, 1, 1] +// CHECK-SAME: [0, 0, 0, 0, 0, 0] [1, 1, 1, 1, 8, 32] [1, 1, 1, 1, 1, 1] // CHECK: return %[[INSERT]]