diff --git a/iree/compiler/Codegen/Common/FlattenMemRefSubspanPass.cpp b/iree/compiler/Codegen/Common/FlattenMemRefSubspanPass.cpp index 7f8626034d18..fb1fe5b03d83 100644 --- a/iree/compiler/Codegen/Common/FlattenMemRefSubspanPass.cpp +++ b/iree/compiler/Codegen/Common/FlattenMemRefSubspanPass.cpp @@ -231,6 +231,10 @@ struct FlattenBindingSubspan final } }; +//===----------------------------------------------------------------------===// +// Linearizing Patterns +//===----------------------------------------------------------------------===// + /// Generates IR to perform index linearization with the given `indices` /// indexing into the given memref `sourceValue`. static Value linearizeIndices(Value sourceValue, ValueRange indices, @@ -312,7 +316,7 @@ struct LinearizeLoadIndices final : public OpConversionPattern { ConversionPatternRewriter &rewriter) const override { if (!isRankOneMemRef(adaptor.memref().getType())) { return rewriter.notifyMatchFailure( - loadOp, "expected converted memref of rank <= 1"); + loadOp, "expected converted memref of rank == 1"); } Value linearIndex = linearizeIndices(loadOp.memref(), loadOp.getIndices(), @@ -337,7 +341,7 @@ struct LinearizeStoreIndices final ConversionPatternRewriter &rewriter) const override { if (!isRankOneMemRef(adaptor.memref().getType())) { return rewriter.notifyMatchFailure( - storeOp, "expected converted memref of rank <= 1"); + storeOp, "expected converted memref of rank == 1"); } Value linearIndex = linearizeIndices(storeOp.memref(), storeOp.getIndices(), @@ -366,7 +370,7 @@ struct LinearizeTransferReadIndices final } if (!isRankOneMemRef(adaptor.source().getType())) { return rewriter.notifyMatchFailure( - transferReadOp, "expected converted memref of rank <= 1"); + transferReadOp, "expected converted memref of rank == 1"); } Value linearIndex = linearizeIndices(transferReadOp.source(), transferReadOp.indices(), @@ -397,7 +401,7 @@ struct LinearizeTransferWriteIndices final } if (!isRankOneMemRef(adaptor.source().getType())) { return rewriter.notifyMatchFailure( - transferWriteOp, "expected converted memref of rank <= 1"); + transferWriteOp, "expected converted memref of rank == 1"); } Value linearIndex = linearizeIndices(transferWriteOp.source(), transferWriteOp.indices(), @@ -429,7 +433,7 @@ struct AdjustConversionCast final if (!isRankOneMemRef(input.getType())) { return rewriter.notifyMatchFailure( - castOp, "expected converted memref of rank <= 1"); + castOp, "expected converted memref of rank == 1"); } rewriter.replaceOpWithNewOp( castOp, castOp.getResultTypes(), input); @@ -441,6 +445,24 @@ struct AdjustConversionCast final // Folding Patterns //===----------------------------------------------------------------------===// +/// Removes MemRef reshape ops given that we'll linearize both the source and +/// target type to the same one. +template +struct FoldMemRefReshape final : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite( + ReshapeOpTy op, typename ReshapeOpTy::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + if (!isRankOneMemRef(adaptor.src().getType())) { + return rewriter.notifyMatchFailure( + op, "expected converted memref of rank == 1"); + } + rewriter.replaceOp(op, adaptor.src()); + return success(); + }; +}; + /// Returns the number of bytes of the given `type`. Returns llvm::None if /// cannot deduce. /// @@ -551,15 +573,17 @@ struct FlattenMemRefSubspanPass FlattenGlobal, FlattenGetGlobal, FlattenBindingSubspan, LinearizeLoadIndices, LinearizeStoreIndices, LinearizeTransferReadIndices, LinearizeTransferWriteIndices, - AdjustConversionCast>(typeConverter, &context); + AdjustConversionCast, FoldMemRefReshape, + FoldMemRefReshape>(typeConverter, &context); ConversionTarget target(context); target.markUnknownOpDynamicallyLegal([](Operation *) { return true; }); - target.addDynamicallyLegalOp([](Operation *op) { - return isRankOneMemRef(op->getResultTypes().front()); - }); + target.addDynamicallyLegalOp< + IREE::HAL::InterfaceBindingSubspanOp, memref::AllocaOp, memref::AllocOp, + memref::CollapseShapeOp, memref::ExpandShapeOp, memref::GetGlobalOp>( + [](Operation *op) { + return isRankOneMemRef(op->getResultTypes().front()); + }); target.addDynamicallyLegalOp( [](memref::GlobalOp op) { return isRankOneMemRef(op.type()); }); target.addDynamicallyLegalOp([](memref::LoadOp loadOp) { diff --git a/iree/compiler/Codegen/Common/test/flatten_memref_subspan.mlir b/iree/compiler/Codegen/Common/test/flatten_memref_subspan.mlir index a76fca26dd53..fdf83331571d 100644 --- a/iree/compiler/Codegen/Common/test/flatten_memref_subspan.mlir +++ b/iree/compiler/Codegen/Common/test/flatten_memref_subspan.mlir @@ -331,3 +331,47 @@ hal.interface private @io { // CHECK: %[[LOAD:.+]] = memref.load %[[SPAN0]][%[[INDEX0]]] : memref // CHECK: %[[INDEX1:.+]] = affine.apply #[[MAP]]()[%[[OFFSET]]] // CHECK: memref.store %[[LOAD]], %[[SPAN1]][%[[INDEX1]]] : memref + +// ----- + +func @collapse_shape(%offset : index, %i0 : index, %i1 : index) -> f32 { + %subspan = hal.interface.binding.subspan @io::@s0b0_ro_constant[%offset] : memref<4x5x6x7xf32> + %collapse = memref.collapse_shape %subspan[[0, 1], [2, 3]] : memref<4x5x6x7xf32> into memref<20x42xf32> + %value = memref.load %collapse[%i0, %i1] : memref<20x42xf32> + return %value : f32 +} + +hal.interface @io attributes {sym_visibility = "private"} { + hal.interface.binding @s0b0_ro_constant, set=0, binding=0, type="StorageBuffer", access="Read" +} + +// CHECK: #[[MAP:.+]] = affine_map<()[s0, s1, s2] -> (s0 * 42 + s1 + s2 floordiv 4)> +// CHECK: func @collapse_shape +// CHECK-SAME: (%[[OFFSET:.+]]: index, %[[I0:.+]]: index, %[[I1:.+]]: index) +// CHECK: %[[C0:.+]] = arith.constant 0 : index +// CHECK: %[[SIZE:.+]] = arith.constant 840 : index +// CHECK: %[[SUBSPAN:.+]] = hal.interface.binding.subspan @io::@s0b0_ro_constant[%[[C0]]] : memref{%[[SIZE]]} +// CHECK: %[[INDEX:.+]] = affine.apply #[[MAP]]()[%[[I0]], %[[I1]], %[[OFFSET]]] +// CHECK: memref.load %[[SUBSPAN]][%[[INDEX]]] + +// ----- + +func @expand_shape(%offset : index, %i0: index, %i1: index, %i2: index, %i3: index) -> f32 { + %subspan = hal.interface.binding.subspan @io::@s0b0_ro_constant[%offset] : memref<20x42xf32> + %expand = memref.expand_shape %subspan[[0, 1], [2, 3]] : memref<20x42xf32> into memref<4x5x6x7xf32> + %value = memref.load %expand[%i0, %i1, %i2, %i3] : memref<4x5x6x7xf32> + return %value : f32 +} + +hal.interface @io attributes {sym_visibility = "private"} { + hal.interface.binding @s0b0_ro_constant, set=0, binding=0, type="StorageBuffer", access="Read" +} + +// CHECK: #[[MAP:.+]] = affine_map<()[s0, s1, s2, s3, s4] -> (s0 * 210 + s1 * 42 + s2 * 7 + s3 + s4 floordiv 4)> +// CHECK: func @expand_shape +// CHECK-SAME: (%[[OFFSET:.+]]: index, %[[I0:.+]]: index, %[[I1:.+]]: index, %[[I2:.+]]: index, %[[I3:.+]]: index) +// CHECK: %[[C0:.+]] = arith.constant 0 : index +// CHECK: %[[SIZE:.+]] = arith.constant 840 : index +// CHECK: %[[SUBSPAN:.+]] = hal.interface.binding.subspan @io::@s0b0_ro_constant[%[[C0]]] : memref{%[[SIZE]]} +// CHECK: %[[INDEX:.+]] = affine.apply #[[MAP]]()[%[[I0]], %[[I1]], %[[I2]], %[[I3]], %[[OFFSET]]] +// CHECK: memref.load %[[SUBSPAN]][%[[INDEX]]]