diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp index 94e0ed319cae83..836dcb8f329e70 100644 --- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp @@ -1080,7 +1080,37 @@ struct DimOfMemRefReshape : public OpRewritePattern { auto reshape = dim.getSource().getDefiningOp(); if (!reshape) - return failure(); + return rewriter.notifyMatchFailure( + dim, "Dim op is not defined by a reshape op."); + + // dim of a memref reshape can be folded if dim.getIndex() dominates the + // reshape. Instead of using `DominanceInfo` (which is usually costly) we + // cheaply check that either of the following conditions hold: + // 1. dim.getIndex() is defined in the same block as reshape but before + // reshape. + // 2. dim.getIndex() is defined in a parent block of + // reshape. + + // Check condition 1 + if (dim.getIndex().getParentBlock() == reshape->getBlock()) { + if (auto *definingOp = dim.getIndex().getDefiningOp()) { + if (reshape->isBeforeInBlock(definingOp)) { + return rewriter.notifyMatchFailure( + dim, + "dim.getIndex is not defined before reshape in the same block."); + } + } // else dim.getIndex is a block argument to reshape->getBlock and + // dominates reshape + } // Check condition 2 + else if (dim->getBlock() != reshape->getBlock() && + !dim.getIndex().getParentRegion()->isProperAncestor( + reshape->getParentRegion())) { + // If dim and reshape are in the same block but dim.getIndex() isn't, we + // already know dim.getIndex() dominates reshape without calling + // `isProperAncestor` + return rewriter.notifyMatchFailure( + dim, "dim.getIndex does not dominate reshape."); + } // Place the load directly after the reshape to ensure that the shape memref // was not mutated. diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp index a854da466c3130..dc8843aa4e1e13 100644 --- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp @@ -824,11 +824,37 @@ struct DimOfDestStyleOp : public OpRewritePattern { return success(); } }; + +/// Fold dim of a tensor reshape operation to a extract into the reshape's shape +/// operand. +struct DimOfReshapeOp : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(DimOp dim, + PatternRewriter &rewriter) const override { + auto reshape = dim.getSource().getDefiningOp(); + + if (!reshape) + return failure(); + + // Since tensors are immutable we don't need to worry about where to place + // the extract call + rewriter.setInsertionPointAfter(dim); + Location loc = dim.getLoc(); + Value extract = + rewriter.create(loc, reshape.getShape(), dim.getIndex()); + if (extract.getType() != dim.getType()) + extract = + rewriter.create(loc, dim.getType(), extract); + rewriter.replaceOp(dim, extract); + return success(); + } +}; } // namespace void DimOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { - results.add(context); + results.add(context); } //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir index b1e92e54d561da..506ed1f1c10b10 100644 --- a/mlir/test/Dialect/MemRef/canonicalize.mlir +++ b/mlir/test/Dialect/MemRef/canonicalize.mlir @@ -313,6 +313,59 @@ func.func @dim_of_memref_reshape_i32(%arg0: memref<*xf32>, %arg1: memref) // ----- +// Test case: memref.dim(memref.reshape %v %shp, %idx) -> memref.load %shp[%idx] +// CHECK-LABEL: func @dim_of_memref_reshape_block_arg_index( +// CHECK-SAME: %[[MEM:[0-9a-z]+]]: memref<*xf32>, +// CHECK-SAME: %[[SHP:[0-9a-z]+]]: memref, +// CHECK-SAME: %[[IDX:[0-9a-z]+]]: index +// CHECK-NEXT: %[[DIM:.*]] = memref.load %[[SHP]][%[[IDX]]] +// CHECK-NOT: memref.dim +// CHECK: return %[[DIM]] : index +func.func @dim_of_memref_reshape_block_arg_index(%arg0: memref<*xf32>, %arg1: memref, %arg2: index) -> index { + %reshape = memref.reshape %arg0(%arg1) : (memref<*xf32>, memref) -> memref<*xf32> + %dim = memref.dim %reshape, %arg2 : memref<*xf32> + return %dim : index +} + +// ----- + +// Test case: memref.dim(memref.reshape %v %shp, %idx) is not folded into memref.load %shp[%idx] +// CHECK-LABEL: func @dim_of_memref_reshape_for( +// CHECK: memref.reshape +// CHECK: memref.dim +// CHECK-NOT: memref.load +func.func @dim_of_memref_reshape_for( %arg0: memref<*xf32>, %arg1: memref) -> index { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c4 = arith.constant 4 : index + + %0 = memref.reshape %arg0(%arg1) : (memref<*xf32>, memref) -> memref<*xf32> + + %1 = scf.for %arg2 = %c0 to %c4 step %c1 iter_args(%arg3 = %c1) -> (index) { + %2 = memref.dim %0, %arg2 : memref<*xf32> + %3 = arith.muli %arg3, %2 : index + scf.yield %3 : index + } + return %1 : index +} + +// ----- + +// Test case: memref.dim(memref.reshape %v %shp, %idx) is not folded into memref.load %shp[%idx] +// CHECK-LABEL: func @dim_of_memref_reshape_undominated( +// CHECK: memref.reshape +// CHECK: memref.dim +// CHECK-NOT: memref.load +func.func @dim_of_memref_reshape_undominated(%arg0: memref<*xf32>, %arg1: memref, %arg2: index) -> index { + %c4 = arith.constant 4 : index + %reshape = memref.reshape %arg0(%arg1) : (memref<*xf32>, memref) -> memref<*xf32> + %0 = arith.muli %arg2, %c4 : index + %dim = memref.dim %reshape, %0 : memref<*xf32> + return %dim : index + } + +// ----- + // CHECK-LABEL: func @alloc_const_fold func.func @alloc_const_fold() -> memref { // CHECK-NEXT: memref.alloc() : memref<4xf32> diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir index 70f5d61bd802fd..e5374f031be553 100644 --- a/mlir/test/Dialect/Tensor/canonicalize.mlir +++ b/mlir/test/Dialect/Tensor/canonicalize.mlir @@ -2287,3 +2287,83 @@ func.func @infer_and_fold_pack_unpack_same_tiles(%t: tensor<10x20x4x4xf32>) -> t // CHECK-LABEL: func.func @infer_and_fold_pack_unpack_same_tiles // CHECK-SAME: %[[SRC:[0-9a-zA-Z]+]] // CHECK: return %[[SRC]] + +// ----- + +// Test case: Folding of tensor.dim(tensor.reshape %v %shp, %idx) -> tensor.extract %shp[%idx] +// CHECK-LABEL: func @dim_of_reshape( +// CHECK-SAME: %[[MEM:[0-9a-z]+]]: tensor<*xf32>, +// CHECK-SAME: %[[SHP:[0-9a-z]+]]: tensor +// CHECK-NEXT: %[[IDX:.*]] = arith.constant 3 +// CHECK-NEXT: %[[DIM:.*]] = tensor.extract %[[SHP]][%[[IDX]]] +// CHECK-NOT: tensor.store +// CHECK-NOT: tensor.dim +// CHECK-NOT: tensor.reshape +// CHECK: return %[[DIM]] : index +func.func @dim_of_reshape(%arg0: tensor<*xf32>, %arg1: tensor) + -> index { + %c3 = arith.constant 3 : index + %0 = tensor.reshape %arg0(%arg1) + : (tensor<*xf32>, tensor) -> tensor<*xf32> + // Update the shape to test that the load ends up in the right place. + tensor.insert %c3 into %arg1[%c3] : tensor + %1 = tensor.dim %0, %c3 : tensor<*xf32> + return %1 : index +} + +// ----- + +// Test case: Folding of tensor.dim(tensor.reshape %v %shp, %idx) -> tensor.extract %shp[%idx] +// CHECK-LABEL: func @dim_of_reshape_i32( +// CHECK: tensor.extract +// CHECK-NEXT: %[[CAST:.*]] = arith.index_cast +// CHECK-NOT: tensor.dim +// CHECK-NOT: tensor.reshape +// CHECK: return %[[CAST]] : index +func.func @dim_of_reshape_i32(%arg0: tensor<*xf32>, %arg1: tensor) + -> index { + %c3 = arith.constant 3 : index + %0 = tensor.reshape %arg0(%arg1) + : (tensor<*xf32>, tensor) -> tensor<*xf32> + %1 = tensor.dim %0, %c3 : tensor<*xf32> + return %1 : index +} + +// ----- + +// Test case: tensor.dim(tensor.reshape %v %shp, %idx) is folded into tensor.extract %shp[%idx] +// CHECK-LABEL: func @dim_of_reshape_for( +// CHECK: scf.for +// CHECK-NEXT: tensor.extract +// CHECK-NOT: tensor.dim +// CHECK-NOT: tensor.reshape +func.func @dim_of_reshape_for( %arg0: tensor<*xf32>, %arg1: tensor) -> index { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c4 = arith.constant 4 : index + + %0 = tensor.reshape %arg0(%arg1) : (tensor<*xf32>, tensor) -> tensor<*xf32> + + %1 = scf.for %arg2 = %c0 to %c4 step %c1 iter_args(%arg3 = %c1) -> (index) { + %2 = tensor.dim %0, %arg2 : tensor<*xf32> + %3 = arith.muli %arg3, %2 : index + scf.yield %3 : index + } + return %1 : index +} + +// ----- + +// Test case: tensor.dim(tensor.reshape %v %shp, %idx) is folded into tensor.extract %shp[%idx] +// CHECK-LABEL: func @dim_of_reshape_undominated( +// CHECK: arith.muli +// CHECK-NEXT: tensor.extract +// CHECK-NOT: tensor.dim +// CHECK-NOT: tensor.reshape +func.func @dim_of_reshape_undominated(%arg0: tensor<*xf32>, %arg1: tensor, %arg2: index) -> index { + %c4 = arith.constant 4 : index + %reshape = tensor.reshape %arg0(%arg1) : (tensor<*xf32>, tensor) -> tensor<*xf32> + %0 = arith.muli %arg2, %c4 : index + %dim = tensor.dim %reshape, %0 : tensor<*xf32> + return %dim : index + }