Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Operator] Add nonzero op [MooreThreads] #178

Merged
merged 1 commit into from
Sep 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions benchmark/test_reduction_perf.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from .performance_utils import (
BLAS_BATCH,
FLOAT_DTYPES,
INT_DTYPES,
REDUCTION_BATCH,
SIZES,
Benchmark,
Expand Down Expand Up @@ -98,6 +99,27 @@ def cumsum_args(dtype, batch, size):
bench.run()


def test_perf_nonzero():
def nonzero_args(dtype, batch, size):
if dtype == torch.bool:
inp = torch.randint(0, 2, [batch, size], dtype=torch.int, device="cuda").to(
torch.bool
)
else:
inp = torch.randint(0, 2, [batch, size], dtype=dtype, device="cuda")
return (inp,)

bench = Benchmark(
op_name="nonzero",
torch_op=torch.nonzero,
arg_func=nonzero_args,
dtypes=FLOAT_DTYPES + INT_DTYPES + [torch.bool],
batch=REDUCTION_BATCH,
sizes=SIZES,
)
bench.run()


def test_perf_groupnorm():
def group_norm_args(dtype, batch, size):
C = 16
Expand Down
1 change: 1 addition & 0 deletions src/flag_gems/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ def enable(lib=aten_lib):
lib.impl("index_select", index_select, "CUDA")
lib.impl("masked_fill", masked_fill, "CUDA")
lib.impl("_unique2", _unique2, "CUDA")
lib.impl("nonzero", nonzero, "CUDA")


class use_gems:
Expand Down
2 changes: 2 additions & 0 deletions src/flag_gems/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
from .mv import mv
from .ne import ne, ne_scalar
from .neg import neg
from .nonzero import nonzero
from .normal import (
normal_float_float,
normal_float_tensor,
Expand Down Expand Up @@ -201,4 +202,5 @@
"where_scalar_other",
"masked_fill",
"_unique2",
"nonzero",
]
222 changes: 177 additions & 45 deletions src/flag_gems/ops/cumsum.py
Original file line number Diff line number Diff line change
@@ -1,55 +1,185 @@
import logging
import math

import torch
import triton
import triton.language as tl

from ..utils import libentry


def heur_block_n(args):
return triton.next_power_of_2(args["N"])


@libentry()
@triton.autotune(
configs=[
triton.Config({"BLOCK_M": 8}, num_warps=8),
triton.Config({"BLOCK_M": 16}, num_warps=8),
triton.Config({"BLOCK_M": 32}, num_warps=8),
],
key=[
"M",
"N",
],
)
@triton.heuristics(
{
"BLOCK_N": heur_block_n,
}
)
@triton.jit
def cumsum_kernel(

@triton.jit(do_not_specialize=["n_elements", "part_num"])
def scan_part_sum_kernel(
inp,
out,
M,
N,
K,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
partial_sum,
n_elements,
part_num,
BLOCK_SIZE: tl.constexpr,
):
pid_m = tl.program_id(0)
pid_k = tl.program_id(1)
m_offset = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
n_offset = tl.arange(0, BLOCK_N)
offset = m_offset[:, None, None] * N * K + n_offset[None, :, None] * K + pid_k
mask = m_offset[:, None, None] < M and n_offset[None, :, None] < N
pid = tl.program_id(0)
offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = offset < n_elements

inp_ptrs = inp + offset
inp_vals = tl.load(inp_ptrs, mask=mask).to(tl.float32)
result = tl.cumsum(inp_vals, axis=1)
inp_vals = tl.load(inp_ptrs, mask=mask)
if (
tl.constexpr(inp_vals.dtype.is_int64())
or tl.constexpr(inp_vals.dtype.is_uint64())
) or tl.constexpr(inp_vals.dtype.is_fp64()):
inp_vals = inp_vals
elif tl.constexpr(inp_vals.dtype.is_int()):
inp_vals = inp_vals.to(tl.int32)
else:
inp_vals = inp_vals.to(tl.float32)
result = tl.cumsum(inp_vals, axis=0)

part_sum_via_sum = tl.sum(inp_vals)

out_ptrs = out + offset
tl.store(out_ptrs, result, mask=mask)

partial_sum_ptrs = partial_sum + pid
tl.store(partial_sum_ptrs, part_sum_via_sum)


@triton.jit(do_not_specialize=["n_elements", "part_num"])
def add_base_sum_kernel(
out,
partial_sum,
n_elements,
part_num,
BLOCK_SIZE: tl.constexpr,
):
pid = tl.program_id(0)
offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = offset < n_elements

out_ptrs = out + offset
out_vals = tl.load(out_ptrs, mask=mask)

if pid > 0:
partial_sum_ptrs = partial_sum + pid - 1
last_part_sum_via_sum = tl.load(partial_sum_ptrs)

final_vals = out_vals + last_part_sum_via_sum
tl.store(out_ptrs, final_vals.to(out_vals.dtype), mask=mask)


@triton.jit(do_not_specialize=["part_num"])
def scan_part_sum_abc_kernel(
inp,
out,
partial_sum,
B,
C,
part_num,
BLOCK_SIZE: tl.constexpr,
):
pid_a = tl.program_id(0)
pid_b = tl.program_id(1)
pid_c = tl.program_id(2)

a_idx = pid_a
b_idx = pid_b * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
c_idx = pid_c

offset = a_idx * B * C + b_idx * C + c_idx
base_part_offset = a_idx * part_num * C + c_idx
part_offset = base_part_offset + pid_b * C

mask = b_idx < B
inp_ptrs = inp + offset
inp_vals = tl.load(inp_ptrs, mask=mask)
if (
tl.constexpr(inp_vals.dtype.is_int64())
or tl.constexpr(inp_vals.dtype.is_uint64())
) or tl.constexpr(inp_vals.dtype.is_fp64()):
inp_vals = inp_vals
elif tl.constexpr(inp_vals.dtype.is_int()):
inp_vals = inp_vals.to(tl.int32)
else:
inp_vals = inp_vals.to(tl.float32)
result = tl.cumsum(inp_vals, axis=0)

part_sum_via_sum = tl.sum(inp_vals)

out_ptrs = out + offset
tl.store(out_ptrs, result, mask=mask)

partial_sum_ptrs = partial_sum + part_offset
tl.store(partial_sum_ptrs, part_sum_via_sum)


@triton.jit(do_not_specialize=["part_num"])
def add_base_sum_abc_kernel(
out,
partial_sum,
B,
C,
part_num,
BLOCK_SIZE: tl.constexpr,
):
pid_a = tl.program_id(0)
pid_b = tl.program_id(1)
pid_c = tl.program_id(2)

a_idx = pid_a
b_idx = pid_b * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
c_idx = pid_c

base_offset = a_idx * B * C + c_idx
offset = base_offset + b_idx * C
base_part_offset = a_idx * part_num * C + c_idx
last_part_offset = base_part_offset + (pid_b - 1) * C

mask = b_idx < B
out_ptrs = out + offset
out_vals = tl.load(out_ptrs, mask=mask)

if pid_b > 0:
partial_sum_ptrs = partial_sum + last_part_offset
last_part_sum_via_sum = tl.load(partial_sum_ptrs)

final_vals = out_vals + last_part_sum_via_sum
tl.store(out_ptrs, final_vals.to(out_vals.dtype), mask=mask)


def scan_then_fan_col(inp, out, n_ele, dtype):
# TODO(all): tune on target board
BLOCK_SIZE = 1024
if n_ele <= 1024 * 4:
BLOCK_SIZE = triton.next_power_of_2(n_ele)
part_num = math.ceil(n_ele / BLOCK_SIZE)
partial_sum = torch.empty(part_num, dtype=dtype, device=inp.device)

grid = (part_num,)
with torch.cuda.device(inp.device):
scan_part_sum_kernel[grid](inp, out, partial_sum, n_ele, part_num, BLOCK_SIZE)

if part_num >= 2:
scan_then_fan_col(partial_sum, partial_sum, part_num, dtype)
with torch.cuda.device(inp.device):
add_base_sum_kernel[grid](out, partial_sum, n_ele, part_num, BLOCK_SIZE)


def scan_then_fan(inp, out, A, B, C, dtype):
# TODO(all): tune on target board
BLOCK_SIZE = 1024
if B <= 1024 * 4:
BLOCK_SIZE = triton.next_power_of_2(B)
part_num = math.ceil(B / BLOCK_SIZE)
partial_sum = torch.empty(A, part_num, C, dtype=dtype, device=inp.device)

grid = (A, part_num, C)
with torch.cuda.device(inp.device):
scan_part_sum_abc_kernel[grid](
inp, out, partial_sum, B, C, part_num, BLOCK_SIZE
)

if part_num >= 2:
scan_then_fan(partial_sum, partial_sum, A, part_num, C, dtype)
with torch.cuda.device(inp.device):
add_base_sum_abc_kernel[grid](out, partial_sum, B, C, part_num, BLOCK_SIZE)


def cumsum(inp, dim=1, *, dtype=None):
logging.debug("GEMS CUMSUM")
Expand All @@ -69,10 +199,12 @@ def cumsum(inp, dim=1, *, dtype=None):
dtype = torch.int64
out = torch.empty_like(inp, dtype=dtype)

grid = lambda meta: (
triton.cdiv(M, meta["BLOCK_M"]),
K,
)
with torch.cuda.device(inp.device):
cumsum_kernel[grid](inp, out, M, N, K)
compute_dtype = out.dtype
if inp.dtype == torch.float16 or inp.dtype == torch.bfloat16:
compute_dtype = torch.float32

if M == 1 and K == 1:
scan_then_fan_col(inp, out, N, compute_dtype)
else:
scan_then_fan(inp, out, M, N, K, compute_dtype)
return out
79 changes: 79 additions & 0 deletions src/flag_gems/ops/nonzero.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
import logging

import torch
import triton
import triton.language as tl

from ..utils import libentry


@libentry()
@triton.autotune(
configs=[
triton.Config({"BLOCK_SIZE": k}, num_warps=w, num_stages=4)
for w in [4, 8, 16, 32]
for k in [256, 512, 1024, 2048, 4096, 8192]
],
key=[
"n_elements",
],
)
@triton.jit
def nonzero_kernel(
inp,
prefix_sum,
out,
n_elements,
shape,
ndim: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
):
pid = tl.program_id(0)

offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = offset < n_elements

inp_vals = tl.load(inp + offset, mask=mask)
out_offset = tl.load(prefix_sum + offset, mask=mask) - 1

nonzero_mask = mask and inp_vals == True # noqa

idx_flat = offset
for dim in range(ndim - 1, -1, -1):
dim_size = tl.load(shape + dim)
remainder = idx_flat % dim_size
idx_flat //= dim_size
tl.store(out + out_offset * ndim + dim, remainder, mask=nonzero_mask)


def nonzero(inp, *, as_tuple=False):
logging.debug("GEMS NONZERO")

inp_ndim = inp.ndim

inp = inp.contiguous()
n_elements = inp.numel()
inp_view = inp.view(n_elements)

shape = torch.tensor(inp.shape, dtype=torch.int32, device=inp.device)

inp_bool = inp_view
if inp_view.dtype != torch.bool:
inp_bool = inp_view != 0

prefix_sum = inp_bool.cumsum(axis=0)

num_nonzeros = n_elements
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can it be " num_nonzeros = prefix_sum[-1] " here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The original cumsum kernel only gives one single block for N, so the performance will drop much when the cases become bigger.
It takes like hours to run the entire performance test cases for the original cumsum... Here is performance of the first 4 cases for batch=1024 on A100:

Operator cumsum Performance Test (torch.float16)
Size    Torch Latency (ms)     STF Latency (ms)   original Latency (ms)
----------------------------------------------------------------------
1024              0.077824            0.014336            0.018432
6144               0.15872            0.078848            0.270336
11264             0.282624            0.089088            0.733184
16384             0.406528            0.131072            0.804864

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can it be " num_nonzeros = prefix_sum[-1] " here?

Yes, and actually the size should be 'prefix_sum[-1] '.
Here we trading space for time. 'prefix_sum[n_elements - 1].item()' will trigger a d2h, so we move it to the end. Then those cuda kernels will not be interrupted.
Here is the performance for batch=1024 on A100:

Operator nonzero Performance Test (torch.bool)
Size        fun1 Latency (ms)   fun2 Latency (ms)
--------------------------------------------------
1024                  0.302656            0.304128
6144                  0.505504            0.530432
11264                 0.620224             0.64512
16384                 0.741536            0.769024
21504                 0.858496            0.884736
26624                  0.97568             1.00557
31744                  1.09635             1.12435
36864                  1.24224             1.26566
41984                   1.3944             1.40698
47104                  1.54192             1.56365
52224                  1.68931             1.70906
57344                  1.83971              1.8647
62464                  1.99014             2.01216
67584                  2.14531             2.16474
72704                  2.29331             2.31322
77824                  2.44122             2.46579

out = torch.empty(num_nonzeros, inp_ndim, dtype=torch.int64, device=inp.device)
tongxin marked this conversation as resolved.
Show resolved Hide resolved

grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
with torch.cuda.device(inp.device):
nonzero_kernel[grid](inp_bool, prefix_sum, out, n_elements, shape, inp_ndim)

num_nonzeros = prefix_sum[n_elements - 1].item()
out = out[0:num_nonzeros]

if as_tuple:
return torch.unbind(out, dim=0)
else:
return out
Loading