Skip to content

Commit

Permalink
[MLIR] Add continuous tiling to transform dialect (#82792)
Browse files Browse the repository at this point in the history
This patch enables continuous tiling of a target structured op using
diminishing tile sizes. In cases where the tensor dimensions are not
exactly divisible by the tile size, we are left with leftover tensor
chunks that are irregularly tiled. This approach enables tiling of the
leftover chunk with a smaller tile size and repeats this process
recursively using exponentially diminishing tile sizes. This eventually
generates a chain of loops that apply tiling using diminishing tile
sizes.

Adds `continuous_tile_sizes` op to the transform dialect. This op, when
given a tile size and a dimension, computes a series of diminishing tile
sizes that can be used to tile the target along the given dimension.
Additionally, this op also generates a series of chunk sizes that the
corresponding tile sizes should be applied to along the given dimension.

Adds `multiway` attribute to `transform.structured.split` that enables
multiway splitting of a single target op along the given dimension, as
specified in a list enumerating the chunk sizes.
  • Loading branch information
muneebkhan85 authored Jun 21, 2024
1 parent 74a105a commit a9efcbf
Show file tree
Hide file tree
Showing 11 changed files with 839 additions and 98 deletions.
85 changes: 72 additions & 13 deletions mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -1396,29 +1396,43 @@ def SplitOp : Op<Transform_Dialect, "structured.split",
DeclareOpInterfaceMethods<TransformOpInterface>,
ReportTrackingListenerFailuresOpTrait]> {
let description = [{
Indicates that the given `target` op should be split into two complementary
Splits the given `target` op into two or more complementary
parts, which combined cover the entire iteration domain of the original op.
The split is performed along the iteration space dimension provided as
attribute. In case of dimension overflow, the transformation fails. The
split is performed at the dimension iterator value specified as either the
static split point attribute when it is known at transform IR construction
time or as the handle to an operation producing a single index-typed value
when it is computed by payload IR. In the latter case, the static split
chunk size attribute specifying the size of the lower part; the remaining
range in the iteration space is assigned as the upper part. In case of
dimension overflow, the transformation fails. The split is performed at the
dimension iterator value specified as either the static chunk size
attribute when it is known at transform IR construction time or
as the handle to an operation producing a single index-typed value
when it is computed by payload IR. In the latter case, the chunk size
point must be set to `ShapedType::kDynamic` and the dynamic size handle
must point to as many value-producing operations as there are structured
operations pointed to by the target handle.

The operation consumes the target handle, but preserves the split point
handle if provided. It produces two new handles pointing to the two parts
of the structured op after splitting, in the same order as the target
operand, with the first handle corresponding to the part with lower
iteration space indices.
The operation consumes the target handle, but preserves the chunk size
handle if provided. Without the `multiway` attribute, it produces two
new handles pointing to the two parts of the structured op after splitting,
in the same order as the target operand, with the first handle
corresponding to the part with lower iteration space indices.

Multiway split mode is enabled by specifying the `multiway` attribute.
In this mode a single `target` op is split into multiple parts covering
the iteration space of the specified dimension. `static_chunk_sizes` and
`dynamic_chunk_sizes` in this case is a list of chunk sizes that the given
dimension should be split into. With `multiway` it produces two handles;
the first handle is a list of the multiple parts of the structured op
after splitting, where the target dimensions for each linalg op in the
list corresponds to the chunk sizes specfied in the input split list.
If the chunk sizes do not cover the entire iteration space, the leftover
chunk is the last payload in the first handle. The second handle is empty.
}];

let arguments = (ins TransformHandleTypeInterface:$target,
I64Attr:$dimension,
Optional<TransformAnyParamTypeOrAnyHandle>:$dynamic_split_point,
I64Attr:$static_split_point);
Optional<TransformAnyParamTypeOrAnyHandle>:$dynamic_chunk_sizes,
I64Attr:$static_chunk_sizes,
UnitAttr:$multiway);
let results = (outs TransformHandleTypeInterface:$first,
TransformHandleTypeInterface:$second);
let hasCustomAssemblyFormat = 1;
Expand Down Expand Up @@ -1819,6 +1833,51 @@ def TileReductionUsingForallOp :

}

//===----------------------------------------------------------------------===//
// ContinuousTileSizesOp
//===----------------------------------------------------------------------===//

def ContinuousTileSizesOp : Op<Transform_Dialect, "structured.continuous_tile_sizes",
[DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
DeclareOpInterfaceMethods<TransformOpInterface>,
ReportTrackingListenerFailuresOpTrait]> {
let description = [{
This transform emits the IR computing the list of (1) exponentially
diminishing tile sizes that are powers of 2; and (2) the corresponding
chunk-sizes the target op should be split into along the given dimension.

For example, for `target_size` 9, and `dimension` 0 for the following
linalg op as target

```
%0 = linalg.matmul ins(%arg0, %arg1: tensor<25x34xf32>, tensor<34x25xf32>)
outs(%arg2: tensor<25x25xf32>)
```

the first result `tile_sizes` will be a list of diminishing tile sizes
9, 4, 2, 1; and the second result will be a list of chunk sizes
18, 4, 2, 1 that the corresponding dimension should be split into.

After the target op has been split along the given dimension (for example
using multiway split), each chunk can be tiled with the corresponding tile
size in the `tile_sizes` list generated as a result of this op.

Specifying the output type as !transform.param<i64> will cause `tile_sizes`
and `chunk_sizes` to be computed statically and not dynamically.
}];

let arguments = (ins TransformHandleTypeInterface:$target,
ConfinedAttr<I64Attr, [IntNonNegative]>:$dimension,
ConfinedAttr<I64Attr, [IntNonNegative]>:$target_size);
let results = (outs TransformAnyParamTypeOrAnyHandle:$tile_sizes,
TransformAnyParamTypeOrAnyHandle:$chunk_sizes);
let hasVerifier = 1;
let assemblyFormat =
"$target attr-dict `:` custom<ContinuousTileSizeTypes>("
"type($target), type($tile_sizes), type($chunk_sizes))";

}

//===----------------------------------------------------------------------===//
// TileUsingForOp
//===----------------------------------------------------------------------===//
Expand Down
21 changes: 21 additions & 0 deletions mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
Original file line number Diff line number Diff line change
Expand Up @@ -801,6 +801,15 @@ struct MultiSizeSpecificationBase {
/// Number of tiles associated with each size.
T lowTripCount, highTripCount;
};

template <typename T>
struct ContinuousTileSizeSpecificationBase {
/// Tile sizes.
SmallVector<T> tileSizes;
/// Number of tiles associated with each size.
SmallVector<T> tripCounts;
};

} // namespace detail

/// A description of a multi-size tiling comprising tile sizes and numbers of
Expand All @@ -811,6 +820,11 @@ struct MultiSizeSpecification
struct StaticMultiSizeSpecification
: public detail::MultiSizeSpecificationBase<int64_t> {};

struct ContinuousTileSizeSpecification
: public detail::ContinuousTileSizeSpecificationBase<Value> {};
struct StaticContinuousTileSizeSpecification
: public detail::ContinuousTileSizeSpecificationBase<int64_t> {};

/// Emits the IR computing the multi-sized tiling specification with two tile
/// sizes not exceeding `targetSize`, each divisible by `sizeDivisor`, such
/// that there exist numbers of tiles with these sizes that fully cover the
Expand Down Expand Up @@ -846,6 +860,13 @@ FailureOr<StaticMultiSizeSpecification>
computeStaticMultiTileSizes(LinalgOp op, unsigned dimension, int64_t targetSize,
int64_t divisor);

FailureOr<StaticContinuousTileSizeSpecification>
computeStaticContinuousTileSizes(LinalgOp op, unsigned dimension,
unsigned targetSize);
FailureOr<ContinuousTileSizeSpecification>
computeContinuousTileSizes(OpBuilder &builder, TilingInterface op,
unsigned dimension, OpFoldResult targetSize,
bool emitAssertions);
/// Rewrite a TilingInterface `op` to a tiled `scf.forall`, applying
/// tiling by `numThreads`.
/// If non-empty, the `mapping` is added as an attribute to the
Expand Down
10 changes: 7 additions & 3 deletions mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -624,7 +624,10 @@ def ForeachOp : TransformDialectOp<"foreach",
Each iteration gets executed by co-indexing the payloads of the arguments
and mapping the body's arguments to these tuples, as though iterating over
the zipped together `targets`. As such, in each iteration, the size of the
payload of each of the body's block arguments is exactly one.
payload of each of the body's block arguments is exactly one. The attribute
`zip_shortest` can be used if the targets vary in their number of payloads;
this will limit the iterations to only the number of payloads found in the
shortest target.

This op always reads the target handles. Furthermore, it consumes a handle
if there is a transform op in the body that consumes the corresponding
Expand All @@ -645,11 +648,12 @@ def ForeachOp : TransformDialectOp<"foreach",
rollback capabilities.
}];

let arguments = (ins Variadic<Transform_AnyHandleOrParamType>:$targets);
let arguments = (ins Variadic<Transform_AnyHandleOrParamType>:$targets,
UnitAttr:$zip_shortest);
let results = (outs Variadic<Transform_AnyHandleOrParamType>:$results);
let regions = (region SizedRegion<1>:$body);
let assemblyFormat =
"$targets `:` type($targets) (`->` type($results)^)? $body attr-dict";
"$targets attr-dict `:` type($targets) (`->` type($results)^)? $body";
let hasVerifier = 1;

let extraClassDeclaration = [{
Expand Down
Loading

0 comments on commit a9efcbf

Please sign in to comment.