Skip to content

Commit

Permalink
[Operator] Add nonzero & optimize cumsum
Browse files Browse the repository at this point in the history
  • Loading branch information
WuKe authored and wuke1993 committed Sep 2, 2024
1 parent adb2094 commit c055a2d
Show file tree
Hide file tree
Showing 7 changed files with 322 additions and 49 deletions.
22 changes: 22 additions & 0 deletions benchmark/test_reduction_perf.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from .performance_utils import (
BLAS_BATCH,
FLOAT_DTYPES,
INT_DTYPES,
REDUCTION_BATCH,
SIZES,
Benchmark,
Expand Down Expand Up @@ -98,6 +99,27 @@ def cumsum_args(dtype, batch, size):
bench.run()


def test_perf_nonzero():
def nonzero_args(dtype, batch, size):
if dtype == torch.bool:
inp = torch.randint(0, 2, [batch, size], dtype=torch.int, device="cuda").to(
torch.bool
)
else:
inp = torch.randint(0, 2, [batch, size], dtype=dtype, device="cuda")
return (inp,)

bench = Benchmark(
op_name="nonzero",
torch_op=torch.nonzero,
arg_func=nonzero_args,
dtypes=FLOAT_DTYPES + INT_DTYPES + [torch.bool],
batch=REDUCTION_BATCH,
sizes=SIZES,
)
bench.run()


def test_perf_groupnorm():
def group_norm_args(dtype, batch, size):
C = 16
Expand Down
1 change: 1 addition & 0 deletions src/flag_gems/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ def enable(lib=aten_lib):
lib.impl("index_select", index_select, "CUDA")
lib.impl("masked_fill", masked_fill, "CUDA")
lib.impl("_unique2", _unique2, "CUDA")
lib.impl("nonzero", nonzero, "CUDA")


class use_gems:
Expand Down
2 changes: 2 additions & 0 deletions src/flag_gems/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
from .mv import mv
from .ne import ne, ne_scalar
from .neg import neg
from .nonzero import nonzero
from .normal import (
normal_float_float,
normal_float_tensor,
Expand Down Expand Up @@ -201,4 +202,5 @@
"where_scalar_other",
"masked_fill",
"_unique2",
"nonzero",
]
222 changes: 177 additions & 45 deletions src/flag_gems/ops/cumsum.py
Original file line number Diff line number Diff line change
@@ -1,55 +1,185 @@
import logging
import math

import torch
import triton
import triton.language as tl

from ..utils import libentry


def heur_block_n(args):
return triton.next_power_of_2(args["N"])


@libentry()
@triton.autotune(
configs=[
triton.Config({"BLOCK_M": 8}, num_warps=8),
triton.Config({"BLOCK_M": 16}, num_warps=8),
triton.Config({"BLOCK_M": 32}, num_warps=8),
],
key=[
"M",
"N",
],
)
@triton.heuristics(
{
"BLOCK_N": heur_block_n,
}
)
@triton.jit
def cumsum_kernel(

@triton.jit(do_not_specialize=["n_elements", "part_num"])
def scan_part_sum_kernel(
inp,
out,
M,
N,
K,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
partial_sum,
n_elements,
part_num,
BLOCK_SIZE: tl.constexpr,
):
pid_m = tl.program_id(0)
pid_k = tl.program_id(1)
m_offset = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
n_offset = tl.arange(0, BLOCK_N)
offset = m_offset[:, None, None] * N * K + n_offset[None, :, None] * K + pid_k
mask = m_offset[:, None, None] < M and n_offset[None, :, None] < N
pid = tl.program_id(0)
offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = offset < n_elements

inp_ptrs = inp + offset
inp_vals = tl.load(inp_ptrs, mask=mask).to(tl.float32)
result = tl.cumsum(inp_vals, axis=1)
inp_vals = tl.load(inp_ptrs, mask=mask)
if (
tl.constexpr(inp_vals.dtype.is_int64())
or tl.constexpr(inp_vals.dtype.is_uint64())
) or tl.constexpr(inp_vals.dtype.is_fp64()):
inp_vals = inp_vals
elif tl.constexpr(inp_vals.dtype.is_int()):
inp_vals = inp_vals.to(tl.int32)
else:
inp_vals = inp_vals.to(tl.float32)
result = tl.cumsum(inp_vals, axis=0)

part_sum_via_sum = tl.sum(inp_vals)

out_ptrs = out + offset
tl.store(out_ptrs, result, mask=mask)

partial_sum_ptrs = partial_sum + pid
tl.store(partial_sum_ptrs, part_sum_via_sum)


@triton.jit(do_not_specialize=["n_elements", "part_num"])
def add_base_sum_kernel(
out,
partial_sum,
n_elements,
part_num,
BLOCK_SIZE: tl.constexpr,
):
pid = tl.program_id(0)
offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = offset < n_elements

out_ptrs = out + offset
out_vals = tl.load(out_ptrs, mask=mask)

if pid > 0:
partial_sum_ptrs = partial_sum + pid - 1
last_part_sum_via_sum = tl.load(partial_sum_ptrs)

final_vals = out_vals + last_part_sum_via_sum
tl.store(out_ptrs, final_vals.to(out_vals.dtype), mask=mask)


@triton.jit(do_not_specialize=["part_num"])
def scan_part_sum_abc_kernel(
inp,
out,
partial_sum,
B,
C,
part_num,
BLOCK_SIZE: tl.constexpr,
):
pid_a = tl.program_id(0)
pid_b = tl.program_id(1)
pid_c = tl.program_id(2)

a_idx = pid_a
b_idx = pid_b * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
c_idx = pid_c

offset = a_idx * B * C + b_idx * C + c_idx
base_part_offset = a_idx * part_num * C + c_idx
part_offset = base_part_offset + pid_b * C

mask = b_idx < B
inp_ptrs = inp + offset
inp_vals = tl.load(inp_ptrs, mask=mask)
if (
tl.constexpr(inp_vals.dtype.is_int64())
or tl.constexpr(inp_vals.dtype.is_uint64())
) or tl.constexpr(inp_vals.dtype.is_fp64()):
inp_vals = inp_vals
elif tl.constexpr(inp_vals.dtype.is_int()):
inp_vals = inp_vals.to(tl.int32)
else:
inp_vals = inp_vals.to(tl.float32)
result = tl.cumsum(inp_vals, axis=0)

part_sum_via_sum = tl.sum(inp_vals)

out_ptrs = out + offset
tl.store(out_ptrs, result, mask=mask)

partial_sum_ptrs = partial_sum + part_offset
tl.store(partial_sum_ptrs, part_sum_via_sum)


@triton.jit(do_not_specialize=["part_num"])
def add_base_sum_abc_kernel(
out,
partial_sum,
B,
C,
part_num,
BLOCK_SIZE: tl.constexpr,
):
pid_a = tl.program_id(0)
pid_b = tl.program_id(1)
pid_c = tl.program_id(2)

a_idx = pid_a
b_idx = pid_b * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
c_idx = pid_c

base_offset = a_idx * B * C + c_idx
offset = base_offset + b_idx * C
base_part_offset = a_idx * part_num * C + c_idx
last_part_offset = base_part_offset + (pid_b - 1) * C

mask = b_idx < B
out_ptrs = out + offset
out_vals = tl.load(out_ptrs, mask=mask)

if pid_b > 0:
partial_sum_ptrs = partial_sum + last_part_offset
last_part_sum_via_sum = tl.load(partial_sum_ptrs)

final_vals = out_vals + last_part_sum_via_sum
tl.store(out_ptrs, final_vals.to(out_vals.dtype), mask=mask)


def scan_then_fan_col(inp, out, n_ele, dtype):
# TODO(all): tune on target board
BLOCK_SIZE = 1024
if n_ele <= 1024 * 4:
BLOCK_SIZE = triton.next_power_of_2(n_ele)
part_num = math.ceil(n_ele / BLOCK_SIZE)
partial_sum = torch.empty(part_num, dtype=dtype, device=inp.device)

grid = (part_num,)
with torch.cuda.device(inp.device):
scan_part_sum_kernel[grid](inp, out, partial_sum, n_ele, part_num, BLOCK_SIZE)

if part_num >= 2:
scan_then_fan_col(partial_sum, partial_sum, part_num, dtype)
with torch.cuda.device(inp.device):
add_base_sum_kernel[grid](out, partial_sum, n_ele, part_num, BLOCK_SIZE)


def scan_then_fan(inp, out, A, B, C, dtype):
# TODO(all): tune on target board
BLOCK_SIZE = 1024
if B <= 1024 * 4:
BLOCK_SIZE = triton.next_power_of_2(B)
part_num = math.ceil(B / BLOCK_SIZE)
partial_sum = torch.empty(A, part_num, C, dtype=dtype, device=inp.device)

grid = (A, part_num, C)
with torch.cuda.device(inp.device):
scan_part_sum_abc_kernel[grid](
inp, out, partial_sum, B, C, part_num, BLOCK_SIZE
)

if part_num >= 2:
scan_then_fan(partial_sum, partial_sum, A, part_num, C, dtype)
with torch.cuda.device(inp.device):
add_base_sum_abc_kernel[grid](out, partial_sum, B, C, part_num, BLOCK_SIZE)


def cumsum(inp, dim=1, *, dtype=None):
logging.debug("GEMS CUMSUM")
Expand All @@ -69,10 +199,12 @@ def cumsum(inp, dim=1, *, dtype=None):
dtype = torch.int64
out = torch.empty_like(inp, dtype=dtype)

grid = lambda meta: (
triton.cdiv(M, meta["BLOCK_M"]),
K,
)
with torch.cuda.device(inp.device):
cumsum_kernel[grid](inp, out, M, N, K)
compute_dtype = out.dtype
if inp.dtype == torch.float16 or inp.dtype == torch.bfloat16:
compute_dtype = torch.float32

if M == 1 and K == 1:
scan_then_fan_col(inp, out, N, compute_dtype)
else:
scan_then_fan(inp, out, M, N, K, compute_dtype)
return out
79 changes: 79 additions & 0 deletions src/flag_gems/ops/nonzero.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
import logging

import torch
import triton
import triton.language as tl

from ..utils import libentry


@libentry()
@triton.autotune(
configs=[
triton.Config({"BLOCK_SIZE": k}, num_warps=w, num_stages=4)
for w in [4, 8, 16, 32]
for k in [256, 512, 1024, 2048, 4096, 8192]
],
key=[
"n_elements",
],
)
@triton.jit
def nonzero_kernel(
inp,
prefix_sum,
out,
n_elements,
shape,
ndim: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
):
pid = tl.program_id(0)

offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = offset < n_elements

inp_vals = tl.load(inp + offset, mask=mask)
out_offset = tl.load(prefix_sum + offset, mask=mask) - 1

nonzero_mask = mask and inp_vals == True # noqa

idx_flat = offset
for dim in range(ndim - 1, -1, -1):
dim_size = tl.load(shape + dim)
remainder = idx_flat % dim_size
idx_flat //= dim_size
tl.store(out + out_offset * ndim + dim, remainder, mask=nonzero_mask)


def nonzero(inp, *, as_tuple=False):
logging.debug("GEMS NONZERO")

inp_ndim = inp.ndim

inp = inp.contiguous()
n_elements = inp.numel()
inp_view = inp.view(n_elements)

shape = torch.tensor(inp.shape, dtype=torch.int32, device=inp.device)

inp_bool = inp_view
if inp_view.dtype != torch.bool:
inp_bool = inp_view != 0

prefix_sum = inp_bool.cumsum(axis=0)

num_nonzeros = n_elements
out = torch.empty(num_nonzeros, inp_ndim, dtype=torch.int64, device=inp.device)

grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
with torch.cuda.device(inp.device):
nonzero_kernel[grid](inp_bool, prefix_sum, out, n_elements, shape, inp_ndim)

num_nonzeros = prefix_sum[n_elements - 1].item()
out = out[0:num_nonzeros]

if as_tuple:
return torch.unbind(out, dim=0)
else:
return out
Loading

0 comments on commit c055a2d

Please sign in to comment.