Skip to content

Commit

Permalink
[mlir] Fix a zero stride canonicalizer crash (llvm#74200)
Browse files Browse the repository at this point in the history
This PR fixes llvm#73383 and is
another shot at the refactoring proposed in
llvm#72885.

---------

Co-authored-by: Kai Sasaki <[email protected]>
  • Loading branch information
rikhuijzer and Lewuathe authored Dec 6, 2023
1 parent df7545e commit 68f0bc6
Show file tree
Hide file tree
Showing 5 changed files with 78 additions and 25 deletions.
30 changes: 27 additions & 3 deletions mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -139,12 +139,36 @@ SmallVector<int64_t>
getValuesSortedByKey(ArrayRef<Attribute> keys, ArrayRef<int64_t> values,
llvm::function_ref<bool(Attribute, Attribute)> compare);

/// Helper function to check whether the passed in `sizes` or `offsets` are
/// valid. This can be used to re-check whether dimensions are still valid
/// after constant folding the dynamic dimensions.
bool hasValidSizesOffsets(SmallVector<int64_t> sizesOrOffsets);

/// Helper function to check whether the passed in `strides` are valid. This
/// can be used to re-check whether dimensions are still valid after constant
/// folding the dynamic dimensions.
bool hasValidStrides(SmallVector<int64_t> strides);

/// Returns "success" when any of the elements in `ofrs` is a constant value. In
/// that case the value is replaced by an attribute. Returns "failure" when no
/// folding happened. If `onlyNonNegative` is set, only non-negative constant
/// values are folded.
/// folding happened. If `onlyNonNegative` and `onlyNonZero` are set, only
/// non-negative and non-zero constant values are folded respectively.
LogicalResult foldDynamicIndexList(SmallVectorImpl<OpFoldResult> &ofrs,
bool onlyNonNegative = false);
bool onlyNonNegative = false,
bool onlyNonZero = false);

/// Returns "success" when any of the elements in `offsetsOrSizes` is a
/// constant value. In that case the value is replaced by an attribute. Returns
/// "failure" when no folding happened. Invalid values are not folded to avoid
/// canonicalization crashes.
LogicalResult
foldDynamicOffsetSizeList(SmallVectorImpl<OpFoldResult> &offsetsOrSizes);

/// Returns "success" when any of the elements in `strides` is a constant
/// value. In that case the value is replaced by an attribute. Returns
/// "failure" when no folding happened. Invalid values are not folded to avoid
/// canonicalization crashes.
LogicalResult foldDynamicStrideList(SmallVectorImpl<OpFoldResult> &strides);

/// Return the number of iterations for a loop with a lower bound `lb`, upper
/// bound `ub` and step `step`.
Expand Down
17 changes: 6 additions & 11 deletions mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2582,17 +2582,12 @@ Type SubViewOp::inferResultType(MemRefType sourceMemRefType,
dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes);
dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);

// If one of the offsets or sizes is invalid, fail the canonicalization.
// These checks also occur in the verifier, but they are needed here
// because some dynamic dimensions may have been constant folded.
for (int64_t offset : staticOffsets)
if (offset < 0 && !ShapedType::isDynamic(offset))
return {};
for (int64_t size : staticSizes)
if (size < 0 && !ShapedType::isDynamic(size))
return {};

if (!hasValidSizesOffsets(staticOffsets))
return {};
if (!hasValidSizesOffsets(staticSizes))
return {};
if (!hasValidStrides(staticStrides))
return {};
return SubViewOp::inferResultType(sourceMemRefType, staticOffsets,
staticSizes, staticStrides);
}
Expand Down
17 changes: 7 additions & 10 deletions mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1447,13 +1447,8 @@ struct StaticTensorGenerate : public OpRewritePattern<GenerateOp> {
SmallVector<int64_t> newShape;
operandsAndShape(resultType, dynamicExtents, newOperands, newShape);

for (int64_t newdim : newShape) {
// This check also occurs in the verifier, but we need it here too
// since intermediate passes may have replaced some dynamic dimensions
// by constants.
if (newdim < 0 && !ShapedType::isDynamic(newdim))
return failure();
}
if (!hasValidSizesOffsets(newShape))
return failure();

if (newOperands.size() == tensorFromElements.getDynamicExtents().size())
return failure();
Expand Down Expand Up @@ -2549,9 +2544,9 @@ class InsertSliceOpConstantArgumentFolder final
SmallVector<OpFoldResult> mixedStrides(insertSliceOp.getMixedStrides());

// No constant operands were folded, just return;
if (failed(foldDynamicIndexList(mixedOffsets, /*onlyNonNegative=*/true)) &&
failed(foldDynamicIndexList(mixedSizes, /*onlyNonNegative=*/true)) &&
failed(foldDynamicIndexList(mixedStrides)))
if (failed(foldDynamicOffsetSizeList(mixedOffsets)) &&
failed(foldDynamicOffsetSizeList(mixedSizes)) &&
failed(foldDynamicStrideList(mixedStrides)))
return failure();

// Create the new op in canonical form.
Expand Down Expand Up @@ -2692,6 +2687,8 @@ struct InsertSliceOpSourceCastInserter final
newSrcShape[i] = *constInt;
}
}
if (!hasValidSizesOffsets(newSrcShape))
return failure();

RankedTensorType newSrcType =
RankedTensorType::get(newSrcShape, srcType.getElementType());
Expand Down
27 changes: 26 additions & 1 deletion mlir/lib/Dialect/Utils/StaticValueUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -256,8 +256,20 @@ std::optional<int64_t> constantTripCount(OpFoldResult lb, OpFoldResult ub,
return mlir::ceilDiv(*ubConstant - *lbConstant, *stepConstant);
}

bool hasValidSizesOffsets(SmallVector<int64_t> sizesOrOffsets) {
return llvm::none_of(sizesOrOffsets, [](int64_t value) {
return !ShapedType::isDynamic(value) && value < 0;
});
}

bool hasValidStrides(SmallVector<int64_t> strides) {
return llvm::none_of(strides, [](int64_t value) {
return !ShapedType::isDynamic(value) && value == 0;
});
}

LogicalResult foldDynamicIndexList(SmallVectorImpl<OpFoldResult> &ofrs,
bool onlyNonNegative) {
bool onlyNonNegative, bool onlyNonZero) {
bool valuesChanged = false;
for (OpFoldResult &ofr : ofrs) {
if (ofr.is<Attribute>())
Expand All @@ -267,11 +279,24 @@ LogicalResult foldDynamicIndexList(SmallVectorImpl<OpFoldResult> &ofrs,
// Note: All ofrs have index type.
if (onlyNonNegative && *getConstantIntValue(attr) < 0)
continue;
if (onlyNonZero && *getConstantIntValue(attr) == 0)
continue;
ofr = attr;
valuesChanged = true;
}
}
return success(valuesChanged);
}

LogicalResult
foldDynamicOffsetSizeList(SmallVectorImpl<OpFoldResult> &offsetsOrSizes) {
return foldDynamicIndexList(offsetsOrSizes, /*onlyNonNegative=*/true,
/*onlyNonZero=*/false);
}

LogicalResult foldDynamicStrideList(SmallVectorImpl<OpFoldResult> &strides) {
return foldDynamicIndexList(strides, /*onlyNonNegative=*/false,
/*onlyNonZero=*/true);
}

} // namespace mlir
12 changes: 12 additions & 0 deletions mlir/test/Dialect/MemRef/canonicalize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,18 @@ func.func @no_fold_subview_negative_size(%input: memref<4x1024xf32>) -> memref<?

// -----

// CHECK-LABEL: func @no_fold_subview_zero_stride
// CHECK: %[[SUBVIEW:.+]] = memref.subview
// CHECK: return %[[SUBVIEW]]
func.func @no_fold_subview_zero_stride(%arg0 : memref<10xf32>) -> memref<1xf32, strided<[?], offset: 1>> {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%1 = memref.subview %arg0[1] [1] [%c0] : memref<10xf32> to memref<1xf32, strided<[?], offset: 1>>
return %1 : memref<1xf32, strided<[?], offset: 1>>
}

// -----

// CHECK-LABEL: func @no_fold_of_store
// CHECK: %[[cst:.+]] = memref.cast %arg
// CHECK: memref.store %[[cst]]
Expand Down

0 comments on commit 68f0bc6

Please sign in to comment.