Skip to content

Commit

Permalink
[mlir][vector] Add a new TD Op for patterns leveraging ShapeCastOp (l…
Browse files Browse the repository at this point in the history
…lvm#110525)

Adds a new Transform Dialect Op that collects patters for dropping unit
dims from various Ops:
  * `transform.apply_patterns.vector.drop_unit_dims_with_shape_cast`.

It excludes patterns for vector.transfer Ops - these are collected
under:
  * `apply_patterns.vector.rank_reducing_subview_patterns`,

and use ShapeCastOp _and_ SubviewOp to reduce the rank (and to eliminate
unit dims).

This new TD Ops allows us to test the "ShapeCast folder" pattern in
isolation. I've extracted the only test that I could find for that
folder from "vector-transforms.mlir" and moved it to a dedicated file:
"shape-cast-folder.mlir". I also added a test case with scalable
vectors.

Changes in VectorTransforms.cpp are not needed (added a comment with
a TODO + ordered the patterns alphabetically). I am Including them here
to avoid a separate PR.
  • Loading branch information
banach-space authored and VitaNuo committed Oct 2, 2024
1 parent 9dad082 commit 17f6d51
Show file tree
Hide file tree
Showing 5 changed files with 66 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,22 @@ def ApplyRankReducingSubviewPatternsOp : Op<Transform_Dialect,
let assemblyFormat = "attr-dict";
}

def ApplyDropUnitDimWithShapeCastPatternsOp : Op<Transform_Dialect,
"apply_patterns.vector.drop_unit_dims_with_shape_cast",
[DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
let description = [{
Apply vector patterns to fold unit dims with vector.shape_cast Ops:
- DropUnitDimFromElementwiseOps
- DropUnitDimsFromScfForOp
- DropUnitDimsFromTransposeOp

Excludes patterns for vector.transfer Ops. This is complemented by
shape_cast folding patterns.
}];

let assemblyFormat = "attr-dict";
}

def ApplyTransferPermutationPatternsOp : Op<Transform_Dialect,
"apply_patterns.vector.transfer_permutation_patterns",
[DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
Expand Down
5 changes: 5 additions & 0 deletions mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,11 @@ void transform::ApplyTransferPermutationPatternsOp::populatePatterns(
vector::populateVectorTransferPermutationMapLoweringPatterns(patterns);
}

void transform::ApplyDropUnitDimWithShapeCastPatternsOp::populatePatterns(
RewritePatternSet &patterns) {
vector::populateDropUnitDimWithShapeCastPatterns(patterns);
}

void transform::ApplyLowerBitCastPatternsOp::populatePatterns(
RewritePatternSet &patterns) {
vector::populateVectorBitCastLoweringPatterns(patterns);
Expand Down
9 changes: 7 additions & 2 deletions mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2056,8 +2056,13 @@ void mlir::vector::populateShapeCastFoldingPatterns(RewritePatternSet &patterns,

void mlir::vector::populateDropUnitDimWithShapeCastPatterns(
RewritePatternSet &patterns, PatternBenefit benefit) {
patterns.add<DropUnitDimFromElementwiseOps, DropUnitDimsFromTransposeOp,
ShapeCastOpFolder, DropUnitDimsFromScfForOp>(
// TODO: Consider either:
// * including DropInnerMostUnitDimsTransferRead and
// DropInnerMostUnitDimsTransferWrite, or
// * better naming to distinguish this and
// populateVectorTransferCollapseInnerMostContiguousDimsPatterns.
patterns.add<DropUnitDimFromElementwiseOps, DropUnitDimsFromScfForOp,
DropUnitDimsFromTransposeOp, ShapeCastOpFolder>(
patterns.getContext(), benefit);
}

Expand Down
38 changes: 38 additions & 0 deletions mlir/test/Dialect/Vector/shape-cast-folder.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
// RUN: mlir-opt %s --transform-interpreter --split-input-file | FileCheck %s

///----------------------------------------------------------------------------------------
/// [Pattern: ShapeCastOpFolder]
///----------------------------------------------------------------------------------------

// CHECK-LABEL: func @fixed_width
// CHECK-SAME: %[[A0:.*0]]: vector<2x4xf32>
// CHECK-NOT: vector.shape_cast
// CHECK: return %[[A0]] : vector<2x4xf32>
func.func @fixed_width(%arg0 : vector<2x4xf32>) -> vector<2x4xf32> {
%0 = vector.shape_cast %arg0 : vector<2x4xf32> to vector<8xf32>
%1 = vector.shape_cast %0 : vector<8xf32> to vector<2x4xf32>
return %1 : vector<2x4xf32>
}

// CHECK-LABEL: func @scalable
// CHECK-SAME: %[[A0:.*0]]: vector<2x[4]xf32>
// CHECK-NOT: vector.shape_cast
// CHECK: return %[[A0]] : vector<2x[4]xf32>
func.func @scalable(%arg0 : vector<2x[4]xf32>) -> vector<2x[4]xf32> {
%0 = vector.shape_cast %arg0 : vector<2x[4]xf32> to vector<[8]xf32>
%1 = vector.shape_cast %0 : vector<[8]xf32> to vector<2x[4]xf32>
return %1 : vector<2x[4]xf32>
}

// ============================================================================
// TD sequence
// ============================================================================
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%root : !transform.any_op {transform.readonly}) {
%func_op = transform.structured.match ops{["func.func"]} in %root : (!transform.any_op) -> !transform.op<"func.func">
transform.apply_patterns to %func_op {
transform.apply_patterns.vector.drop_unit_dims_with_shape_cast
} : !transform.op<"func.func">
transform.yield
}
}
9 changes: 0 additions & 9 deletions mlir/test/Dialect/Vector/vector-transforms.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -184,15 +184,6 @@ func.func @vector_transfers(%arg0: index, %arg1: index) {
return
}

// CHECK-LABEL: func @cancelling_shape_cast_ops
// CHECK-SAME: %[[A0:.*0]]: vector<2x4xf32>
// CHECK: return %[[A0]] : vector<2x4xf32>
func.func @cancelling_shape_cast_ops(%arg0 : vector<2x4xf32>) -> vector<2x4xf32> {
%0 = vector.shape_cast %arg0 : vector<2x4xf32> to vector<8xf32>
%1 = vector.shape_cast %0 : vector<8xf32> to vector<2x4xf32>
return %1 : vector<2x4xf32>
}

// CHECK-LABEL: func @elementwise_unroll
// CHECK-SAME: (%[[ARG0:.*]]: memref<4x4xf32>, %[[ARG1:.*]]: memref<4x4xf32>)
// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
Expand Down

0 comments on commit 17f6d51

Please sign in to comment.