Skip to content

Commit

Permalink
[mlir][Vector] Update patterns for flattening vector.xfer Ops (2/N) (l…
Browse files Browse the repository at this point in the history
…lvm#73523)

Updates patterns for flattening `vector.transfer_read` by relaxing the
requirement that the "collapsed" indices are all zero. This enables
collapsing cases like this one:

```mlir
  %2 = vector.transfer_read %arg4[%c0, %arg0, %arg1, %c0] ... :
    memref<1x43x4x6xi32>, vector<1x2x6xi32>
```

Previously only the following case would be consider for collapsing
(all indices are 0):

```mlir
  %2 = vector.transfer_read %arg4[%c0, %c0, %c0, %c0] ... :
    memref<1x43x4x6xi32>, vector<1x2x6xi32>
```

Also adds some new comments and renames the `firstContiguousInnerDim`
parameter as `firstDimToCollapse` (the latter better matches the actual
meaning).

Similar updates for `vector.transfer_write` will be implemented in a
follow-up patch.
  • Loading branch information
banach-space authored Dec 5, 2023
1 parent e8dbe94 commit 2eb9e33
Show file tree
Hide file tree
Showing 4 changed files with 129 additions and 11 deletions.
79 changes: 68 additions & 11 deletions mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -511,6 +511,8 @@ static Value collapseInnerDims(PatternRewriter &rewriter, mlir::Location loc,
/// Checks that the indices corresponding to dimensions starting at
/// `firstDimToCollapse` are constant 0, and writes to `outIndices`
/// the truncated indices where `firstDimToCollapse` is now the innermost dim.
/// TODO: Extract the logic that writes to outIndices so that this method
/// simply checks one pre-condition.
static LogicalResult
checkAndCollapseInnerZeroIndices(ValueRange indices, int64_t firstDimToCollapse,
SmallVector<Value> &outIndices) {
Expand Down Expand Up @@ -542,45 +544,100 @@ class FlattenContiguousRowMajorTransferReadPattern
auto loc = transferReadOp.getLoc();
Value vector = transferReadOp.getVector();
VectorType vectorType = cast<VectorType>(vector.getType());
Value source = transferReadOp.getSource();
auto source = transferReadOp.getSource();
MemRefType sourceType = dyn_cast<MemRefType>(source.getType());

// 0. Check pre-conditions
// Contiguity check is valid on tensors only.
if (!sourceType)
return failure();
// If this is already 0D/1D, there's nothing to do.
if (vectorType.getRank() <= 1)
// Already 0D/1D, nothing to do.
return failure();
if (!vector::isContiguousSlice(sourceType, vectorType))
return failure();
int64_t firstContiguousInnerDim =
sourceType.getRank() - vectorType.getRank();
// TODO: generalize this pattern, relax the requirements here.
if (transferReadOp.hasOutOfBoundsDim())
return failure();
if (!transferReadOp.getPermutationMap().isMinorIdentity())
return failure();
if (transferReadOp.getMask())
return failure();

SmallVector<Value> collapsedIndices;
if (failed(checkAndCollapseInnerZeroIndices(transferReadOp.getIndices(),
firstContiguousInnerDim,
collapsedIndices)))
return failure();
int64_t firstDimToCollapse = sourceType.getRank() - vectorType.getRank();

// 1. Collapse the source memref
Value collapsedSource =
collapseInnerDims(rewriter, loc, source, firstContiguousInnerDim);
collapseInnerDims(rewriter, loc, source, firstDimToCollapse);
MemRefType collapsedSourceType =
dyn_cast<MemRefType>(collapsedSource.getType());
int64_t collapsedRank = collapsedSourceType.getRank();
assert(collapsedRank == firstContiguousInnerDim + 1);
assert(collapsedRank == firstDimToCollapse + 1);

// 2. Generate input args for a new vector.transfer_read that will read
// from the collapsed memref.
// 2.1. New dim exprs + affine map
SmallVector<AffineExpr, 1> dimExprs{
getAffineDimExpr(firstContiguousInnerDim, rewriter.getContext())};
getAffineDimExpr(firstDimToCollapse, rewriter.getContext())};
auto collapsedMap =
AffineMap::get(collapsedRank, 0, dimExprs, rewriter.getContext());

// 2.2 New indices
// If all the collapsed indices are zero then no extra logic is needed.
// Otherwise, a new offset/index has to be computed.
if (failed(checkAndCollapseInnerZeroIndices(transferReadOp.getIndices(),
firstDimToCollapse,
collapsedIndices))) {
// Copy all the leading indices
collapsedIndices = transferReadOp.getIndices();
collapsedIndices.resize(firstDimToCollapse);

// Compute the remaining trailing index/offset required for reading from
// the collapsed memref:
//
// offset = 0
// for (i = firstDimToCollapse; i < outputRank; ++i)
// offset += sourceType.getDimSize(i) * transferReadOp.indices[i]
//
// For this example:
// %2 = vector.transfer_read %arg4[%c0, %arg0, %c0] (...) :
// memref<1x43x2xi32>, vector<1x2xi32>
// which would be collapsed to:
// %1 = vector.transfer_read %collapse_shape[%c0, %offset] (...) :
// memref<1x86xi32>, vector<2xi32>
// one would get the following offset:
// %offset = %arg0 * 43
AffineExpr offsetExpr, idxExpr;
bindSymbols(rewriter.getContext(), offsetExpr, idxExpr);

int64_t outputRank = transferReadOp.getIndices().size();
OpFoldResult offset =
rewriter.create<arith::ConstantIndexOp>(loc, 0).getResult();

for (int64_t i = firstDimToCollapse; i < outputRank; ++i) {
int64_t dim = dyn_cast<ShapedType>(source.getType()).getDimSize(i);
offset = affine::makeComposedFoldedAffineApply(
rewriter, loc, offsetExpr + dim * idxExpr,
{offset, transferReadOp.getIndices()[i]});
}
if (offset.is<Value>()) {
collapsedIndices.push_back(offset.get<Value>());
} else {
collapsedIndices.push_back(rewriter.create<arith::ConstantIndexOp>(
loc, *getConstantIntValue(offset)));
}
}

// 3. Create new vector.transfer_read that reads from the collapsed memref
VectorType flatVectorType = VectorType::get({vectorType.getNumElements()},
vectorType.getElementType());
vector::TransferReadOp flatRead = rewriter.create<vector::TransferReadOp>(
loc, flatVectorType, collapsedSource, collapsedIndices, collapsedMap);
flatRead.setInBoundsAttr(rewriter.getBoolArrayAttr({true}));

// 4. Replace the old transfer_read with the new one reading from the
// collapsed shape
rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(
transferReadOp, cast<VectorType>(vector.getType()), flatRead);
return success();
Expand Down
5 changes: 5 additions & 0 deletions mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,11 @@ bool vector::isContiguousSlice(MemRefType memrefType, VectorType vectorType) {
return false;
auto strides = ArrayRef<int64_t>(stridesFull).take_back(vecRank);

// TODO: Add support for memref with trailing dynamic shapes. Memrefs
// with leading dynamic dimensions are already supported.
if (ShapedType::isDynamicShape(memrefShape))
return false;

// Cond 1: A contiguous memref will always have a unit trailing stride.
if (strides.back() != 1)
return false;
Expand Down
55 changes: 55 additions & 0 deletions mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,61 @@ func.func @transfer_read_dims_mismatch_contiguous(

// -----

func.func @transfer_read_dims_mismatch_non_zero_indices(
%idx_1: index,
%idx_2: index,
%m_in: memref<1x43x4x6xi32>,
%m_out: memref<1x2x6xi32>) {
%c0 = arith.constant 0 : index
%c0_i32 = arith.constant 0 : i32
%2 = vector.transfer_read %m_in[%c0, %idx_1, %idx_2, %c0], %c0_i32 {in_bounds = [true, true, true]} :
memref<1x43x4x6xi32>, vector<1x2x6xi32>
vector.transfer_write %2, %m_out[%c0, %c0, %c0] {in_bounds = [true, true, true]} :
vector<1x2x6xi32>, memref<1x2x6xi32>
return
}

// CHECK: #[[$ATTR_0:.+]] = affine_map<()[s0, s1] -> (s0 * 4 + s1 * 43)>

// CHECK-LABEL: func.func @transfer_read_dims_mismatch_non_zero_indices(
// CHECK-SAME: %[[IDX_1:.*]]: index, %[[IDX_2:.*]]: index,
// CHECK-SAME: %[[M_IN:.*]]: memref<1x43x4x6xi32>,
// CHECK-SAME: %[[M_OUT:.*]]: memref<1x2x6xi32>) {
// CHECK: %[[C_0:.*]] = arith.constant 0 : i32
// CHECK: %[[C_0_IDX:.*]] = arith.constant 0 : index
// CHECK: %[[COLLAPSED_IN:.*]] = memref.collapse_shape %[[M_IN]] {{\[}}[0], [1, 2, 3]] : memref<1x43x4x6xi32> into memref<1x1032xi32>
// CHECK: %[[COLLAPSED_IDX:.*]] = affine.apply #[[$ATTR_0]]()[%[[IDX_2]], %[[IDX_1]]]
// CHECK: %[[READ:.*]] = vector.transfer_read %[[COLLAPSED_IN]][%[[C_0_IDX]], %[[COLLAPSED_IDX]]], %[[C_0]] {in_bounds = [true]} : memref<1x1032xi32>, vector<12xi32>
// CHECK: %[[COLLAPSED_OUT:.*]] = memref.collapse_shape %[[M_OUT]] {{\[}}[0, 1, 2]] : memref<1x2x6xi32> into memref<12xi32>
// CHECK: vector.transfer_write %[[READ]], %[[COLLAPSED_OUT]][%[[C_0_IDX]]] {in_bounds = [true]} : vector<12xi32>, memref<12xi32>

// -----

func.func @transfer_read_dims_mismatch_non_zero_indices_dynamic_shapes(
%idx_1: index,
%idx_2: index,
%m_in: memref<1x?x4x6xi32>,
%m_out: memref<1x2x6xi32>) {
%c0 = arith.constant 0 : index
%c0_i32 = arith.constant 0 : i32
%2 = vector.transfer_read %m_in[%c0, %idx_1, %idx_2, %c0], %c0_i32 {in_bounds = [true, true, true]} :
memref<1x?x4x6xi32>, vector<1x2x6xi32>
vector.transfer_write %2, %m_out[%c0, %c0, %c0] {in_bounds = [true, true, true]} :
vector<1x2x6xi32>, memref<1x2x6xi32>
return
}

// CHECK-LABEL: func.func @transfer_read_dims_mismatch_non_zero_indices_dynamic_shapes(
// CHECK-SAME: %[[IDX_1:.*]]: index, %[[IDX_2:.*]]: index,
// CHECK-SAME: %[[M_IN:.*]]: memref<1x?x4x6xi32>,
// CHECK-SAME: %[[M_OUT:.*]]: memref<1x2x6xi32>) {
// CHECK: %[[READ:.*]] = vector.transfer_read %[[M_IN]]{{.*}} : memref<1x?x4x6xi32>, vector<1x2x6xi32>
// CHECK: %[[COLLAPSED:.*]] = memref.collapse_shape %[[M_OUT]]{{.*}} : memref<1x2x6xi32> into memref<12xi32>
// CHECK: %[[SC:.*]] = vector.shape_cast %[[READ]] : vector<1x2x6xi32> to vector<12xi32>
// CHECK: vector.transfer_write %[[SC]], %[[COLLAPSED]]{{.*}} : vector<12xi32>, memref<12xi32>

// -----

func.func @transfer_read_dims_mismatch_non_contiguous(
%arg : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>) -> vector<2x1x2x2xi8> {
%c0 = arith.constant 0 : index
Expand Down
1 change: 1 addition & 0 deletions mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -454,6 +454,7 @@ struct TestFlattenVectorTransferPatterns
}
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<memref::MemRefDialect>();
registry.insert<affine::AffineDialect>();
}
void runOnOperation() override {
RewritePatternSet patterns(&getContext());
Expand Down

0 comments on commit 2eb9e33

Please sign in to comment.