Skip to content

Commit

Permalink
Handle reshape ops in FlattenMemRefSubspanPass (#7649)
Browse files Browse the repository at this point in the history
These are essentially no-op given that we'll flatten both the
source and target memref types.
  • Loading branch information
antiagainst committed Nov 12, 2021
1 parent f363d32 commit bd2001a
Show file tree
Hide file tree
Showing 2 changed files with 79 additions and 11 deletions.
46 changes: 35 additions & 11 deletions iree/compiler/Codegen/Common/FlattenMemRefSubspanPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,10 @@ struct FlattenBindingSubspan final
}
};

//===----------------------------------------------------------------------===//
// Linearizing Patterns
//===----------------------------------------------------------------------===//

/// Generates IR to perform index linearization with the given `indices`
/// indexing into the given memref `sourceValue`.
static Value linearizeIndices(Value sourceValue, ValueRange indices,
Expand Down Expand Up @@ -312,7 +316,7 @@ struct LinearizeLoadIndices final : public OpConversionPattern<memref::LoadOp> {
ConversionPatternRewriter &rewriter) const override {
if (!isRankOneMemRef(adaptor.memref().getType())) {
return rewriter.notifyMatchFailure(
loadOp, "expected converted memref of rank <= 1");
loadOp, "expected converted memref of rank == 1");
}

Value linearIndex = linearizeIndices(loadOp.memref(), loadOp.getIndices(),
Expand All @@ -337,7 +341,7 @@ struct LinearizeStoreIndices final
ConversionPatternRewriter &rewriter) const override {
if (!isRankOneMemRef(adaptor.memref().getType())) {
return rewriter.notifyMatchFailure(
storeOp, "expected converted memref of rank <= 1");
storeOp, "expected converted memref of rank == 1");
}

Value linearIndex = linearizeIndices(storeOp.memref(), storeOp.getIndices(),
Expand Down Expand Up @@ -366,7 +370,7 @@ struct LinearizeTransferReadIndices final
}
if (!isRankOneMemRef(adaptor.source().getType())) {
return rewriter.notifyMatchFailure(
transferReadOp, "expected converted memref of rank <= 1");
transferReadOp, "expected converted memref of rank == 1");
}
Value linearIndex =
linearizeIndices(transferReadOp.source(), transferReadOp.indices(),
Expand Down Expand Up @@ -397,7 +401,7 @@ struct LinearizeTransferWriteIndices final
}
if (!isRankOneMemRef(adaptor.source().getType())) {
return rewriter.notifyMatchFailure(
transferWriteOp, "expected converted memref of rank <= 1");
transferWriteOp, "expected converted memref of rank == 1");
}
Value linearIndex =
linearizeIndices(transferWriteOp.source(), transferWriteOp.indices(),
Expand Down Expand Up @@ -429,7 +433,7 @@ struct AdjustConversionCast final

if (!isRankOneMemRef(input.getType())) {
return rewriter.notifyMatchFailure(
castOp, "expected converted memref of rank <= 1");
castOp, "expected converted memref of rank == 1");
}
rewriter.replaceOpWithNewOp<UnrealizedConversionCastOp>(
castOp, castOp.getResultTypes(), input);
Expand All @@ -441,6 +445,24 @@ struct AdjustConversionCast final
// Folding Patterns
//===----------------------------------------------------------------------===//

/// Removes MemRef reshape ops given that we'll linearize both the source and
/// target type to the same one.
template <typename ReshapeOpTy>
struct FoldMemRefReshape final : public OpConversionPattern<ReshapeOpTy> {
using OpConversionPattern<ReshapeOpTy>::OpConversionPattern;

LogicalResult matchAndRewrite(
ReshapeOpTy op, typename ReshapeOpTy::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
if (!isRankOneMemRef(adaptor.src().getType())) {
return rewriter.notifyMatchFailure(
op, "expected converted memref of rank == 1");
}
rewriter.replaceOp(op, adaptor.src());
return success();
};
};

/// Returns the number of bytes of the given `type`. Returns llvm::None if
/// cannot deduce.
///
Expand Down Expand Up @@ -551,15 +573,17 @@ struct FlattenMemRefSubspanPass
FlattenGlobal, FlattenGetGlobal, FlattenBindingSubspan,
LinearizeLoadIndices, LinearizeStoreIndices,
LinearizeTransferReadIndices, LinearizeTransferWriteIndices,
AdjustConversionCast>(typeConverter, &context);
AdjustConversionCast, FoldMemRefReshape<memref::CollapseShapeOp>,
FoldMemRefReshape<memref::ExpandShapeOp>>(typeConverter, &context);

ConversionTarget target(context);
target.markUnknownOpDynamicallyLegal([](Operation *) { return true; });
target.addDynamicallyLegalOp<IREE::HAL::InterfaceBindingSubspanOp,
memref::AllocaOp, memref::AllocOp,
memref::GetGlobalOp>([](Operation *op) {
return isRankOneMemRef(op->getResultTypes().front());
});
target.addDynamicallyLegalOp<
IREE::HAL::InterfaceBindingSubspanOp, memref::AllocaOp, memref::AllocOp,
memref::CollapseShapeOp, memref::ExpandShapeOp, memref::GetGlobalOp>(
[](Operation *op) {
return isRankOneMemRef(op->getResultTypes().front());
});
target.addDynamicallyLegalOp<memref::GlobalOp>(
[](memref::GlobalOp op) { return isRankOneMemRef(op.type()); });
target.addDynamicallyLegalOp<memref::LoadOp>([](memref::LoadOp loadOp) {
Expand Down
44 changes: 44 additions & 0 deletions iree/compiler/Codegen/Common/test/flatten_memref_subspan.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -331,3 +331,47 @@ hal.interface private @io {
// CHECK: %[[LOAD:.+]] = memref.load %[[SPAN0]][%[[INDEX0]]] : memref<?xf32>
// CHECK: %[[INDEX1:.+]] = affine.apply #[[MAP]]()[%[[OFFSET]]]
// CHECK: memref.store %[[LOAD]], %[[SPAN1]][%[[INDEX1]]] : memref<?xf32>

// -----

func @collapse_shape(%offset : index, %i0 : index, %i1 : index) -> f32 {
%subspan = hal.interface.binding.subspan @io::@s0b0_ro_constant[%offset] : memref<4x5x6x7xf32>
%collapse = memref.collapse_shape %subspan[[0, 1], [2, 3]] : memref<4x5x6x7xf32> into memref<20x42xf32>
%value = memref.load %collapse[%i0, %i1] : memref<20x42xf32>
return %value : f32
}

hal.interface @io attributes {sym_visibility = "private"} {
hal.interface.binding @s0b0_ro_constant, set=0, binding=0, type="StorageBuffer", access="Read"
}

// CHECK: #[[MAP:.+]] = affine_map<()[s0, s1, s2] -> (s0 * 42 + s1 + s2 floordiv 4)>
// CHECK: func @collapse_shape
// CHECK-SAME: (%[[OFFSET:.+]]: index, %[[I0:.+]]: index, %[[I1:.+]]: index)
// CHECK: %[[C0:.+]] = arith.constant 0 : index
// CHECK: %[[SIZE:.+]] = arith.constant 840 : index
// CHECK: %[[SUBSPAN:.+]] = hal.interface.binding.subspan @io::@s0b0_ro_constant[%[[C0]]] : memref<?xf32>{%[[SIZE]]}
// CHECK: %[[INDEX:.+]] = affine.apply #[[MAP]]()[%[[I0]], %[[I1]], %[[OFFSET]]]
// CHECK: memref.load %[[SUBSPAN]][%[[INDEX]]]

// -----

func @expand_shape(%offset : index, %i0: index, %i1: index, %i2: index, %i3: index) -> f32 {
%subspan = hal.interface.binding.subspan @io::@s0b0_ro_constant[%offset] : memref<20x42xf32>
%expand = memref.expand_shape %subspan[[0, 1], [2, 3]] : memref<20x42xf32> into memref<4x5x6x7xf32>
%value = memref.load %expand[%i0, %i1, %i2, %i3] : memref<4x5x6x7xf32>
return %value : f32
}

hal.interface @io attributes {sym_visibility = "private"} {
hal.interface.binding @s0b0_ro_constant, set=0, binding=0, type="StorageBuffer", access="Read"
}

// CHECK: #[[MAP:.+]] = affine_map<()[s0, s1, s2, s3, s4] -> (s0 * 210 + s1 * 42 + s2 * 7 + s3 + s4 floordiv 4)>
// CHECK: func @expand_shape
// CHECK-SAME: (%[[OFFSET:.+]]: index, %[[I0:.+]]: index, %[[I1:.+]]: index, %[[I2:.+]]: index, %[[I3:.+]]: index)
// CHECK: %[[C0:.+]] = arith.constant 0 : index
// CHECK: %[[SIZE:.+]] = arith.constant 840 : index
// CHECK: %[[SUBSPAN:.+]] = hal.interface.binding.subspan @io::@s0b0_ro_constant[%[[C0]]] : memref<?xf32>{%[[SIZE]]}
// CHECK: %[[INDEX:.+]] = affine.apply #[[MAP]]()[%[[I0]], %[[I1]], %[[I2]], %[[I3]], %[[OFFSET]]]
// CHECK: memref.load %[[SUBSPAN]][%[[INDEX]]]

0 comments on commit bd2001a

Please sign in to comment.