Skip to content

Commit

Permalink
Figure out untiled shapes from tiled and distributed loop nest (#7648)
Browse files Browse the repository at this point in the history
This commit changes the existing logic for figuring out the untiled
result shapes to use tiled and distributed loop nest. This is needed
for tensor cases where we can have the root op consumed by some
element-wise ops. Those element-wise ops will intervene with the
connection to `flow.dispatch.tensor.store` ops. Additionally walking
use chains and recognizing consumer ops can be unwieldy and fragile.

So this instead uses the surrounding tiled and distributed loops to
get back the original untiled shapes.

It might not matter that much for matmul, given we can query M, N, K
via inputs through `flow.dispatch.tensor.load` ops. But for conv ops,
we cannot and we need to query the result.
  • Loading branch information
antiagainst authored Nov 12, 2021
1 parent f4025e0 commit f363d32
Show file tree
Hide file tree
Showing 8 changed files with 262 additions and 44 deletions.
2 changes: 1 addition & 1 deletion iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ static LogicalResult setRootDefaultConfig(FuncOp entryPoint, Operation *op) {
vectorSize = 1;
break;
}
ArrayRef<int64_t> shape = getUntiledResultShape(
SmallVector<int64_t> shape = getUntiledResultShape(
cast<linalg::LinalgOp>(op), outputOperand.index());
if (llvm::any_of(shape, ShapedType::isDynamic)) {
vectorSize = 1;
Expand Down
2 changes: 1 addition & 1 deletion iree/compiler/Codegen/SPIRV/KernelConfig.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ LogicalResult setConvOpConfig(linalg::LinalgOp linalgOp,
const int64_t subgroupSize,
const int64_t bestTilingFactor) {
ArrayRef<int64_t> inputShape = getUntiledShape(linalgOp.inputs()[0]);
ArrayRef<int64_t> outputShape = getUntiledResultShape(linalgOp, 0);
SmallVector<int64_t> outputShape = getUntiledResultShape(linalgOp, 0);
if (llvm::any_of(inputShape, ShapedType::isDynamic)) return success();
if (llvm::any_of(outputShape, ShapedType::isDynamic)) return success();

Expand Down
1 change: 1 addition & 0 deletions iree/compiler/Codegen/SPIRV/test/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ iree_lit_test_suite(
[
"config_adreno_conv.mlir",
"config_adreno_matmul.mlir",
"config_default_conv.mlir",
"config_default_linalg_ext_ops.mlir",
"config_default_linalg_ops.mlir",
"config_default_matmul.mlir",
Expand Down
1 change: 1 addition & 0 deletions iree/compiler/Codegen/SPIRV/test/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ iree_lit_test_suite(
SRCS
"config_adreno_conv.mlir"
"config_adreno_matmul.mlir"
"config_default_conv.mlir"
"config_default_linalg_ext_ops.mlir"
"config_default_linalg_ops.mlir"
"config_default_matmul.mlir"
Expand Down
110 changes: 110 additions & 0 deletions iree/compiler/Codegen/SPIRV/test/config_default_conv.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
// RUN: iree-opt -split-input-file -pass-pipeline='hal.executable(hal.executable.variant(iree-spirv-lower-executable-target-pass{test-lowering-configuration=true}))' %s | IreeFileCheck %s

// Convolution with consumer pointwise ops

#map0 = affine_map<()[s0, s1] -> (s0 * s1)>
#map1 = affine_map<(d0)[s0] -> (s0, -d0 + 112)>
#map2 = affine_map<(d0)[s0] -> (s0, -d0 + 32)>
#map3 = affine_map<(d0) -> (d0 * 2)>
#map4 = affine_map<(d0, d1) -> (d0 * 2 + 1, d1 * -2 + 225)>
#map5 = affine_map<(d0)[s0] -> (-d0 + 32, s0)>
#map6 = affine_map<(d0)[s0] -> (-d0 + 112, s0)>
#map7 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>

hal.executable private @conv_pointwise_112x112x32 {
hal.interface public @io {
hal.interface.binding public @s0b0_ro_external, set=0, binding=0, type="StorageBuffer", access="Read"
hal.interface.binding public @s0b1_ro_external, set=0, binding=1, type="StorageBuffer", access="Read"
hal.interface.binding public @s0b2_ro_external, set=0, binding=2, type="StorageBuffer", access="Read"
hal.interface.binding public @s0b3_xw_external, set=0, binding=3, type="StorageBuffer", access="Write|Discard"
}
hal.executable.variant public @vulkan_spirv_fb, target = #hal.executable.target<"vulkan", "vulkan-spirv-fb", {
spv.target_env = #spv.target_env<#spv.vce<v1.4, [Shader], []>, Unknown:IntegratedGPU, {
max_compute_shared_memory_size = 16384 : i32,
max_compute_workgroup_invocations = 128 : i32,
max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>,
subgroup_size = 32 : i32}>
}> {
hal.executable.entry_point public @conv_pointwise_112x112x32 attributes {interface = @io, ordinal = 0 : index}
builtin.module {
func @conv_pointwise_112x112x32() {
%c0 = arith.constant 0 : index
%cst = arith.constant 0.000000e+00 : f32
%c112 = arith.constant 112 : index
%c32 = arith.constant 32 : index
%0 = hal.interface.binding.subspan @io::@s0b0_ro_external[%c0] : !flow.dispatch.tensor<readonly:1x112x112x32xf32>
%1 = hal.interface.binding.subspan @io::@s0b1_ro_external[%c0] : !flow.dispatch.tensor<readonly:1x225x225x3xf32>
%2 = hal.interface.binding.subspan @io::@s0b2_ro_external[%c0] : !flow.dispatch.tensor<readonly:3x3x3x32xf32>
%3 = hal.interface.binding.subspan @io::@s0b3_xw_external[%c0] : !flow.dispatch.tensor<writeonly:1x112x112x32xf32>
%workgroup_size_x = hal.interface.workgroup.size[0] : index
%workgroup_size_y = hal.interface.workgroup.size[1] : index
%workgroup_size_z = hal.interface.workgroup.size[2] : index
%workgroup_id_x = hal.interface.workgroup.id[0] : index
%workgroup_count_x = hal.interface.workgroup.count[0] : index
%workgroup_id_y = hal.interface.workgroup.id[1] : index
%workgroup_count_y = hal.interface.workgroup.count[1] : index
%workgroup_id_z = hal.interface.workgroup.id[2] : index
%workgroup_count_z = hal.interface.workgroup.count[2] : index
%4 = affine.apply #map0()[%workgroup_id_z, %workgroup_size_z]
%5 = affine.apply #map0()[%workgroup_count_z, %workgroup_size_z]
scf.for %arg0 = %4 to %c112 step %5 {
%6 = affine.apply #map0()[%workgroup_id_y, %workgroup_size_y]
%7 = affine.apply #map0()[%workgroup_count_y, %workgroup_size_y]
scf.for %arg1 = %6 to %c112 step %7 {
%8 = affine.apply #map0()[%workgroup_id_x, %workgroup_size_x]
%9 = affine.apply #map0()[%workgroup_count_x, %workgroup_size_x]
scf.for %arg2 = %8 to %c32 step %9 {
%10 = affine.min #map1(%arg0)[%workgroup_size_z]
%11 = affine.min #map1(%arg1)[%workgroup_size_y]
%12 = affine.min #map2(%arg2)[%workgroup_size_x]
%13 = flow.dispatch.tensor.load %0, offsets = [0, %arg0, %arg1, %arg2], sizes = [1, %10, %11, %12], strides = [1, 1, 1, 1] : !flow.dispatch.tensor<readonly:1x112x112x32xf32> -> tensor<1x?x?x?xf32>
%14 = linalg.init_tensor [1, %10, %11, %12] : tensor<1x?x?x?xf32>
%15 = affine.apply #map3(%arg0)
%16 = affine.min #map4(%10, %arg0)
%17 = affine.apply #map3(%arg1)
%18 = affine.min #map4(%11, %arg1)
%19 = flow.dispatch.tensor.load %1, offsets = [0, %15, %17, 0], sizes = [1, %16, %18, 3], strides = [1, 1, 1, 1] : !flow.dispatch.tensor<readonly:1x225x225x3xf32> -> tensor<1x?x?x3xf32>
%20 = affine.min #map5(%arg2)[%workgroup_size_x]
%21 = flow.dispatch.tensor.load %2, offsets = [0, 0, 0, %arg2], sizes = [3, 3, 3, %20], strides = [1, 1, 1, 1] : !flow.dispatch.tensor<readonly:3x3x3x32xf32> -> tensor<3x3x3x?xf32>
%22 = affine.min #map6(%arg0)[%workgroup_size_z]
%23 = affine.min #map6(%arg1)[%workgroup_size_y]
%24 = linalg.init_tensor [1, %22, %23, %20] : tensor<1x?x?x?xf32>
%25 = linalg.fill(%cst, %24) : f32, tensor<1x?x?x?xf32> -> tensor<1x?x?x?xf32>
%26 = linalg.conv_2d_nhwc_hwcf {dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>} ins(%19, %21 : tensor<1x?x?x3xf32>, tensor<3x3x3x?xf32>) outs(%25 : tensor<1x?x?x?xf32>) -> tensor<1x?x?x?xf32>
%27 = linalg.generic {indexing_maps = [#map7, #map7, #map7], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%26, %13 : tensor<1x?x?x?xf32>, tensor<1x?x?x?xf32>) outs(%14 : tensor<1x?x?x?xf32>) {
^bb0(%arg3: f32, %arg4: f32, %arg5: f32): // no predecessors
%28 = arith.subf %arg3, %arg4 : f32
linalg.yield %28 : f32
} -> tensor<1x?x?x?xf32>
flow.dispatch.tensor.store %27, %3, offsets = [0, %arg0, %arg1, %arg2], sizes = [1, %10, %11, %12], strides = [1, 1, 1, 1] : tensor<1x?x?x?xf32> -> !flow.dispatch.tensor<writeonly:1x112x112x32xf32>
}
}
}
return
}
hal.interface private @io {
hal.interface.binding public @s0b0_ro_external, set=0, binding=0, type="StorageBuffer", access="Read"
hal.interface.binding public @s0b1_ro_external, set=0, binding=1, type="StorageBuffer", access="Read"
hal.interface.binding public @s0b2_ro_external, set=0, binding=2, type="StorageBuffer", access="Read"
hal.interface.binding public @s0b3_xw_external, set=0, binding=3, type="StorageBuffer", access="Write|Discard"
}
}
}
}

// CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering.config<tile_sizes = {{\[}}[0, 4, 4, 32], [0, 2, 2, 4], [0, 0, 0, 0, 1, 1, 4]{{\]}}, native_vector_size = []>
// CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation.info<"SPIRVVectorize", workload_per_wg = [32, 4, 4]>
// CHECK-DAG: #[[MAP_X:.+]] = affine_map<()[s0] -> (s0 ceildiv 32)
// CHECK-DAG: #[[MAP_YZ:.+]] = affine_map<()[s0] -> (s0 ceildiv 4)>
// CHECK: hal.executable.entry_point public @conv_pointwise_112x112x32
// CHECK-SAME: translation.info = #[[TRANSLATION]]
// CHECK-SAME: workgroup_size = [8 : index, 2 : index, 2 : index]
// CHECK-NEXT: ^{{.+}}(%[[X:.+]]: index, %[[Y:.+]]: index, %[[Z:.+]]: index):
// CHECK-NEXT: %[[X_COUNT:.+]] = affine.apply #[[MAP_X]]()[%[[X]]]
// CHECK-NEXT: %[[Y_COUNT:.+]] = affine.apply #[[MAP_YZ]]()[%[[Y]]]
// CHECK-NEXT: %[[Z_COUNT:.+]] = affine.apply #[[MAP_YZ]]()[%[[Z]]]
// CHECK-NEXT: hal.return %[[X_COUNT]], %[[Y_COUNT]], %[[Z_COUNT]]

// CHECK: func @conv_pointwise_112x112x32()
// CHECK: linalg.conv_2d_nhwc_hwcf
// CHECK-SAME: lowering.config = #[[CONFIG]]
101 changes: 101 additions & 0 deletions iree/compiler/Codegen/SPIRV/test/config_default_matmul.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -340,3 +340,104 @@ hal.executable @matmul_25x546 {
// CHECK: func @matmul_25x546()
// CHECK: linalg.matmul
// CHECK-SAME: lowering.config = #[[CONFIG]]

// -----

// Matmul with consumer pointwise ops

#map0 = affine_map<()[s0, s1] -> (s0 * s1)>
#map1 = affine_map<(d0)[s0] -> (s0, -d0 + 256)>
#map2 = affine_map<(d0)[s0] -> (s0, -d0 + 1024)>
#map3 = affine_map<(d0)[s0] -> (-d0 + 256, s0)>
#map4 = affine_map<(d0)[s0] -> (-d0 + 1024, s0)>
#map5 = affine_map<(d0, d1) -> (d0, d1)>

hal.executable private @matmul_pointwise_256x1024 {
hal.interface public @io {
hal.interface.binding public @s0b0_ro_external, set=0, binding=0, type="StorageBuffer", access="Read"
hal.interface.binding public @s0b1_ro_external, set=0, binding=1, type="StorageBuffer", access="Read"
hal.interface.binding public @s0b2_ro_external, set=0, binding=2, type="StorageBuffer", access="Read"
hal.interface.binding public @s0b3_ro_external, set=0, binding=3, type="StorageBuffer", access="Read"
hal.interface.binding public @s0b4_xw_external, set=0, binding=4, type="StorageBuffer", access="Write|Discard"
}
hal.executable.variant public @vulkan_spirv_fb, target = #hal.executable.target<"vulkan", "vulkan-spirv-fb", {
spv.target_env = #spv.target_env<#spv.vce<v1.4, [Shader], []>, Unknown:IntegratedGPU, {
max_compute_shared_memory_size = 16384 : i32,
max_compute_workgroup_invocations = 128 : i32,
max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>,
subgroup_size = 32 : i32}>
}> {
hal.executable.entry_point public @matmul_pointwise_256x1024 attributes {interface = @io, ordinal = 0 : index}
builtin.module {
func @matmul_pointwise_256x1024() {
%c0 = arith.constant 0 : index
%cst = arith.constant 0.000000e+00 : f16
%c256 = arith.constant 256 : index
%c1024 = arith.constant 1024 : index
%0 = hal.interface.binding.subspan @io::@s0b0_ro_external[%c0] : !flow.dispatch.tensor<readonly:256x1024xf16>
%1 = hal.interface.binding.subspan @io::@s0b1_ro_external[%c0] : !flow.dispatch.tensor<readonly:256x1024xf16>
%2 = hal.interface.binding.subspan @io::@s0b2_ro_external[%c0] : !flow.dispatch.tensor<readonly:256x128xf16>
%3 = hal.interface.binding.subspan @io::@s0b3_ro_external[%c0] : !flow.dispatch.tensor<readonly:128x1024xf16>
%4 = hal.interface.binding.subspan @io::@s0b4_xw_external[%c0] : !flow.dispatch.tensor<writeonly:256x1024xf16>
%workgroup_size_x = hal.interface.workgroup.size[0] : index
%workgroup_size_y = hal.interface.workgroup.size[1] : index
%workgroup_id_x = hal.interface.workgroup.id[0] : index
%workgroup_count_x = hal.interface.workgroup.count[0] : index
%workgroup_id_y = hal.interface.workgroup.id[1] : index
%workgroup_count_y = hal.interface.workgroup.count[1] : index
%5 = affine.apply #map0()[%workgroup_id_y, %workgroup_size_y]
%6 = affine.apply #map0()[%workgroup_count_y, %workgroup_size_y]
scf.for %arg0 = %5 to %c256 step %6 {
%7 = affine.apply #map0()[%workgroup_id_x, %workgroup_size_x]
%8 = affine.apply #map0()[%workgroup_count_x, %workgroup_size_x]
scf.for %arg1 = %7 to %c1024 step %8 {
%9 = affine.min #map1(%arg0)[%workgroup_size_y]
%10 = affine.min #map2(%arg1)[%workgroup_size_x]
%11 = flow.dispatch.tensor.load %0, offsets = [%arg0, %arg1], sizes = [%9, %10], strides = [1, 1] : !flow.dispatch.tensor<readonly:256x1024xf16> -> tensor<?x?xf16>
%12 = flow.dispatch.tensor.load %1, offsets = [%arg0, %arg1], sizes = [%9, %10], strides = [1, 1] : !flow.dispatch.tensor<readonly:256x1024xf16> -> tensor<?x?xf16>
%13 = linalg.init_tensor [%9, %10] : tensor<?x?xf16>
%14 = affine.min #map3(%arg0)[%workgroup_size_y]
%15 = flow.dispatch.tensor.load %2, offsets = [%arg0, 0], sizes = [%14, 128], strides = [1, 1] : !flow.dispatch.tensor<readonly:256x128xf16> -> tensor<?x128xf16>
%16 = affine.min #map4(%arg1)[%workgroup_size_x]
%17 = flow.dispatch.tensor.load %3, offsets = [0, %arg1], sizes = [128, %16], strides = [1, 1] : !flow.dispatch.tensor<readonly:128x1024xf16> -> tensor<128x?xf16>
%18 = linalg.init_tensor [%14, %16] : tensor<?x?xf16>
%19 = linalg.fill(%cst, %18) : f16, tensor<?x?xf16> -> tensor<?x?xf16>
%20 = linalg.matmul ins(%15, %17 : tensor<?x128xf16>, tensor<128x?xf16>) outs(%19 : tensor<?x?xf16>) -> tensor<?x?xf16>
%21 = linalg.generic {indexing_maps = [#map5, #map5, #map5, #map5], iterator_types = ["parallel", "parallel"]} ins(%20, %11, %12 : tensor<?x?xf16>, tensor<?x?xf16>, tensor<?x?xf16>) outs(%13 : tensor<?x?xf16>) {
^bb0(%arg2: f16, %arg3: f16, %arg4: f16, %arg5: f16): // no predecessors
%22 = arith.divf %arg2, %arg3 : f16
%23 = arith.subf %22, %arg4 : f16
linalg.yield %23 : f16
} -> tensor<?x?xf16>
flow.dispatch.tensor.store %21, %4, offsets = [%arg0, %arg1], sizes = [%9, %10], strides = [1, 1] : tensor<?x?xf16> -> !flow.dispatch.tensor<writeonly:256x1024xf16>
}
}
return
}
hal.interface private @io {
hal.interface.binding public @s0b0_ro_external, set=0, binding=0, type="StorageBuffer", access="Read"
hal.interface.binding public @s0b1_ro_external, set=0, binding=1, type="StorageBuffer", access="Read"
hal.interface.binding public @s0b2_ro_external, set=0, binding=2, type="StorageBuffer", access="Read"
hal.interface.binding public @s0b3_ro_external, set=0, binding=3, type="StorageBuffer", access="Read"
hal.interface.binding public @s0b4_xw_external, set=0, binding=4, type="StorageBuffer", access="Write|Discard"
}
}
}
}

// CHECK-DAG: #[[CONFIG:.+]] = #iree_codegen.lowering.config<tile_sizes = {{\[}}[16, 256], [8, 8], [0, 0, 4]{{\]}}, native_vector_size = []>
// CHECK-DAG: #[[MAP_X:.+]] = affine_map<()[s0] -> (s0 ceildiv 256)>
// CHECK-DAG: #[[MAP_Y:.+]] = affine_map<()[s0] -> (s0 ceildiv 16)>
// CHECK-DAG: #[[TRANSLATION:.+]] = #iree_codegen.translation.info<"SPIRVVectorize", workload_per_wg = [256, 16]>
// CHECK: hal.executable.entry_point public @matmul_pointwise_256x1024
// CHECK-SAME: translation.info = #[[TRANSLATION]]
// CHECK-SAME: workgroup_size = [32 : index, 2 : index, 1 : index]
// CHECK-NEXT: ^{{.+}}(%[[X:.+]]: index, %[[Y:.+]]: index, %{{.+}}: index):
// CHECK-NEXT: %[[ONE:.+]] = arith.constant 1 : index
// CHECK-NEXT: %[[X_COUNT:.+]] = affine.apply #[[MAP_X]]()[%[[X]]]
// CHECK-NEXT: %[[Y_COUNT:.+]] = affine.apply #[[MAP_Y]]()[%[[Y]]]
// CHECK-NEXT: hal.return %[[X_COUNT]], %[[Y_COUNT]], %[[ONE]]

// CHECK: func @matmul_pointwise_256x1024()
// CHECK: linalg.matmul
// CHECK-SAME: lowering.config = #[[CONFIG]]
Loading

0 comments on commit f363d32

Please sign in to comment.