From 546ae623ebdd3aed0bce896de7f73cac050a50b3 Mon Sep 17 00:00:00 2001 From: Nilaykumar Patel Date: Wed, 30 Oct 2024 11:27:12 +0530 Subject: [PATCH] Make reductin on 32 rows at a time instead of 16. (#14344) #12507 Make reductin on 32 rows at a time instead of 16. Signed-off-by: Nilaykumar K Patel --- .../max_pool_multi_core_large_kernel.cpp | 41 ++++++++++--------- ...core_sharded_with_halo_large_kernel_v2.cpp | 12 ++++-- .../max_pool2d_multi_core_program_factory.cpp | 18 ++++---- 3 files changed, 39 insertions(+), 32 deletions(-) diff --git a/ttnn/cpp/ttnn/operations/pool/maxpool/device/kernels/compute/max_pool_multi_core_large_kernel.cpp b/ttnn/cpp/ttnn/operations/pool/maxpool/device/kernels/compute/max_pool_multi_core_large_kernel.cpp index 8a03fb05ef6..90a7d6c0a40 100644 --- a/ttnn/cpp/ttnn/operations/pool/maxpool/device/kernels/compute/max_pool_multi_core_large_kernel.cpp +++ b/ttnn/cpp/ttnn/operations/pool/maxpool/device/kernels/compute/max_pool_multi_core_large_kernel.cpp @@ -67,7 +67,7 @@ inline void reduce_h_fused( const uint32_t out_cb_id, const uint32_t unpA_face_r_dim) { constexpr uint32_t num_output_tiles = out_ntiles_c * nblocks; - constexpr uint32_t num_faces_in_tile = is_partial_tile ? 1 : 2; + uint32_t num_faces_in_input_tile = is_partial_tile ? 1 : unpA_face_r_dim < 32 ? 2 : 4; constexpr uint32_t num_out_rows = 1; for (uint32_t out_elem_i = 0; out_elem_i < nblocks; ++out_elem_i) { const uint32_t curr_in_cb_id = @@ -78,10 +78,10 @@ inline void reduce_h_fused( in_scalar_cb_id, num_tiles_for_reduction, 0 /*tile idx for Src b is 0 because only 1 tile of constants is loaded*/, - num_faces_in_tile /* unpack 1 or 2 faces ) */, + num_faces_in_input_tile /* unpack 1 or 2 faces ) */, unpA_face_r_dim); for (uint32_t c_i = 0; c_i < num_tiles_for_reduction; ++c_i) { - reduce_tile_math(in_ntiles_c * out_elem_i + c_i, num_faces_in_tile /* reduce 1 or 2 faces */); + reduce_tile_math(in_ntiles_c * out_elem_i + c_i, num_faces_in_input_tile /* reduce 1 or 2 faces */); } cb_pop_front(curr_in_cb_id, 1); } @@ -104,6 +104,7 @@ void MAIN { constexpr uint32_t nsticks_per_core_by_nblocks = get_compile_time_arg_val(13); constexpr uint32_t in_c = get_compile_time_arg_val(14); + constexpr uint32_t max_rows_for_reduction = get_compile_time_arg_val(16); constexpr uint32_t num_output_tiles = out_ntiles_c * nblocks; constexpr uint32_t in_cb_id = tt::CB::c_in0; // and tt::CB::c_in1 for split reader @@ -114,9 +115,9 @@ void MAIN { constexpr bool is_partial_tile = in_c < 32; static_assert((!is_partial_tile || (in_c == 16)), "Partial tile must have c_dim 16"); - constexpr uint32_t num_faces_in_tile = is_partial_tile ? 1 : 2; + constexpr uint32_t num_faces_in_input_tile = is_partial_tile ? 1 : max_rows_for_reduction < 32 ? 2 : 4; + constexpr uint32_t num_faces_in_output_tile = is_partial_tile ? 1 : 2; constexpr uint32_t num_out_rows = 1; - constexpr uint32_t MAX_ROWS_FOR_REDUCTION = 16; constexpr uint32_t MAX_TILES_PER_REDUCTION = 8; constexpr uint32_t num_tiles_for_reduction = @@ -132,10 +133,10 @@ void MAIN { in_scalar_cb_id, num_tiles_for_reduction, interm_reduction_cb_id, - num_faces_in_tile, - MAX_ROWS_FOR_REDUCTION); + num_faces_in_input_tile, + max_rows_for_reduction); - uint32_t interm_reduction_chunks = window_size_hw / MAX_ROWS_FOR_REDUCTION; + uint32_t interm_reduction_chunks = window_size_hw / max_rows_for_reduction; cb_wait_front(in_scalar_cb_id, 1); cb_reserve_back(out_cb_id, 1); for (uint32_t i = 0; i < nsticks_per_core_by_nblocks; ++i) { @@ -144,8 +145,8 @@ void MAIN { // TODO: subblocking to support this. uint32_t out_write_idx = i * num_8_tiles_blocks + j; - pack_untilize_dst_init_short( - interm_reduction_cb_id, num_out_rows, num_faces_in_tile); + pack_untilize_dst_init_short( + interm_reduction_cb_id, num_out_rows, num_faces_in_output_tile); cb_reserve_back(interm_reduction_cb_id, 1); for (uint32_t h = 0; h <= interm_reduction_chunks; h++) { tile_regs_acquire(); @@ -156,22 +157,22 @@ void MAIN { num_tiles_for_reduction, i, interm_reduction_cb_id, - MAX_ROWS_FOR_REDUCTION); + max_rows_for_reduction); tile_regs_commit(); tile_regs_wait(); - pack_untilize_dst( + pack_untilize_dst( interm_reduction_cb_id, 1 /*out_subblock_h*/, h, num_out_rows, - num_faces_in_tile); /* pack 1 row (1x16 or 1x32) */ + num_faces_in_output_tile); /* pack 1 row (1x16 or 1x32) */ tile_regs_release(); } cb_push_back(interm_reduction_cb_id, 1); pack_untilize_uninit(interm_reduction_cb_id); cb_wait_front(interm_reduction_cb_id, 1); - pack_untilize_dst_init_short( - out_cb_id, num_out_rows, num_faces_in_tile); + pack_untilize_dst_init_short( + out_cb_id, num_out_rows, num_faces_in_output_tile); tile_regs_acquire(); unpack_tilizeA_B_block( @@ -179,20 +180,20 @@ void MAIN { in_scalar_cb_id, num_tiles_for_reduction, 0 /*tile idx for Src b is 0 because only 1 tile of constants is loaded*/, - num_faces_in_tile /* unpack 1 or 2 faces ) */, - MAX_ROWS_FOR_REDUCTION); + num_faces_in_input_tile /* unpack 1 or 2 faces ) */, + max_rows_for_reduction); for (uint32_t c_i = 0; c_i < num_tiles_for_reduction; ++c_i) { - reduce_tile_math(c_i, num_faces_in_tile /* reduce 1 or 2 faces */); + reduce_tile_math(c_i, num_faces_in_input_tile /* reduce 1 or 2 faces */); } tile_regs_commit(); tile_regs_wait(); - pack_untilize_dst( + pack_untilize_dst( out_cb_id, 1 /*out_subblock_h*/, out_write_idx, num_out_rows, - num_faces_in_tile); /* pack 1 row (1x16 or 1x32) */ + num_faces_in_output_tile); /* pack 1 row (1x16 or 1x32) */ tile_regs_release(); cb_pop_front(interm_reduction_cb_id, 1); pack_untilize_uninit(out_cb_id); diff --git a/ttnn/cpp/ttnn/operations/pool/maxpool/device/kernels/dataflow/reader_max_pool_2d_multi_core_sharded_with_halo_large_kernel_v2.cpp b/ttnn/cpp/ttnn/operations/pool/maxpool/device/kernels/dataflow/reader_max_pool_2d_multi_core_sharded_with_halo_large_kernel_v2.cpp index 3a63cc8d04d..c6b0ea9f930 100644 --- a/ttnn/cpp/ttnn/operations/pool/maxpool/device/kernels/dataflow/reader_max_pool_2d_multi_core_sharded_with_halo_large_kernel_v2.cpp +++ b/ttnn/cpp/ttnn/operations/pool/maxpool/device/kernels/dataflow/reader_max_pool_2d_multi_core_sharded_with_halo_large_kernel_v2.cpp @@ -67,11 +67,12 @@ void kernel_main() { // value of 1 in bf16 in a uin32_t constexpr uint32_t bf16_one_u32 = get_compile_time_arg_val(12); + constexpr uint32_t max_rows_for_reduction = get_compile_time_arg_val(14); + // static_assert(0 == reader_nindices%2, "reader_nindices must be multiple of 2"); constexpr uint32_t TILE_SIZE = 32 * 32; constexpr uint32_t MAX_TILES_PER_REDUCTION = 8; - constexpr uint32_t MAX_ROWS_FOR_REDUCTION = 16; constexpr uint32_t MAX_ELE_PER_REDUCTION = 512; constexpr uint32_t in_cb_id = (reader_id == 1) ? tt::CB::c_in1 : tt::CB::c_in0; @@ -112,7 +113,7 @@ void kernel_main() { } uint32_t counter = reader_id; uint32_t total_elems_to_reduce = window_h * window_w; - uint32_t remaining_elems = total_elems_to_reduce % MAX_ROWS_FOR_REDUCTION; + uint32_t remaining_elems = total_elems_to_reduce % max_rows_for_reduction; while (counter < reader_nindices) { for (uint32_t j = 0; j < num_8_tile_blocks; j++) { for (uint32_t i = 0; i < nblocks; ++i) { @@ -122,6 +123,9 @@ void kernel_main() { uint32_t out_l1_write_addr_base = get_write_ptr(in_cb_id); uint32_t out_l1_write_addr = out_l1_write_addr_base; cb_reserve_back(in_cb_id, npages_to_reserve); + // If next is last chunk, fill whole buffer with -inf. + if ((total_elems_to_reduce - processed_rows) < max_rows_for_reduction) + fill_with_val(out_l1_write_addr, TILE_SIZE * MAX_TILES_PER_REDUCTION, minus_inf); for (uint32_t h = 0; h < window_h; ++h, h_multiples += in_w_padded) { uint32_t stick_offset = top_left_local_index + h_multiples; uint32_t read_offset = @@ -131,14 +135,14 @@ void kernel_main() { out_l1_write_addr += read_bytes; read_offset += in_nbytes_c; processed_rows++; - if ((processed_rows % MAX_ROWS_FOR_REDUCTION) == 0) { + if ((processed_rows % max_rows_for_reduction) == 0) { noc_async_read_barrier(); cb_push_back(in_cb_id, npages_to_reserve); out_l1_write_addr_base = get_write_ptr(in_cb_id); out_l1_write_addr = out_l1_write_addr_base; cb_reserve_back(in_cb_id, npages_to_reserve); // If next is last chunk, fill whole buffer with -inf. - if ((total_elems_to_reduce - processed_rows) < MAX_ROWS_FOR_REDUCTION) + if ((total_elems_to_reduce - processed_rows) < max_rows_for_reduction) fill_with_val(out_l1_write_addr, TILE_SIZE * MAX_TILES_PER_REDUCTION, minus_inf); } } diff --git a/ttnn/cpp/ttnn/operations/pool/maxpool/device/max_pool2d_multi_core_program_factory.cpp b/ttnn/cpp/ttnn/operations/pool/maxpool/device/max_pool2d_multi_core_program_factory.cpp index 061c10aa0eb..c21873f788d 100644 --- a/ttnn/cpp/ttnn/operations/pool/maxpool/device/max_pool2d_multi_core_program_factory.cpp +++ b/ttnn/cpp/ttnn/operations/pool/maxpool/device/max_pool2d_multi_core_program_factory.cpp @@ -2,12 +2,7 @@ // // SPDX-License-Identifier: Apache-2.0 - - - -#include "impl/buffers/buffer_constants.hpp" #include "max_pool2d_device_op.hpp" -// #include "max_pool2d_multi_core_program_factory.hpp" #include "ttnn/operations/reduction/generic/device/reduce_op.hpp" // for reduce_op_utils /** @@ -272,6 +267,10 @@ MaxPool2D::MultiCore::cached_program_t max_pool_2d_multi_core_sharded_with_halo_ } #endif + uint32_t max_rows_for_reduction = tt::constants::TILE_HEIGHT; + /* For GRAYSKULL, make reduction for 16 rows at a time.*/ + if (device->arch() == tt::ARCH::GRAYSKULL) + max_rows_for_reduction /= 2; /** * Reader Kernel: input rows -> input cb */ @@ -292,7 +291,8 @@ MaxPool2D::MultiCore::cached_program_t max_pool_2d_multi_core_sharded_with_halo_ split_reader, // enable split reader 0, // split reader id bf16_one_u32, - in_nblocks_c}; + in_nblocks_c, + max_rows_for_reduction}; std::vector reader1_ct_args = { out_nhw_per_core, @@ -308,7 +308,8 @@ MaxPool2D::MultiCore::cached_program_t max_pool_2d_multi_core_sharded_with_halo_ split_reader, // enable split reader 1, // split reader id bf16_one_u32, - in_nblocks_c}; + in_nblocks_c, + max_rows_for_reduction}; std::string reader_kernel_fname; if (is_large_kernel) { @@ -350,7 +351,8 @@ MaxPool2D::MultiCore::cached_program_t max_pool_2d_multi_core_sharded_with_halo_ split_reader, // enable split reader out_nhw_per_core / nblocks, // loop count with blocks input_shape[3] / num_shards_c, - in_nblocks_c}; + in_nblocks_c, + max_rows_for_reduction}; auto reduce_op = tt::tt_metal::ReduceOpMath::MAX; auto reduce_dim = tt::tt_metal::ReduceOpDim::H;