Skip to content

Commit

Permalink
Port sum_reduction_recursive to kernel API.
Browse files Browse the repository at this point in the history
  • Loading branch information
diptorupd committed Apr 1, 2024
1 parent ddeefba commit 8e78498
Showing 1 changed file with 6 additions and 6 deletions.
12 changes: 6 additions & 6 deletions numba_dpex/examples/kernel/sum_reduction_recursive.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,11 @@

@ndpx.kernel
def sum_reduction_kernel(nditem: kapi.NdItem, A, input_size, partial_sums, slm):
local_id = ndpx.get_local_id(0)
global_id = ndpx.get_global_id(0)
group_size = ndpx.get_local_size(0)
group_id = ndpx.get_group_id(0)

local_id = nditem.get_local_id(0)
global_id = nditem.get_global_id(0)
group_size = nditem.get_local_range(0)
gr = nditem.get_group()
group_id = gr.get_group_id(0)
slm[local_id] = 0

if global_id < input_size:
Expand All @@ -32,7 +32,7 @@ def sum_reduction_kernel(nditem: kapi.NdItem, A, input_size, partial_sums, slm):
stride = group_size // 2
while stride > 0:
# Waiting for each 2x2 addition into given workgroup
ndpx.barrier(ndpx.LOCAL_MEM_FENCE)
kapi.group_barrier(gr)

# Add elements 2 by 2 between local_id and local_id + stride
if local_id < stride:
Expand Down

0 comments on commit 8e78498

Please sign in to comment.