Skip to content

Commit

Permalink
[mlir][linalg] Mark xfers as in-bounds when masking depthwise convs (#…
Browse files Browse the repository at this point in the history
…96771)

If this is not set the fact that the dynamic channel is in-bounds cannot
be inferred automatically (like it can for static sizes), which
eventually leads to it being marked as out-of-bounds (which prevents
some rewrites).
  • Loading branch information
MacDue authored Jun 27, 2024
1 parent bb97378 commit 34e34a0
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 12 deletions.
5 changes: 5 additions & 0 deletions mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3234,6 +3234,11 @@ struct Conv1DGenerator
auto maskType =
VectorType::get(maskShape, rewriter.getI1Type(), scalableDims);

SmallVector<bool> inBounds(maskShape.size(), true);
auto xferOp = cast<VectorTransferOpInterface>(opToMask);
xferOp->setAttr(xferOp.getInBoundsAttrName(),
rewriter.getBoolArrayAttr(inBounds));

SmallVector<OpFoldResult> mixedDims = vector::getMixedSizesXfer(
cast<LinalgOp>(op).hasPureTensorSemantics(), opToMask, rewriter);

Expand Down
24 changes: 12 additions & 12 deletions mlir/test/Dialect/Linalg/vectorize-conv-masked-and-scalable.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -34,18 +34,18 @@ module attributes {transform.with_named_sequence} {
// CHECK: %[[C8:.*]] = arith.constant 8 : index
// CHECK: %[[MASK_IN:.*]] = vector.create_mask %[[C1]], %[[C8]], %[[CH_DIM_IN]] : vector<1x8x4xi1>
/// Read the input tensor
// CHECK: %[[VEC_IN:.*]] = vector.mask %[[MASK_IN]] { vector.transfer_read %[[INPUT]]{{\[}}%[[C0]], %[[C0]], %[[C0]]], %[[PAD]] : tensor<1x8x?xi8>, vector<1x8x4xi8> } : vector<1x8x4xi1> -> vector<1x8x4xi8>
// CHECK: %[[VEC_IN:.*]] = vector.mask %[[MASK_IN]] { vector.transfer_read %[[INPUT]]{{\[}}%[[C0]], %[[C0]], %[[C0]]], %[[PAD]] {in_bounds = [true, true, true]} : tensor<1x8x?xi8>, vector<1x8x4xi8> } : vector<1x8x4xi1> -> vector<1x8x4xi8>

/// Create a mask for the filter tensor
// CHECK: %[[CH_DIM_FLT:.*]] = tensor.dim %[[FILTER]], %[[C1]] : tensor<1x?xi8>
// CHECK: %[[MASK_FLT:.*]] = vector.create_mask %[[C1]], %[[CH_DIM_FLT]] : vector<1x4xi1>
/// Read the filter tensor
// CHECK: %[[VEC_FLT:.*]] = vector.mask %[[MASK_FLT]] { vector.transfer_read %[[FILTER]]{{\[}}%[[C0]], %[[C0]]], %[[PAD]] : tensor<1x?xi8>, vector<1x4xi8> } : vector<1x4xi1> -> vector<1x4xi8>
// CHECK: %[[VEC_FLT:.*]] = vector.mask %[[MASK_FLT]] { vector.transfer_read %[[FILTER]]{{\[}}%[[C0]], %[[C0]]], %[[PAD]] {in_bounds = [true, true]} : tensor<1x?xi8>, vector<1x4xi8> } : vector<1x4xi1> -> vector<1x4xi8>

/// Create a mask for the output tensor
// CHECK: %[[CH_DIM_OUT:.*]] = tensor.dim %[[OUTPUT]], %[[C2]] : tensor<1x8x?xi8>
// CHECK: %[[MASK_OUT:.*]] = vector.create_mask %[[C1]], %[[C8]], %[[CH_DIM_OUT]] : vector<1x8x4xi1>
// CHECK: %[[VEC_OUT:.*]] = vector.mask %[[MASK_OUT]] { vector.transfer_read %[[OUTPUT]]{{\[}}%[[C0]], %[[C0]], %[[C0]]], %[[PAD]] : tensor<1x8x?xi8>, vector<1x8x4xi8> } : vector<1x8x4xi1> -> vector<1x8x4xi8>
// CHECK: %[[VEC_OUT:.*]] = vector.mask %[[MASK_OUT]] { vector.transfer_read %[[OUTPUT]]{{\[}}%[[C0]], %[[C0]], %[[C0]]], %[[PAD]] {in_bounds = [true, true, true]} : tensor<1x8x?xi8>, vector<1x8x4xi8> } : vector<1x8x4xi1> -> vector<1x8x4xi8>

/// Convolution
// CHECK: %[[IN_1:.*]] = vector.extract_strided_slice %[[VEC_IN]] {offsets = [0, 0, 0], sizes = [1, 8, 4], strides = [1, 1, 1]} : vector<1x8x4xi8> to vector<1x8x4xi8>
Expand All @@ -55,7 +55,7 @@ module attributes {transform.with_named_sequence} {
// CHECK: %[[MULI:.*]] = arith.muli %[[IN_1]], %[[FLT_1_B]] : vector<1x8x4xi8>
// CHECK: %[[ADDI:.*]] = arith.addi %[[MULI]], %[[OUT_1]] : vector<1x8x4xi8>
// CHECK: %[[OUT_INS:.*]] = vector.insert_strided_slice %[[ADDI]], %[[VEC_OUT]] {offsets = [0, 0, 0], strides = [1, 1, 1]} : vector<1x8x4xi8> into vector<1x8x4xi8>
// CHECK: %[[OUT:.*]] = vector.mask %[[MASK_OUT]] { vector.transfer_write %[[OUT_INS]], %[[OUTPUT]]{{\[}}%[[C0]], %[[C0]], %[[C0]]] : vector<1x8x4xi8>, tensor<1x8x?xi8> } : vector<1x8x4xi1> -> tensor<1x8x?xi8>
// CHECK: %[[OUT:.*]] = vector.mask %[[MASK_OUT]] { vector.transfer_write %[[OUT_INS]], %[[OUTPUT]]{{\[}}%[[C0]], %[[C0]], %[[C0]]] {in_bounds = [true, true, true]} : vector<1x8x4xi8>, tensor<1x8x?xi8> } : vector<1x8x4xi1> -> tensor<1x8x?xi8>
// CHECK: return %[[OUT]] : tensor<1x8x?xi8>

// -----
Expand Down Expand Up @@ -95,19 +95,19 @@ module attributes {transform.with_named_sequence} {
// CHECK: %[[C8:.*]] = arith.constant 8 : index
// CHECK: %[[MASK_IN:.*]] = vector.create_mask %[[C1]], %[[C8]], %[[CH_DIM_IN]] : vector<1x8x[4]xi1>
/// Read the input tensor
// CHECK: %[[VEC_IN:.*]] = vector.mask %[[MASK_IN]] { vector.transfer_read %[[INPUT]]{{\[}}%[[C0]], %[[C0]], %[[C0]]], %[[PAD]] : tensor<1x8x?xi8>, vector<1x8x[4]xi8> } : vector<1x8x[4]xi1> -> vector<1x8x[4]xi8>
// CHECK: %[[VEC_IN:.*]] = vector.mask %[[MASK_IN]] { vector.transfer_read %[[INPUT]]{{\[}}%[[C0]], %[[C0]], %[[C0]]], %[[PAD]] {in_bounds = [true, true, true]} : tensor<1x8x?xi8>, vector<1x8x[4]xi8> } : vector<1x8x[4]xi1> -> vector<1x8x[4]xi8>

/// Create a mask for the filter tensor
// CHECK: %[[CH_DIM_FLT:.*]] = tensor.dim %[[FILTER]], %[[C1]] : tensor<1x?xi8>
// CHECK: %[[MASK_FLT:.*]] = vector.create_mask %[[C1]], %[[CH_DIM_FLT]] : vector<1x[4]xi1>
/// Read the filter tensor
// CHECK: %[[VEC_FLT:.*]] = vector.mask %[[MASK_FLT]] { vector.transfer_read %[[FILTER]]{{\[}}%[[C0]], %[[C0]]], %[[PAD]] : tensor<1x?xi8>, vector<1x[4]xi8> } : vector<1x[4]xi1> -> vector<1x[4]xi8>
// CHECK: %[[VEC_FLT:.*]] = vector.mask %[[MASK_FLT]] { vector.transfer_read %[[FILTER]]{{\[}}%[[C0]], %[[C0]]], %[[PAD]] {in_bounds = [true, true]} : tensor<1x?xi8>, vector<1x[4]xi8> } : vector<1x[4]xi1> -> vector<1x[4]xi8>

/// Create a mask for the output tensor
// CHECK: %[[CH_DIM_OUT:.*]] = tensor.dim %[[OUTPUT]], %[[C2]] : tensor<1x8x?xi8>
// CHECK: %[[MASK_OUT:.*]] = vector.create_mask %[[C1]], %[[C8]], %[[CH_DIM_OUT]] : vector<1x8x[4]xi1>
/// Read the output tensor
// CHECK: %[[VEC_OUT:.*]] = vector.mask %[[MASK_OUT]] { vector.transfer_read %[[OUTPUT]]{{\[}}%[[C0]], %[[C0]], %[[C0]]], %[[PAD]] : tensor<1x8x?xi8>, vector<1x8x[4]xi8> } : vector<1x8x[4]xi1> -> vector<1x8x[4]xi8>
// CHECK: %[[VEC_OUT:.*]] = vector.mask %[[MASK_OUT]] { vector.transfer_read %[[OUTPUT]]{{\[}}%[[C0]], %[[C0]], %[[C0]]], %[[PAD]] {in_bounds = [true, true, true]} : tensor<1x8x?xi8>, vector<1x8x[4]xi8> } : vector<1x8x[4]xi1> -> vector<1x8x[4]xi8>

/// Convolution
// CHECK: %[[IN_1:.*]] = vector.extract_strided_slice %[[VEC_IN]] {offsets = [0, 0, 0], sizes = [1, 8, 4], strides = [1, 1, 1]} : vector<1x8x[4]xi8> to vector<1x8x[4]xi8>
Expand All @@ -117,7 +117,7 @@ module attributes {transform.with_named_sequence} {
// CHECK: %[[MULI:.*]] = arith.muli %[[IN_1]], %[[FLT_1_B]] : vector<1x8x[4]xi8>
// CHECK: %[[ADDI:.*]] = arith.addi %[[MULI]], %[[OUT_1]] : vector<1x8x[4]xi8>
// CHECK: %[[OUT_INS:.*]] = vector.insert_strided_slice %[[ADDI]], %[[VEC_OUT]] {offsets = [0, 0, 0], strides = [1, 1, 1]} : vector<1x8x[4]xi8> into vector<1x8x[4]xi8>
// CHECK: %[[OUT:.*]] = vector.mask %[[MASK_OUT]] { vector.transfer_write %[[OUT_INS]], %[[OUTPUT]]{{\[}}%[[C0]], %[[C0]], %[[C0]]] : vector<1x8x[4]xi8>, tensor<1x8x?xi8> } : vector<1x8x[4]xi1> -> tensor<1x8x?xi8>
// CHECK: %[[OUT:.*]] = vector.mask %[[MASK_OUT]] { vector.transfer_write %[[OUT_INS]], %[[OUTPUT]]{{\[}}%[[C0]], %[[C0]], %[[C0]]] {in_bounds = [true, true, true]} : vector<1x8x[4]xi8>, tensor<1x8x?xi8> } : vector<1x8x[4]xi1> -> tensor<1x8x?xi8>
// CHECK: return %[[OUT]] : tensor<1x8x?xi8>

// -----
Expand Down Expand Up @@ -157,19 +157,19 @@ module attributes {transform.with_named_sequence} {
// CHECK: %[[C5:.*]] = arith.constant 5 : index
// CHECK: %[[MASK_IN:.*]] = vector.create_mask %[[C3]], %[[C5]], %[[CH_DIM_IN]] : vector<3x4x[4]xi1>
/// Read the input tensor
// CHECK: %[[VEC_IN:.*]] = vector.mask %[[MASK_IN]] { vector.transfer_read %[[INPUT]]{{\[}}%[[C0]], %[[C0]], %[[C0]]], %[[PAD]] : memref<3x5x?xf32>, vector<3x4x[4]xf32> } : vector<3x4x[4]xi1> -> vector<3x4x[4]xf32>
// CHECK: %[[VEC_IN:.*]] = vector.mask %[[MASK_IN]] { vector.transfer_read %[[INPUT]]{{\[}}%[[C0]], %[[C0]], %[[C0]]], %[[PAD]] {in_bounds = [true, true, true]} : memref<3x5x?xf32>, vector<3x4x[4]xf32> } : vector<3x4x[4]xi1> -> vector<3x4x[4]xf32>

/// Create a mask for the filter tensor
// CHECK: %[[CH_DIM_FLT:.*]] = memref.dim %[[FILTER]], %[[C1]] : memref<2x?xf32>
// CHECK: %[[MASK_FLT:.*]] = vector.create_mask %[[C2]], %[[CH_DIM_FLT]] : vector<2x[4]xi1>
/// Read the filter tensor
// CHECK: %[[VEC_FLT:.*]] = vector.mask %[[MASK_FLT]] { vector.transfer_read %[[FILTER]]{{\[}}%[[C0]], %[[C0]]], %[[PAD]] : memref<2x?xf32>, vector<2x[4]xf32> } : vector<2x[4]xi1> -> vector<2x[4]xf32>
// CHECK: %[[VEC_FLT:.*]] = vector.mask %[[MASK_FLT]] { vector.transfer_read %[[FILTER]]{{\[}}%[[C0]], %[[C0]]], %[[PAD]] {in_bounds = [true, true]} : memref<2x?xf32>, vector<2x[4]xf32> } : vector<2x[4]xi1> -> vector<2x[4]xf32>

/// Create a mask for the output tensor
// CHECK: %[[CH_DIM_OUT:.*]] = memref.dim %[[OUTPUT]], %[[C2]] : memref<3x2x?xf32>
// CHECK: %[[MASK_OUT:.*]] = vector.create_mask %[[C3]], %[[C2]], %[[CH_DIM_OUT]] : vector<3x2x[4]xi1>
/// Read the output tensor
// CHECK: %[[VEC_OUT:.*]] = vector.mask %[[MASK_OUT]] { vector.transfer_read %[[OUTPUT]]{{\[}}%[[C0]], %[[C0]], %[[C0]]], %[[PAD]] : memref<3x2x?xf32>, vector<3x2x[4]xf32> } : vector<3x2x[4]xi1> -> vector<3x2x[4]xf32>
// CHECK: %[[VEC_OUT:.*]] = vector.mask %[[MASK_OUT]] { vector.transfer_read %[[OUTPUT]]{{\[}}%[[C0]], %[[C0]], %[[C0]]], %[[PAD]] {in_bounds = [true, true, true]} : memref<3x2x?xf32>, vector<3x2x[4]xf32> } : vector<3x2x[4]xi1> -> vector<3x2x[4]xf32>

/// Convolution
// CHECK: %[[IN_1:.*]] = vector.extract_strided_slice %[[VEC_IN]] {offsets = [0, 0, 0], sizes = [3, 2, 4], strides = [1, 1, 1]} : vector<3x4x[4]xf32> to vector<3x2x[4]xf32>
Expand All @@ -182,4 +182,4 @@ module attributes {transform.with_named_sequence} {
// CHECK: %[[FLT_2_B:.*]] = vector.broadcast %[[FLT_2]] : vector<[4]xf32> to vector<3x2x[4]xf32>
// CHECK: %[[FMA_2:.*]] = vector.fma %[[IN_2]], %[[FLT_2_B]], %[[FMA_1]] : vector<3x2x[4]xf32>
// CHECK: %[[OUT_INS:.*]] = vector.insert_strided_slice %[[FMA_2]], %[[VEC_OUT]] {offsets = [0, 0, 0], strides = [1, 1, 1]} : vector<3x2x[4]xf32> into vector<3x2x[4]xf32>
// CHECK: vector.mask %[[MASK_OUT]] { vector.transfer_write %[[OUT_INS]], %[[OUTPUT]]{{\[}}%[[C0]], %[[C0]], %[[C0]]] : vector<3x2x[4]xf32>, memref<3x2x?xf32> } : vector<3x2x[4]xi1>
// CHECK: vector.mask %[[MASK_OUT]] { vector.transfer_write %[[OUT_INS]], %[[OUTPUT]]{{\[}}%[[C0]], %[[C0]], %[[C0]]] {in_bounds = [true, true, true]} : vector<3x2x[4]xf32>, memref<3x2x?xf32> } : vector<3x2x[4]xi1>

0 comments on commit 34e34a0

Please sign in to comment.