Skip to content

shared memory accumulation #298

Answered by daedalus5
toaster-robotics asked this question in Q&A
Discussion options

You must be logged in to vote
import warp as wp
import numpy as np

device = "cuda:0"

snippet = """
    __shared__ int sum[256];

    int index = j * 16 + i;

    sum[index] = arr[index];
    __syncthreads();

    for (int stride = 128; stride > 0; stride >>= 1) {
        if (index < stride) {
            sum[index] += sum[index + stride];
        }
        __syncthreads();
    }

    if (index == 0) {
        out[0] = sum[0];
    }
    """

@wp.func_native(snippet)
def reduce(arr: wp.array2d(dtype=int), out: wp.array(dtype=int), i: int, j: int):
    ...

@wp.kernel
def reduce_kernel(arr: wp.array2d(dtype=int), out: wp.array(dtype=int)):
    i, j = wp.tid()
    reduce(arr, out, i, j)

N = 16
row = np.arange(N, dtype=…

Replies: 1 comment 1 reply

Comment options

You must be logged in to vote
1 reply
@toaster-robotics
Comment options

Answer selected by shi-eric
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants