Skip to content

Commit

Permalink
Make reductin on 32 rows at a time instead of 16. (tenstorrent#14344)
Browse files Browse the repository at this point in the history
tenstorrent#12507  Make reductin on 32 rows at a time instead of 16.

Signed-off-by: Nilaykumar K Patel <[email protected]>
  • Loading branch information
nkpatel-tt authored Oct 30, 2024
1 parent db8aa1c commit 546ae62
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 32 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ inline void reduce_h_fused(
const uint32_t out_cb_id,
const uint32_t unpA_face_r_dim) {
constexpr uint32_t num_output_tiles = out_ntiles_c * nblocks;
constexpr uint32_t num_faces_in_tile = is_partial_tile ? 1 : 2;
uint32_t num_faces_in_input_tile = is_partial_tile ? 1 : unpA_face_r_dim < 32 ? 2 : 4;
constexpr uint32_t num_out_rows = 1;
for (uint32_t out_elem_i = 0; out_elem_i < nblocks; ++out_elem_i) {
const uint32_t curr_in_cb_id =
Expand All @@ -78,10 +78,10 @@ inline void reduce_h_fused(
in_scalar_cb_id,
num_tiles_for_reduction,
0 /*tile idx for Src b is 0 because only 1 tile of constants is loaded*/,
num_faces_in_tile /* unpack 1 or 2 faces ) */,
num_faces_in_input_tile /* unpack 1 or 2 faces ) */,
unpA_face_r_dim);
for (uint32_t c_i = 0; c_i < num_tiles_for_reduction; ++c_i) {
reduce_tile_math(in_ntiles_c * out_elem_i + c_i, num_faces_in_tile /* reduce 1 or 2 faces */);
reduce_tile_math(in_ntiles_c * out_elem_i + c_i, num_faces_in_input_tile /* reduce 1 or 2 faces */);
}
cb_pop_front(curr_in_cb_id, 1);
}
Expand All @@ -104,6 +104,7 @@ void MAIN {

constexpr uint32_t nsticks_per_core_by_nblocks = get_compile_time_arg_val(13);
constexpr uint32_t in_c = get_compile_time_arg_val(14);
constexpr uint32_t max_rows_for_reduction = get_compile_time_arg_val(16);
constexpr uint32_t num_output_tiles = out_ntiles_c * nblocks;

constexpr uint32_t in_cb_id = tt::CB::c_in0; // and tt::CB::c_in1 for split reader
Expand All @@ -114,9 +115,9 @@ void MAIN {

constexpr bool is_partial_tile = in_c < 32;
static_assert((!is_partial_tile || (in_c == 16)), "Partial tile must have c_dim 16");
constexpr uint32_t num_faces_in_tile = is_partial_tile ? 1 : 2;
constexpr uint32_t num_faces_in_input_tile = is_partial_tile ? 1 : max_rows_for_reduction < 32 ? 2 : 4;
constexpr uint32_t num_faces_in_output_tile = is_partial_tile ? 1 : 2;
constexpr uint32_t num_out_rows = 1;
constexpr uint32_t MAX_ROWS_FOR_REDUCTION = 16;
constexpr uint32_t MAX_TILES_PER_REDUCTION = 8;

constexpr uint32_t num_tiles_for_reduction =
Expand All @@ -132,10 +133,10 @@ void MAIN {
in_scalar_cb_id,
num_tiles_for_reduction,
interm_reduction_cb_id,
num_faces_in_tile,
MAX_ROWS_FOR_REDUCTION);
num_faces_in_input_tile,
max_rows_for_reduction);

uint32_t interm_reduction_chunks = window_size_hw / MAX_ROWS_FOR_REDUCTION;
uint32_t interm_reduction_chunks = window_size_hw / max_rows_for_reduction;
cb_wait_front(in_scalar_cb_id, 1);
cb_reserve_back(out_cb_id, 1);
for (uint32_t i = 0; i < nsticks_per_core_by_nblocks; ++i) {
Expand All @@ -144,8 +145,8 @@ void MAIN {
// TODO: subblocking to support this.
uint32_t out_write_idx = i * num_8_tiles_blocks + j;

pack_untilize_dst_init_short<num_tiles_for_reduction, num_output_tiles>(
interm_reduction_cb_id, num_out_rows, num_faces_in_tile);
pack_untilize_dst_init_short<num_tiles_for_reduction>(
interm_reduction_cb_id, num_out_rows, num_faces_in_output_tile);
cb_reserve_back(interm_reduction_cb_id, 1);
for (uint32_t h = 0; h <= interm_reduction_chunks; h++) {
tile_regs_acquire();
Expand All @@ -156,43 +157,43 @@ void MAIN {
num_tiles_for_reduction,
i,
interm_reduction_cb_id,
MAX_ROWS_FOR_REDUCTION);
max_rows_for_reduction);
tile_regs_commit();
tile_regs_wait();
pack_untilize_dst<num_tiles_for_reduction, num_output_tiles>(
pack_untilize_dst<num_tiles_for_reduction>(
interm_reduction_cb_id,
1 /*out_subblock_h*/,
h,
num_out_rows,
num_faces_in_tile); /* pack 1 row (1x16 or 1x32) */
num_faces_in_output_tile); /* pack 1 row (1x16 or 1x32) */
tile_regs_release();
}
cb_push_back(interm_reduction_cb_id, 1);
pack_untilize_uninit(interm_reduction_cb_id);
cb_wait_front(interm_reduction_cb_id, 1);
pack_untilize_dst_init_short<num_tiles_for_reduction, num_output_tiles>(
out_cb_id, num_out_rows, num_faces_in_tile);
pack_untilize_dst_init_short<num_tiles_for_reduction>(
out_cb_id, num_out_rows, num_faces_in_output_tile);

tile_regs_acquire();
unpack_tilizeA_B_block(
interm_reduction_cb_id,
in_scalar_cb_id,
num_tiles_for_reduction,
0 /*tile idx for Src b is 0 because only 1 tile of constants is loaded*/,
num_faces_in_tile /* unpack 1 or 2 faces ) */,
MAX_ROWS_FOR_REDUCTION);
num_faces_in_input_tile /* unpack 1 or 2 faces ) */,
max_rows_for_reduction);
for (uint32_t c_i = 0; c_i < num_tiles_for_reduction; ++c_i) {
reduce_tile_math(c_i, num_faces_in_tile /* reduce 1 or 2 faces */);
reduce_tile_math(c_i, num_faces_in_input_tile /* reduce 1 or 2 faces */);
}

tile_regs_commit();
tile_regs_wait();
pack_untilize_dst<num_tiles_for_reduction, num_output_tiles>(
pack_untilize_dst<num_tiles_for_reduction>(
out_cb_id,
1 /*out_subblock_h*/,
out_write_idx,
num_out_rows,
num_faces_in_tile); /* pack 1 row (1x16 or 1x32) */
num_faces_in_output_tile); /* pack 1 row (1x16 or 1x32) */
tile_regs_release();
cb_pop_front(interm_reduction_cb_id, 1);
pack_untilize_uninit(out_cb_id);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,11 +67,12 @@ void kernel_main() {
// value of 1 in bf16 in a uin32_t
constexpr uint32_t bf16_one_u32 = get_compile_time_arg_val(12);

constexpr uint32_t max_rows_for_reduction = get_compile_time_arg_val(14);

// static_assert(0 == reader_nindices%2, "reader_nindices must be multiple of 2");

constexpr uint32_t TILE_SIZE = 32 * 32;
constexpr uint32_t MAX_TILES_PER_REDUCTION = 8;
constexpr uint32_t MAX_ROWS_FOR_REDUCTION = 16;
constexpr uint32_t MAX_ELE_PER_REDUCTION = 512;

constexpr uint32_t in_cb_id = (reader_id == 1) ? tt::CB::c_in1 : tt::CB::c_in0;
Expand Down Expand Up @@ -112,7 +113,7 @@ void kernel_main() {
}
uint32_t counter = reader_id;
uint32_t total_elems_to_reduce = window_h * window_w;
uint32_t remaining_elems = total_elems_to_reduce % MAX_ROWS_FOR_REDUCTION;
uint32_t remaining_elems = total_elems_to_reduce % max_rows_for_reduction;
while (counter < reader_nindices) {
for (uint32_t j = 0; j < num_8_tile_blocks; j++) {
for (uint32_t i = 0; i < nblocks; ++i) {
Expand All @@ -122,6 +123,9 @@ void kernel_main() {
uint32_t out_l1_write_addr_base = get_write_ptr(in_cb_id);
uint32_t out_l1_write_addr = out_l1_write_addr_base;
cb_reserve_back(in_cb_id, npages_to_reserve);
// If next is last chunk, fill whole buffer with -inf.
if ((total_elems_to_reduce - processed_rows) < max_rows_for_reduction)
fill_with_val(out_l1_write_addr, TILE_SIZE * MAX_TILES_PER_REDUCTION, minus_inf);
for (uint32_t h = 0; h < window_h; ++h, h_multiples += in_w_padded) {
uint32_t stick_offset = top_left_local_index + h_multiples;
uint32_t read_offset =
Expand All @@ -131,14 +135,14 @@ void kernel_main() {
out_l1_write_addr += read_bytes;
read_offset += in_nbytes_c;
processed_rows++;
if ((processed_rows % MAX_ROWS_FOR_REDUCTION) == 0) {
if ((processed_rows % max_rows_for_reduction) == 0) {
noc_async_read_barrier();
cb_push_back(in_cb_id, npages_to_reserve);
out_l1_write_addr_base = get_write_ptr(in_cb_id);
out_l1_write_addr = out_l1_write_addr_base;
cb_reserve_back(in_cb_id, npages_to_reserve);
// If next is last chunk, fill whole buffer with -inf.
if ((total_elems_to_reduce - processed_rows) < MAX_ROWS_FOR_REDUCTION)
if ((total_elems_to_reduce - processed_rows) < max_rows_for_reduction)
fill_with_val(out_l1_write_addr, TILE_SIZE * MAX_TILES_PER_REDUCTION, minus_inf);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,7 @@
//
// SPDX-License-Identifier: Apache-2.0




#include "impl/buffers/buffer_constants.hpp"
#include "max_pool2d_device_op.hpp"
// #include "max_pool2d_multi_core_program_factory.hpp"
#include "ttnn/operations/reduction/generic/device/reduce_op.hpp" // for reduce_op_utils

/**
Expand Down Expand Up @@ -272,6 +267,10 @@ MaxPool2D::MultiCore::cached_program_t max_pool_2d_multi_core_sharded_with_halo_
}
#endif

uint32_t max_rows_for_reduction = tt::constants::TILE_HEIGHT;
/* For GRAYSKULL, make reduction for 16 rows at a time.*/
if (device->arch() == tt::ARCH::GRAYSKULL)
max_rows_for_reduction /= 2;
/**
* Reader Kernel: input rows -> input cb
*/
Expand All @@ -292,7 +291,8 @@ MaxPool2D::MultiCore::cached_program_t max_pool_2d_multi_core_sharded_with_halo_
split_reader, // enable split reader
0, // split reader id
bf16_one_u32,
in_nblocks_c};
in_nblocks_c,
max_rows_for_reduction};

std::vector<uint32_t> reader1_ct_args = {
out_nhw_per_core,
Expand All @@ -308,7 +308,8 @@ MaxPool2D::MultiCore::cached_program_t max_pool_2d_multi_core_sharded_with_halo_
split_reader, // enable split reader
1, // split reader id
bf16_one_u32,
in_nblocks_c};
in_nblocks_c,
max_rows_for_reduction};

std::string reader_kernel_fname;
if (is_large_kernel) {
Expand Down Expand Up @@ -350,7 +351,8 @@ MaxPool2D::MultiCore::cached_program_t max_pool_2d_multi_core_sharded_with_halo_
split_reader, // enable split reader
out_nhw_per_core / nblocks, // loop count with blocks
input_shape[3] / num_shards_c,
in_nblocks_c};
in_nblocks_c,
max_rows_for_reduction};

auto reduce_op = tt::tt_metal::ReduceOpMath::MAX;
auto reduce_dim = tt::tt_metal::ReduceOpDim::H;
Expand Down

0 comments on commit 546ae62

Please sign in to comment.