From 8e784984cee29106a0fcf213a8a10f54773b2f9f Mon Sep 17 00:00:00 2001 From: Diptorup Deb Date: Sun, 31 Mar 2024 05:54:36 -0500 Subject: [PATCH] Port sum_reduction_recursive to kernel API. --- .../examples/kernel/sum_reduction_recursive.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/numba_dpex/examples/kernel/sum_reduction_recursive.py b/numba_dpex/examples/kernel/sum_reduction_recursive.py index a6f55a65d9..08f831cfca 100644 --- a/numba_dpex/examples/kernel/sum_reduction_recursive.py +++ b/numba_dpex/examples/kernel/sum_reduction_recursive.py @@ -18,11 +18,11 @@ @ndpx.kernel def sum_reduction_kernel(nditem: kapi.NdItem, A, input_size, partial_sums, slm): - local_id = ndpx.get_local_id(0) - global_id = ndpx.get_global_id(0) - group_size = ndpx.get_local_size(0) - group_id = ndpx.get_group_id(0) - + local_id = nditem.get_local_id(0) + global_id = nditem.get_global_id(0) + group_size = nditem.get_local_range(0) + gr = nditem.get_group() + group_id = gr.get_group_id(0) slm[local_id] = 0 if global_id < input_size: @@ -32,7 +32,7 @@ def sum_reduction_kernel(nditem: kapi.NdItem, A, input_size, partial_sums, slm): stride = group_size // 2 while stride > 0: # Waiting for each 2x2 addition into given workgroup - ndpx.barrier(ndpx.LOCAL_MEM_FENCE) + kapi.group_barrier(gr) # Add elements 2 by 2 between local_id and local_id + stride if local_id < stride: