Skip to content

Commit

Permalink
[GPU] Do not generate insert_strided_slice for 0-d vectors (#19149)
Browse files Browse the repository at this point in the history
Co-authored-by: saienduri <[email protected]>
  • Loading branch information
Groverkss and saienduri authored Nov 14, 2024
1 parent bf711a1 commit eef2c3a
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -200,13 +200,20 @@ struct DistributeTransferRead final
rewriter, indices, offsets, vectorLayout, readOp.getPermutationMap(),
warpIndices, threadIndices);

Value slicedRead = rewriter.create<vector::TransferReadOp>(
VectorValue slicedRead = rewriter.create<vector::TransferReadOp>(
readOp.getLoc(), innerVectorType, readOp.getSource(), slicedIndices,
readOp.getPermutationMapAttr(), readOp.getPadding(), readOp.getMask(),
readOp.getInBoundsAttr());

acc = rewriter.create<vector::InsertStridedSliceOp>(
readOp.getLoc(), slicedRead, acc, offsets, strides);
if (acc.getType().getRank() == 0) {
// TODO: This should really be a folding pattern in
// insert_strided_slice, but instead insert_strided_slice just doesn't
// support 0-d vectors...
acc = slicedRead;
} else {
acc = rewriter.create<vector::InsertStridedSliceOp>(
readOp.getLoc(), slicedRead, acc, offsets, strides);
}
}

replaceOpWithDistributedValues(rewriter, readOp, acc);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -331,6 +331,42 @@ builtin.module attributes { transform.with_named_sequence } {

// -----

#layout = #iree_vector_ext.nested_layout<
subgroup_tile = [],
batch_tile = [],
outer_tile = [],
thread_tile = [],
element_tile = [],

subgroup_strides = [],
thread_strides = []
>

// CHECK-LABEL: @distribute_transfer_read_0d
func.func @distribute_transfer_read_0d(%arg0: memref<128xf16>) -> vector<f16> {
%c0 = arith.constant 0 : index
%cst = arith.constant 0.0 : f16
%root = vector.transfer_read %arg0[%c0], %cst
{in_bounds = []} : memref<128xf16>, vector<f16>
%rootl = iree_vector_ext.to_layout %root to layout(#layout) : vector<f16>
func.return %rootl : vector<f16>
}


builtin.module attributes { transform.with_named_sequence } {
transform.named_sequence @__transform_main(%variant_op: !transform.any_op {transform.readonly}) {
%top_level_func = transform.structured.match ops{["func.func"]} in %variant_op : (!transform.any_op) -> !transform.any_op
transform.iree.test_gpu_vector_distribution %top_level_func : !transform.any_op
transform.yield
}
}

// CHECK: %[[RD:.+]] = vector.transfer_read %{{.*}}[%c0]
// CHECK-SAME: memref<128xf16>, vector<f16>
// CHECK: iree_vector_ext.to_simd %[[RD]]

// -----

#layout_row_major = #iree_vector_ext.nested_layout<
subgroup_tile = [1, 1],
batch_tile = [2, 2],
Expand Down

0 comments on commit eef2c3a

Please sign in to comment.