Skip to content

Commit

Permalink
[mlir] [memref] add more checks to the memref.reinterpret_cast (#112669)
Browse files Browse the repository at this point in the history
Operation memref.reinterpret_cast was accept input like:

%out = memref.reinterpret_cast %in to offset: [%offset], sizes: [10],
strides: [1]
         : memref<?xf32> to memref<10xf32>

A problem arises: while lowering, the true offset of %out is %offset,
but its data type indicates an offset of 0. Permitting this
inconsistency can result in incorrect outcomes, as certain pass might
erroneously extract the offset from the data type of %out.

This patch fixes this by enforcing that the return value's data type
aligns
with the input parameter.
  • Loading branch information
cxy-1993 authored Oct 26, 2024
1 parent 5f7bad0 commit 889b67c
Show file tree
Hide file tree
Showing 9 changed files with 81 additions and 63 deletions.
13 changes: 12 additions & 1 deletion mlir/lib/Dialect/GPU/Transforms/DecomposeMemRefs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,17 @@ namespace mlir {

using namespace mlir;

static MemRefType inferCastResultType(Value source, OpFoldResult offset) {
auto sourceType = cast<BaseMemRefType>(source.getType());
SmallVector<int64_t> staticOffsets;
SmallVector<Value> dynamicOffsets;
dispatchIndexOpFoldResults(offset, dynamicOffsets, staticOffsets);
auto stridedLayout =
StridedLayoutAttr::get(source.getContext(), staticOffsets.front(), {});
return MemRefType::get({}, sourceType.getElementType(), stridedLayout,
sourceType.getMemorySpace());
}

static void setInsertionPointToStart(OpBuilder &builder, Value val) {
if (auto *parentOp = val.getDefiningOp()) {
builder.setInsertionPointAfter(parentOp);
Expand Down Expand Up @@ -98,7 +109,7 @@ static Value getFlatMemref(OpBuilder &rewriter, Location loc, Value source,
SmallVector<OpFoldResult> offsetsTemp = getAsOpFoldResult(offsets);
auto &&[base, offset, ignore] =
getFlatOffsetAndStrides(rewriter, loc, source, offsetsTemp);
auto retType = cast<MemRefType>(base.getType());
MemRefType retType = inferCastResultType(base, offset);
return rewriter.create<memref::ReinterpretCastOp>(loc, retType, base, offset,
std::nullopt, std::nullopt);
}
Expand Down
27 changes: 15 additions & 12 deletions mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1892,11 +1892,12 @@ LogicalResult ReinterpretCastOp::verify() {
// Match sizes in result memref type and in static_sizes attribute.
for (auto [idx, resultSize, expectedSize] :
llvm::enumerate(resultType.getShape(), getStaticSizes())) {
if (!ShapedType::isDynamic(resultSize) &&
!ShapedType::isDynamic(expectedSize) && resultSize != expectedSize)
if (!ShapedType::isDynamic(resultSize) && resultSize != expectedSize)
return emitError("expected result type with size = ")
<< expectedSize << " instead of " << resultSize
<< " in dim = " << idx;
<< (ShapedType::isDynamic(expectedSize)
? std::string("dynamic")
: std::to_string(expectedSize))
<< " instead of " << resultSize << " in dim = " << idx;
}

// Match offset and strides in static_offset and static_strides attributes. If
Expand All @@ -1910,20 +1911,22 @@ LogicalResult ReinterpretCastOp::verify() {

// Match offset in result memref type and in static_offsets attribute.
int64_t expectedOffset = getStaticOffsets().front();
if (!ShapedType::isDynamic(resultOffset) &&
!ShapedType::isDynamic(expectedOffset) && resultOffset != expectedOffset)
if (!ShapedType::isDynamic(resultOffset) && resultOffset != expectedOffset)
return emitError("expected result type with offset = ")
<< expectedOffset << " instead of " << resultOffset;
<< (ShapedType::isDynamic(expectedOffset)
? std::string("dynamic")
: std::to_string(expectedOffset))
<< " instead of " << resultOffset;

// Match strides in result memref type and in static_strides attribute.
for (auto [idx, resultStride, expectedStride] :
llvm::enumerate(resultStrides, getStaticStrides())) {
if (!ShapedType::isDynamic(resultStride) &&
!ShapedType::isDynamic(expectedStride) &&
resultStride != expectedStride)
if (!ShapedType::isDynamic(resultStride) && resultStride != expectedStride)
return emitError("expected result type with stride = ")
<< expectedStride << " instead of " << resultStride
<< " in dim = " << idx;
<< (ShapedType::isDynamic(expectedStride)
? std::string("dynamic")
: std::to_string(expectedStride))
<< " instead of " << resultStride << " in dim = " << idx;
}

return success();
Expand Down
22 changes: 18 additions & 4 deletions mlir/lib/Dialect/MemRef/Transforms/ExpandOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,8 @@ struct MemRefReshapeOpConverter : public OpRewritePattern<memref::ReshapeOp> {
strides.resize(rank);

Location loc = op.getLoc();
Value stride = rewriter.create<arith::ConstantIndexOp>(loc, 1);
Value stride = nullptr;
int64_t staticStride = 1;
for (int i = rank - 1; i >= 0; --i) {
Value size;
// Load dynamic sizes from the shape input, use constants for static dims.
Expand All @@ -105,9 +106,22 @@ struct MemRefReshapeOpConverter : public OpRewritePattern<memref::ReshapeOp> {
size = rewriter.create<arith::ConstantOp>(loc, sizeAttr);
sizes[i] = sizeAttr;
}
strides[i] = stride;
if (i > 0)
stride = rewriter.create<arith::MulIOp>(loc, stride, size);
if (stride)
strides[i] = stride;
else
strides[i] = rewriter.getIndexAttr(staticStride);

if (i > 0) {
if (stride) {
stride = rewriter.create<arith::MulIOp>(loc, stride, size);
} else if (op.getType().isDynamicDim(i)) {
stride = rewriter.create<arith::MulIOp>(
loc, rewriter.create<arith::ConstantIndexOp>(loc, staticStride),
size);
} else {
staticStride *= op.getType().getDimSize(i);
}
}
}
rewriter.replaceOpWithNewOp<memref::ReinterpretCastOp>(
op, op.getType(), op.getSource(), /*offset=*/rewriter.getIndexAttr(0),
Expand Down
17 changes: 7 additions & 10 deletions mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -507,6 +507,8 @@ getCollapsedStride(memref::CollapseShapeOp collapseShape, OpBuilder &builder,

SmallVector<OpFoldResult> groupStrides;
ArrayRef<int64_t> srcShape = sourceType.getShape();

OpFoldResult lastValidStride = nullptr;
for (int64_t currentDim : reassocGroup) {
// Skip size-of-1 dimensions, since right now their strides may be
// meaningless.
Expand All @@ -517,11 +519,11 @@ getCollapsedStride(memref::CollapseShapeOp collapseShape, OpBuilder &builder,
continue;

int64_t currentStride = strides[currentDim];
groupStrides.push_back(ShapedType::isDynamic(currentStride)
? origStrides[currentDim]
: builder.getIndexAttr(currentStride));
lastValidStride = ShapedType::isDynamic(currentStride)
? origStrides[currentDim]
: builder.getIndexAttr(currentStride);
}
if (groupStrides.empty()) {
if (!lastValidStride) {
// We're dealing with a 1x1x...x1 shape. The stride is meaningless,
// but we still have to make the type system happy.
MemRefType collapsedType = collapseShape.getResultType();
Expand All @@ -543,12 +545,7 @@ getCollapsedStride(memref::CollapseShapeOp collapseShape, OpBuilder &builder,
return {builder.getIndexAttr(finalStride)};
}

// For the general case, we just want the minimum stride
// since the collapsed dimensions are contiguous.
auto minMap = AffineMap::getMultiDimIdentityMap(groupStrides.size(),
builder.getContext());
return {makeComposedFoldedAffineMin(builder, collapseShape.getLoc(), minMap,
groupStrides)};
return {lastValidStride};
}

/// From `reshape_like(memref, subSizes, subStrides))` compute
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -425,8 +425,6 @@ func.func @collapse_shape_dynamic_with_non_identity_layout(
// CHECK: %[[SIZE1:.*]] = llvm.extractvalue %[[MEM]][3, 1] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
// CHECK: %[[SIZE2:.*]] = llvm.extractvalue %[[MEM]][3, 2] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
// CHECK: %[[STRIDE0:.*]] = llvm.extractvalue %[[MEM]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
// CHECK: %[[STRIDE0_TO_IDX:.*]] = builtin.unrealized_conversion_cast %[[STRIDE0]] : i64 to index
// CHECK: %[[STRIDE0:.*]] = builtin.unrealized_conversion_cast %[[STRIDE0_TO_IDX]] : index to i64
// CHECK: %[[FINAL_SIZE1:.*]] = llvm.mul %[[SIZE1]], %[[SIZE2]] : i64
// CHECK: %[[SIZE1_TO_IDX:.*]] = builtin.unrealized_conversion_cast %[[FINAL_SIZE1]] : i64 to index
// CHECK: %[[FINAL_SIZE1:.*]] = builtin.unrealized_conversion_cast %[[SIZE1_TO_IDX]] : index to i64
Expand Down Expand Up @@ -548,23 +546,19 @@ func.func @collapse_shape_dynamic(%arg0 : memref<1x2x?xf32>) -> memref<1x?xf32>
// CHECK: %[[C0:.*]] = llvm.mlir.constant(0 : index) : i64
// CHECK: %[[SIZE2:.*]] = llvm.extractvalue %[[MEM]][3, 2] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
// CHECK: %[[STRIDE0:.*]] = llvm.extractvalue %[[MEM]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
// CHECK: %[[STRIDE1:.*]] = llvm.extractvalue %[[MEM]][4, 1] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
// CHECK: %[[C2:.*]] = llvm.mlir.constant(2 : index) : i64
// CHECK: %[[FINAL_SIZE1:.*]] = llvm.mul %[[SIZE2]], %[[C2]] : i64
// CHECK: %[[SIZE1_TO_IDX:.*]] = builtin.unrealized_conversion_cast %[[FINAL_SIZE1]] : i64 to index
// CHECK: %[[FINAL_SIZE1:.*]] = builtin.unrealized_conversion_cast %[[SIZE1_TO_IDX]] : index to i64
// CHECK: %[[C1:.*]] = llvm.mlir.constant(1 : index) : i64
// CHECK: %[[MIN_STRIDE1:.*]] = llvm.intr.smin(%[[STRIDE1]], %[[C1]]) : (i64, i64) -> i64
// CHECK: %[[MIN_STRIDE1_TO_IDX:.*]] = builtin.unrealized_conversion_cast %[[MIN_STRIDE1]] : i64 to index
// CHECK: %[[MIN_STRIDE1:.*]] = builtin.unrealized_conversion_cast %[[MIN_STRIDE1_TO_IDX]] : index to i64
// CHECK: %[[DESC:.*]] = llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: %[[DESC0:.*]] = llvm.insertvalue %[[BASE_BUFFER]], %[[DESC]][0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: %[[DESC1:.*]] = llvm.insertvalue %[[ALIGNED_BUFFER]], %[[DESC0]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: %[[DESC2:.*]] = llvm.insertvalue %[[C0]], %[[DESC1]][2] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: %[[C1:.*]] = llvm.mlir.constant(1 : index) : i64
// CHECK: %[[DESC3:.*]] = llvm.insertvalue %[[C1]], %[[DESC2]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: %[[DESC4:.*]] = llvm.insertvalue %[[STRIDE0]], %[[DESC3]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: %[[DESC5:.*]] = llvm.insertvalue %[[FINAL_SIZE1]], %[[DESC4]][3, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: %[[DESC6:.*]] = llvm.insertvalue %[[MIN_STRIDE1]], %[[DESC5]][4, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: %[[DESC6:.*]] = llvm.insertvalue %[[C1]], %[[DESC5]][4, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: %[[RES:.*]] = builtin.unrealized_conversion_cast %[[DESC6]] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> to memref<1x?xf32>
// CHECK: return %[[RES]] : memref<1x?xf32>
// CHECK: }
Expand Down
12 changes: 6 additions & 6 deletions mlir/test/Dialect/GPU/decompose-memrefs.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
// CHECK: gpu.launch
// CHECK-SAME: threads(%[[TX:.*]], %[[TY:.*]], %[[TZ:.*]]) in
// CHECK: %[[IDX:.*]] = affine.apply #[[MAP]]()[%[[TX]], %[[STRIDES]]#0, %[[TY]], %[[STRIDES]]#1, %[[TZ]]]
// CHECK: %[[PTR:.*]] = memref.reinterpret_cast %[[BASE]] to offset: [%[[IDX]]], sizes: [], strides: [] : memref<f32> to memref<f32>
// CHECK: memref.store %[[VAL]], %[[PTR]][] : memref<f32>
// CHECK: %[[PTR:.*]] = memref.reinterpret_cast %[[BASE]] to offset: [%[[IDX]]], sizes: [], strides: [] : memref<f32> to memref<f32, strided<[], offset: ?>>
// CHECK: memref.store %[[VAL]], %[[PTR]][] : memref<f32, strided<[], offset: ?>>
func.func @decompose_store(%arg0 : f32, %arg1 : memref<?x?x?xf32>) {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
Expand All @@ -33,8 +33,8 @@ func.func @decompose_store(%arg0 : f32, %arg1 : memref<?x?x?xf32>) {
// CHECK: gpu.launch
// CHECK-SAME: threads(%[[TX:.*]], %[[TY:.*]], %[[TZ:.*]]) in
// CHECK: %[[IDX:.*]] = affine.apply #[[MAP]]()[%[[OFFSET]], %[[TX]], %[[STRIDES]]#0, %[[TY]], %[[STRIDES]]#1, %[[TZ]], %[[STRIDES]]#2]
// CHECK: %[[PTR:.*]] = memref.reinterpret_cast %[[BASE]] to offset: [%[[IDX]]], sizes: [], strides: [] : memref<f32> to memref<f32>
// CHECK: memref.store %[[VAL]], %[[PTR]][] : memref<f32>
// CHECK: %[[PTR:.*]] = memref.reinterpret_cast %[[BASE]] to offset: [%[[IDX]]], sizes: [], strides: [] : memref<f32> to memref<f32, strided<[], offset: ?>>
// CHECK: memref.store %[[VAL]], %[[PTR]][] : memref<f32, strided<[], offset: ?>>
func.func @decompose_store_strided(%arg0 : f32, %arg1 : memref<?x?x?xf32, strided<[?, ?, ?], offset: ?>>) {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
Expand All @@ -59,8 +59,8 @@ func.func @decompose_store_strided(%arg0 : f32, %arg1 : memref<?x?x?xf32, stride
// CHECK: gpu.launch
// CHECK-SAME: threads(%[[TX:.*]], %[[TY:.*]], %[[TZ:.*]]) in
// CHECK: %[[IDX:.*]] = affine.apply #[[MAP]]()[%[[TX]], %[[STRIDES]]#0, %[[TY]], %[[STRIDES]]#1, %[[TZ]]]
// CHECK: %[[PTR:.*]] = memref.reinterpret_cast %[[BASE]] to offset: [%[[IDX]]], sizes: [], strides: [] : memref<f32> to memref<f32>
// CHECK: %[[RES:.*]] = memref.load %[[PTR]][] : memref<f32>
// CHECK: %[[PTR:.*]] = memref.reinterpret_cast %[[BASE]] to offset: [%[[IDX]]], sizes: [], strides: [] : memref<f32> to memref<f32, strided<[], offset: ?>>
// CHECK: %[[RES:.*]] = memref.load %[[PTR]][] : memref<f32, strided<[], offset: ?>>
// CHECK: "test.test"(%[[RES]]) : (f32) -> ()
func.func @decompose_load(%arg0 : memref<?x?x?xf32>) {
%c0 = arith.constant 0 : index
Expand Down
13 changes: 6 additions & 7 deletions mlir/test/Dialect/MemRef/expand-ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -52,20 +52,19 @@ func.func @memref_reshape(%input: memref<*xf32>,
// CHECK-SAME: [[SRC:%.*]]: memref<*xf32>,
// CHECK-SAME: [[SHAPE:%.*]]: memref<3xi32>) -> memref<?x?x8xf32> {

// CHECK: [[C1:%.*]] = arith.constant 1 : index
// CHECK: [[C8:%.*]] = arith.constant 8 : index
// CHECK: [[STRIDE_1:%.*]] = arith.muli [[C1]], [[C8]] : index

// CHECK: [[C1_:%.*]] = arith.constant 1 : index
// CHECK: [[DIM_1:%.*]] = memref.load [[SHAPE]]{{\[}}[[C1_]]] : memref<3xi32>
// CHECK: [[C1:%.*]] = arith.constant 1 : index
// CHECK: [[DIM_1:%.*]] = memref.load [[SHAPE]]{{\[}}[[C1]]] : memref<3xi32>
// CHECK: [[SIZE_1:%.*]] = arith.index_cast [[DIM_1]] : i32 to index
// CHECK: [[STRIDE_0:%.*]] = arith.muli [[STRIDE_1]], [[SIZE_1]] : index

// CHECK: [[C8_:%.*]] = arith.constant 8 : index
// CHECK: [[STRIDE_0:%.*]] = arith.muli [[C8_]], [[SIZE_1]] : index

// CHECK: [[C0:%.*]] = arith.constant 0 : index
// CHECK: [[DIM_0:%.*]] = memref.load [[SHAPE]]{{\[}}[[C0]]] : memref<3xi32>
// CHECK: [[SIZE_0:%.*]] = arith.index_cast [[DIM_0]] : i32 to index

// CHECK: [[RESULT:%.*]] = memref.reinterpret_cast [[SRC]]
// CHECK-SAME: to offset: [0], sizes: {{\[}}[[SIZE_0]], [[SIZE_1]], 8],
// CHECK-SAME: strides: {{\[}}[[STRIDE_0]], [[STRIDE_1]], [[C1]]]
// CHECK-SAME: strides: {{\[}}[[STRIDE_0]], 8, 1]
// CHECK-SAME: : memref<*xf32> to memref<?x?x8xf32>
21 changes: 6 additions & 15 deletions mlir/test/Dialect/MemRef/expand-strided-metadata.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -931,19 +931,15 @@ func.func @extract_aligned_pointer_as_index_of_unranked_source(%arg0: memref<*xf
// = min(7, 1)
// = 1
//
// CHECK-DAG: #[[$STRIDE0_MIN_MAP:.*]] = affine_map<()[s0] -> (s0)>
// CHECK-DAG: #[[$SIZE0_MAP:.*]] = affine_map<()[s0, s1] -> ((s0 * s1) * 4)>
// CHECK-DAG: #[[$STRIDE1_MIN_MAP:.*]] = affine_map<()[s0, s1] -> (s0, s1, 42)>
// CHECK: #[[$SIZE0_MAP:.*]] = affine_map<()[s0, s1] -> ((s0 * s1) * 4)>
// CHECK-LABEL: func @simplify_collapse(
// CHECK-SAME: %[[ARG:.*]]: memref<?x?x4x?x6x7xi32>)
//
// CHECK: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:6, %[[STRIDES:.*]]:6 = memref.extract_strided_metadata %[[ARG]] : memref<?x?x4x?x6x7xi32>
//
// CHECK-DAG: %[[DYN_STRIDE0:.*]] = affine.min #[[$STRIDE0_MIN_MAP]]()[%[[STRIDES]]#0]
// CHECK-DAG: %[[DYN_SIZE1:.*]] = affine.apply #[[$SIZE0_MAP]]()[%[[SIZES]]#1, %[[SIZES]]#3]
// CHECK-DAG: %[[DYN_STRIDE1:.*]] = affine.min #[[$STRIDE1_MIN_MAP]]()[%[[STRIDES]]#1, %[[STRIDES]]#2]
// CHECK: %[[DYN_SIZE1:.*]] = affine.apply #[[$SIZE0_MAP]]()[%[[SIZES]]#1, %[[SIZES]]#3]
//
// CHECK: %[[COLLAPSE_VIEW:.*]] = memref.reinterpret_cast %[[BASE]] to offset: [0], sizes: [%[[SIZES]]#0, %[[DYN_SIZE1]], 42], strides: [%[[DYN_STRIDE0]], %[[DYN_STRIDE1]], 1]
// CHECK: %[[COLLAPSE_VIEW:.*]] = memref.reinterpret_cast %[[BASE]] to offset: [0], sizes: [%[[SIZES]]#0, %[[DYN_SIZE1]], 42], strides: [%[[STRIDES]]#0, 42, 1]
func.func @simplify_collapse(%arg : memref<?x?x4x?x6x7xi32>)
-> memref<?x?x42xi32> {

Expand Down Expand Up @@ -1046,15 +1042,12 @@ func.func @simplify_collapse_with_dim_of_size1_and_non_1_stride
// We just return the first dynamic one for this group.
//
//
// CHECK-DAG: #[[$STRIDE0_MIN_MAP:.*]] = affine_map<()[s0, s1] -> (s0, s1)>
// CHECK-LABEL: func @simplify_collapse_with_dim_of_size1_and_resulting_dyn_stride(
// CHECK-SAME: %[[ARG:.*]]: memref<2x3x1x1x1xi32, strided<[?, ?, ?, ?, 2]
//
// CHECK: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:5, %[[STRIDES:.*]]:5 = memref.extract_strided_metadata %[[ARG]] : memref<2x3x1x1x1xi32, strided<[?, ?, ?, ?, 2], offset: ?>>
//
// CHECK-DAG: %[[DYN_STRIDE0:.*]] = affine.min #[[$STRIDE0_MIN_MAP]]()[%[[STRIDES]]#0, %[[STRIDES]]#1]
//
// CHECK: %[[COLLAPSE_VIEW:.*]] = memref.reinterpret_cast %[[BASE]] to offset: [%[[OFFSET]]], sizes: [6, 1], strides: [%[[DYN_STRIDE0]], %[[STRIDES]]#2]
// CHECK: %[[COLLAPSE_VIEW:.*]] = memref.reinterpret_cast %[[BASE]] to offset: [%[[OFFSET]]], sizes: [6, 1], strides: [%[[STRIDES]]#1, %[[STRIDES]]#2]
func.func @simplify_collapse_with_dim_of_size1_and_resulting_dyn_stride
(%arg0: memref<2x3x1x1x1xi32, strided<[?, ?, ?, ?, 2], offset: ?>>)
-> memref<6x1xi32, strided<[?, ?], offset: ?>> {
Expand Down Expand Up @@ -1083,8 +1076,7 @@ func.func @simplify_collapse_with_dim_of_size1_and_resulting_dyn_stride
// Stride 2 = origStride5
// = 1
//
// CHECK-DAG: #[[$SIZE0_MAP:.*]] = affine_map<()[s0, s1] -> ((s0 * s1) * 4)>
// CHECK-DAG: #[[$STRIDE0_MAP:.*]] = affine_map<()[s0] -> (s0)>
// CHECK: #[[$SIZE0_MAP:.*]] = affine_map<()[s0, s1] -> ((s0 * s1) * 4)>
// CHECK-LABEL: func @extract_strided_metadata_of_collapse(
// CHECK-SAME: %[[ARG:.*]]: memref<?x?x4x?x6x7xi32>)
//
Expand All @@ -1094,10 +1086,9 @@ func.func @simplify_collapse_with_dim_of_size1_and_resulting_dyn_stride
//
// CHECK-DAG: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:6, %[[STRIDES:.*]]:6 = memref.extract_strided_metadata %[[ARG]] : memref<?x?x4x?x6x7xi32>
//
// CHECK-DAG: %[[DYN_STRIDE0:.*]] = affine.min #[[$STRIDE0_MAP]]()[%[[STRIDES]]#0]
// CHECK-DAG: %[[DYN_SIZE1:.*]] = affine.apply #[[$SIZE0_MAP]]()[%[[SIZES]]#1, %[[SIZES]]#3]
//
// CHECK: return %[[BASE]], %[[C0]], %[[SIZES]]#0, %[[DYN_SIZE1]], %[[C42]], %[[DYN_STRIDE0]], %[[C42]], %[[C1]]
// CHECK: return %[[BASE]], %[[C0]], %[[SIZES]]#0, %[[DYN_SIZE1]], %[[C42]], %[[STRIDES]]#0, %[[C42]], %[[C1]]
func.func @extract_strided_metadata_of_collapse(%arg : memref<?x?x4x?x6x7xi32>)
-> (memref<i32>, index,
index, index, index,
Expand Down
9 changes: 9 additions & 0 deletions mlir/test/Dialect/MemRef/invalid.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,15 @@ func.func @memref_reinterpret_cast_no_map_but_offset(%in: memref<?xf32>) {

// -----

func.func @memref_reinterpret_cast_offset_mismatch_dynamic(%in: memref<?xf32>, %offset : index) {
// expected-error @+1 {{expected result type with offset = dynamic instead of 0}}
%out = memref.reinterpret_cast %in to offset: [%offset], sizes: [10], strides: [1]
: memref<?xf32> to memref<10xf32>
return
}

// -----

func.func @memref_reinterpret_cast_no_map_but_stride(%in: memref<?xf32>) {
// expected-error @+1 {{expected result type with stride = 10 instead of 1 in dim = 0}}
%out = memref.reinterpret_cast %in to offset: [0], sizes: [10], strides: [10]
Expand Down

0 comments on commit 889b67c

Please sign in to comment.