From 9f0ccf4bf05b80016ffbc97c6cff25e5c55ca748 Mon Sep 17 00:00:00 2001 From: WuKe Date: Fri, 23 Aug 2024 15:41:23 +0800 Subject: [PATCH] [Operator] Add nonzero op --- benchmark/test_reduction_perf.py | 21 ++++ src/flag_gems/__init__.py | 1 + src/flag_gems/ops/__init__.py | 2 + src/flag_gems/ops/cumsum.py | 194 ++++++++++++++++++++++++------- src/flag_gems/ops/nonzero.py | 79 +++++++++++++ tests/accuracy_utils.py | 2 + tests/test_reduction_ops.py | 16 +++ 7 files changed, 272 insertions(+), 43 deletions(-) create mode 100644 src/flag_gems/ops/nonzero.py diff --git a/benchmark/test_reduction_perf.py b/benchmark/test_reduction_perf.py index c38c5235..46a961b0 100644 --- a/benchmark/test_reduction_perf.py +++ b/benchmark/test_reduction_perf.py @@ -98,6 +98,27 @@ def cumsum_args(dtype, batch, size): bench.run() +def test_perf_nonzero(): + def nonzero_args(dtype, batch, size): + if dtype == torch.bool: + inp = torch.randint(0, 2, [batch, size], dtype=torch.int, device="cuda").to( + torch.bool + ) + else: + inp = torch.randint(0, 2, [batch, size], dtype=dtype, device="cuda") + return (inp,) + + bench = Benchmark( + op_name="nonzero", + torch_op=torch.nonzero, + arg_func=nonzero_args, + dtypes=FLOAT_DTYPES + [torch.bool], + batch=REDUCTION_BATCH, + sizes=SIZES, + ) + bench.run() + + def test_perf_groupnorm(): def group_norm_args(dtype, batch, size): C = 16 diff --git a/src/flag_gems/__init__.py b/src/flag_gems/__init__.py index 9bd0c698..71cd47f7 100644 --- a/src/flag_gems/__init__.py +++ b/src/flag_gems/__init__.py @@ -129,6 +129,7 @@ def enable(lib=aten_lib): lib.impl("index_select", index_select, "CUDA") lib.impl("masked_fill", masked_fill, "CUDA") lib.impl("_unique2", _unique2, "CUDA") + lib.impl("nonzero", nonzero, "CUDA") class use_gems: diff --git a/src/flag_gems/ops/__init__.py b/src/flag_gems/ops/__init__.py index 208208b0..4d944044 100644 --- a/src/flag_gems/ops/__init__.py +++ b/src/flag_gems/ops/__init__.py @@ -51,6 +51,7 @@ from .mv import mv from .ne import ne, ne_scalar from .neg import neg +from .nonzero import nonzero from .normal import ( normal_float_float, normal_float_tensor, @@ -201,4 +202,5 @@ "where_scalar_other", "masked_fill", "_unique2", + "nonzero", ] diff --git a/src/flag_gems/ops/cumsum.py b/src/flag_gems/ops/cumsum.py index 73ae4f06..f76efb43 100644 --- a/src/flag_gems/ops/cumsum.py +++ b/src/flag_gems/ops/cumsum.py @@ -1,55 +1,165 @@ import logging +import math import torch import triton import triton.language as tl -from ..utils import libentry - - -def heur_block_n(args): - return triton.next_power_of_2(args["N"]) - - -@libentry() -@triton.autotune( - configs=[ - triton.Config({"BLOCK_M": 8}, num_warps=8), - triton.Config({"BLOCK_M": 16}, num_warps=8), - triton.Config({"BLOCK_M": 32}, num_warps=8), - ], - key=[ - "M", - "N", - ], -) -@triton.heuristics( - { - "BLOCK_N": heur_block_n, - } -) + +@triton.jit +def scan_part_sum_kernel( + inp, + out, + partial_sum, + n_elements, + part_num, + BLOCK_SIZE: tl.constexpr, +): + pid = tl.program_id(0) + offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask = offset < n_elements + + inp_ptrs = inp + offset + inp_vals = tl.load(inp_ptrs, mask=mask).to(tl.float32) + result = tl.cumsum(inp_vals, axis=0) + + part_sum_via_sum = tl.sum(inp_vals) + + out_ptrs = out + offset + tl.store(out_ptrs, result, mask=mask) + + partial_sum_ptrs = partial_sum + pid + tl.store(partial_sum_ptrs, part_sum_via_sum) + + @triton.jit -def cumsum_kernel( +def add_base_sum_kernel( + out, + partial_sum, + n_elements, + part_num, + BLOCK_SIZE: tl.constexpr, +): + pid = tl.program_id(0) + offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask = offset < n_elements + + out_ptrs = out + offset + out_vals = tl.load(out_ptrs, mask=mask) + + if pid > 0: + partial_sum_ptrs = partial_sum + pid - 1 + last_part_sum_via_sum = tl.load(partial_sum_ptrs) + + out_vals += last_part_sum_via_sum + tl.store(out_ptrs, out_vals, mask=mask) + + +@triton.jit +def scan_part_sum_abc_kernel( inp, out, - M, - N, - K, - BLOCK_M: tl.constexpr, - BLOCK_N: tl.constexpr, + partial_sum, + B, + C, + part_num, + BLOCK_SIZE: tl.constexpr, ): - pid_m = tl.program_id(0) - pid_k = tl.program_id(1) - m_offset = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) - n_offset = tl.arange(0, BLOCK_N) - offset = m_offset[:, None, None] * N * K + n_offset[None, :, None] * K + pid_k - mask = m_offset[:, None, None] < M and n_offset[None, :, None] < N + pid_a = tl.program_id(0) + pid_b = tl.program_id(1) + pid_c = tl.program_id(2) + + a_idx = pid_a + b_idx = pid_b * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + c_idx = pid_c + + offset = a_idx * B * C + b_idx * C + c_idx + base_part_offset = a_idx * part_num * C + c_idx + part_offset = base_part_offset + pid_b * C + + mask = b_idx < B inp_ptrs = inp + offset inp_vals = tl.load(inp_ptrs, mask=mask).to(tl.float32) - result = tl.cumsum(inp_vals, axis=1) + result = tl.cumsum(inp_vals, axis=0) + + part_sum_via_sum = tl.sum(inp_vals) + out_ptrs = out + offset tl.store(out_ptrs, result, mask=mask) + partial_sum_ptrs = partial_sum + part_offset + tl.store(partial_sum_ptrs, part_sum_via_sum) + + +@triton.jit +def add_base_sum_abc_kernel( + out, + partial_sum, + B, + C, + part_num, + BLOCK_SIZE: tl.constexpr, +): + pid_a = tl.program_id(0) + pid_b = tl.program_id(1) + pid_c = tl.program_id(2) + + a_idx = pid_a + b_idx = pid_b * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + c_idx = pid_c + + base_offset = a_idx * B * C + c_idx + offset = base_offset + b_idx * C + base_part_offset = a_idx * part_num * C + c_idx + last_part_offset = base_part_offset + (pid_b - 1) * C + + mask = b_idx < B + out_ptrs = out + offset + out_vals = tl.load(out_ptrs, mask=mask) + + if pid_b > 0: + partial_sum_ptrs = partial_sum + last_part_offset + last_part_sum_via_sum = tl.load(partial_sum_ptrs) + + out_vals += last_part_sum_via_sum + tl.store(out_ptrs, out_vals, mask=mask) + + +def scan_then_fan_col(inp, out, n_ele, dtype): + # TODO(all): tune on target board + BLOCK_SIZE = 1024 + if n_ele <= 1024 * 4: + BLOCK_SIZE = triton.next_power_of_2(n_ele) + part_num = math.ceil(n_ele / BLOCK_SIZE) + partial_sum = torch.empty(part_num, dtype=dtype, device=inp.device) + + grid = (part_num,) + with torch.cuda.device(inp.device): + scan_part_sum_kernel[grid](inp, out, partial_sum, n_ele, part_num, BLOCK_SIZE) + + if part_num >= 2: + scan_then_fan_col(partial_sum, partial_sum, part_num, dtype) + with torch.cuda.device(inp.device): + add_base_sum_kernel[grid](out, partial_sum, n_ele, part_num, BLOCK_SIZE) + + +def scan_then_fan(inp, out, A, B, C, dtype): + # TODO(all): tune on target board + BLOCK_SIZE = 1024 + part_num = math.ceil(B / BLOCK_SIZE) + partial_sum = torch.empty(A, part_num, C, dtype=dtype, device=inp.device) + + grid = (A, part_num, C) + with torch.cuda.device(inp.device): + scan_part_sum_abc_kernel[grid]( + inp, out, partial_sum, B, C, part_num, BLOCK_SIZE + ) + + if part_num >= 2: + scan_then_fan(partial_sum, partial_sum, A, part_num, C, dtype) + with torch.cuda.device(inp.device): + add_base_sum_abc_kernel[grid](out, partial_sum, B, C, part_num, BLOCK_SIZE) + def cumsum(inp, dim=1, *, dtype=None): logging.debug("GEMS CUMSUM") @@ -69,10 +179,8 @@ def cumsum(inp, dim=1, *, dtype=None): dtype = torch.int64 out = torch.empty_like(inp, dtype=dtype) - grid = lambda meta: ( - triton.cdiv(M, meta["BLOCK_M"]), - K, - ) - with torch.cuda.device(inp.device): - cumsum_kernel[grid](inp, out, M, N, K) + if M == 1 and K == 1: + scan_then_fan_col(inp, out, N, out.dtype) + else: + scan_then_fan(inp, out, M, N, K, out.dtype) return out diff --git a/src/flag_gems/ops/nonzero.py b/src/flag_gems/ops/nonzero.py new file mode 100644 index 00000000..9a3cbc71 --- /dev/null +++ b/src/flag_gems/ops/nonzero.py @@ -0,0 +1,79 @@ +import logging + +import torch +import triton +import triton.language as tl + +from ..utils import libentry + + +@libentry() +@triton.autotune( + configs=[ + triton.Config({"BLOCK_SIZE": k}, num_warps=w, num_stages=4) + for w in [4, 8, 16, 32] + for k in [256, 512, 1024, 2048, 4096, 8192] + ], + key=[ + "n_elements", + ], +) +@triton.jit +def nonzero_kernel( + inp, + prefix_sum, + out, + n_elements, + shape, + ndim: tl.constexpr, + BLOCK_SIZE: tl.constexpr, +): + pid = tl.program_id(0) + + offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask = offset < n_elements + + inp_vals = tl.load(inp + offset, mask=mask) + out_offset = tl.load(prefix_sum + offset, mask=mask) - 1 + + nonzero_mask = mask and inp_vals == True # noqa + + idx_flat = offset + for dim in range(ndim - 1, -1, -1): + dim_size = tl.load(shape + dim) + remainder = idx_flat % dim_size + idx_flat //= dim_size + tl.store(out + out_offset * ndim + dim, remainder, mask=nonzero_mask) + + +def nonzero(inp, *, as_tuple=False): + logging.debug("GEMS NONZERO") + + inp_ndim = inp.ndim + + inp = inp.contiguous() + n_elements = inp.numel() + inp_view = inp.view(n_elements) + + shape = torch.tensor(inp.shape, dtype=torch.int32, device=inp.device) + + inp_bool = inp_view + if inp_view.dtype != torch.bool: + inp_bool = inp_view != 0 + + prefix_sum = inp_bool.cumsum(axis=0) + + num_nonzeros = n_elements + out = torch.empty(num_nonzeros, inp_ndim, dtype=torch.int64, device=inp.device) + + grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) + with torch.cuda.device(inp.device): + nonzero_kernel[grid](inp_bool, prefix_sum, out, n_elements, shape, inp_ndim) + + num_nonzeros = prefix_sum[n_elements - 1].item() + out = out[0:num_nonzeros] + + if as_tuple: + return torch.unbind(out, dim=0) + else: + return out diff --git a/tests/accuracy_utils.py b/tests/accuracy_utils.py index 9b33b9d7..970f2f18 100644 --- a/tests/accuracy_utils.py +++ b/tests/accuracy_utils.py @@ -52,6 +52,8 @@ def gems_assert_close(a, b, dtype, equal_nan=False, reduce_dim=1): a = a.to("cpu") b = b.to(dtype) atol = 1e-4 * reduce_dim + if dtype == torch.bfloat16: + atol = 1e-3 * reduce_dim rtol = RESOLUTION[dtype] torch.testing.assert_close(a, b, atol=atol, rtol=rtol, equal_nan=equal_nan) diff --git a/tests/test_reduction_ops.py b/tests/test_reduction_ops.py index 6c9c7ee7..080b6b84 100644 --- a/tests/test_reduction_ops.py +++ b/tests/test_reduction_ops.py @@ -250,6 +250,22 @@ def test_accuracy_cumsum(shape, dtype): gems_assert_close(res_out, ref_out, dtype, reduce_dim=shape[dim]) +@pytest.mark.parametrize("shape", REDUCTION_SHAPES) +@pytest.mark.parametrize("dtype", FLOAT_DTYPES + [torch.bool]) +def test_accuracy_nonzero(shape, dtype): + if dtype == torch.bool: + inp = torch.randint(0, 2, shape, dtype=torch.int, device="cuda").to(torch.bool) + else: + inp = torch.randint(0, 2, shape, dtype=dtype, device="cuda") + ref_inp = to_reference(inp, False) + + ref_out = torch.nonzero(ref_inp) + with flag_gems.use_gems(): + res_out = torch.nonzero(inp) + + gems_assert_equal(res_out, ref_out) + + @pytest.mark.parametrize( "N, C, H, W, num_groups", [