diff --git a/benchmark/test_reduction_perf.py b/benchmark/test_reduction_perf.py index c38c5235..10b60a22 100644 --- a/benchmark/test_reduction_perf.py +++ b/benchmark/test_reduction_perf.py @@ -3,6 +3,7 @@ from .performance_utils import ( BLAS_BATCH, FLOAT_DTYPES, + INT_DTYPES, REDUCTION_BATCH, SIZES, Benchmark, @@ -98,6 +99,27 @@ def cumsum_args(dtype, batch, size): bench.run() +def test_perf_nonzero(): + def nonzero_args(dtype, batch, size): + if dtype == torch.bool: + inp = torch.randint(0, 2, [batch, size], dtype=torch.int, device="cuda").to( + torch.bool + ) + else: + inp = torch.randint(0, 2, [batch, size], dtype=dtype, device="cuda") + return (inp,) + + bench = Benchmark( + op_name="nonzero", + torch_op=torch.nonzero, + arg_func=nonzero_args, + dtypes=FLOAT_DTYPES + INT_DTYPES + [torch.bool], + batch=REDUCTION_BATCH, + sizes=SIZES, + ) + bench.run() + + def test_perf_groupnorm(): def group_norm_args(dtype, batch, size): C = 16 diff --git a/src/flag_gems/__init__.py b/src/flag_gems/__init__.py index 9bd0c698..71cd47f7 100644 --- a/src/flag_gems/__init__.py +++ b/src/flag_gems/__init__.py @@ -129,6 +129,7 @@ def enable(lib=aten_lib): lib.impl("index_select", index_select, "CUDA") lib.impl("masked_fill", masked_fill, "CUDA") lib.impl("_unique2", _unique2, "CUDA") + lib.impl("nonzero", nonzero, "CUDA") class use_gems: diff --git a/src/flag_gems/ops/__init__.py b/src/flag_gems/ops/__init__.py index 208208b0..4d944044 100644 --- a/src/flag_gems/ops/__init__.py +++ b/src/flag_gems/ops/__init__.py @@ -51,6 +51,7 @@ from .mv import mv from .ne import ne, ne_scalar from .neg import neg +from .nonzero import nonzero from .normal import ( normal_float_float, normal_float_tensor, @@ -201,4 +202,5 @@ "where_scalar_other", "masked_fill", "_unique2", + "nonzero", ] diff --git a/src/flag_gems/ops/cumsum.py b/src/flag_gems/ops/cumsum.py index 73ae4f06..ea407da7 100644 --- a/src/flag_gems/ops/cumsum.py +++ b/src/flag_gems/ops/cumsum.py @@ -1,55 +1,185 @@ import logging +import math import torch import triton import triton.language as tl -from ..utils import libentry - - -def heur_block_n(args): - return triton.next_power_of_2(args["N"]) - - -@libentry() -@triton.autotune( - configs=[ - triton.Config({"BLOCK_M": 8}, num_warps=8), - triton.Config({"BLOCK_M": 16}, num_warps=8), - triton.Config({"BLOCK_M": 32}, num_warps=8), - ], - key=[ - "M", - "N", - ], -) -@triton.heuristics( - { - "BLOCK_N": heur_block_n, - } -) -@triton.jit -def cumsum_kernel( + +@triton.jit(do_not_specialize=["n_elements", "part_num"]) +def scan_part_sum_kernel( inp, out, - M, - N, - K, - BLOCK_M: tl.constexpr, - BLOCK_N: tl.constexpr, + partial_sum, + n_elements, + part_num, + BLOCK_SIZE: tl.constexpr, ): - pid_m = tl.program_id(0) - pid_k = tl.program_id(1) - m_offset = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) - n_offset = tl.arange(0, BLOCK_N) - offset = m_offset[:, None, None] * N * K + n_offset[None, :, None] * K + pid_k - mask = m_offset[:, None, None] < M and n_offset[None, :, None] < N + pid = tl.program_id(0) + offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask = offset < n_elements + inp_ptrs = inp + offset - inp_vals = tl.load(inp_ptrs, mask=mask).to(tl.float32) - result = tl.cumsum(inp_vals, axis=1) + inp_vals = tl.load(inp_ptrs, mask=mask) + if ( + tl.constexpr(inp_vals.dtype.is_int64()) + or tl.constexpr(inp_vals.dtype.is_uint64()) + ) or tl.constexpr(inp_vals.dtype.is_fp64()): + inp_vals = inp_vals + elif tl.constexpr(inp_vals.dtype.is_int()): + inp_vals = inp_vals.to(tl.int32) + else: + inp_vals = inp_vals.to(tl.float32) + result = tl.cumsum(inp_vals, axis=0) + + part_sum_via_sum = tl.sum(inp_vals) + out_ptrs = out + offset tl.store(out_ptrs, result, mask=mask) + partial_sum_ptrs = partial_sum + pid + tl.store(partial_sum_ptrs, part_sum_via_sum) + + +@triton.jit(do_not_specialize=["n_elements", "part_num"]) +def add_base_sum_kernel( + out, + partial_sum, + n_elements, + part_num, + BLOCK_SIZE: tl.constexpr, +): + pid = tl.program_id(0) + offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask = offset < n_elements + + out_ptrs = out + offset + out_vals = tl.load(out_ptrs, mask=mask) + + if pid > 0: + partial_sum_ptrs = partial_sum + pid - 1 + last_part_sum_via_sum = tl.load(partial_sum_ptrs) + + final_vals = out_vals + last_part_sum_via_sum + tl.store(out_ptrs, final_vals.to(out_vals.dtype), mask=mask) + + +@triton.jit(do_not_specialize=["part_num"]) +def scan_part_sum_abc_kernel( + inp, + out, + partial_sum, + B, + C, + part_num, + BLOCK_SIZE: tl.constexpr, +): + pid_a = tl.program_id(0) + pid_b = tl.program_id(1) + pid_c = tl.program_id(2) + + a_idx = pid_a + b_idx = pid_b * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + c_idx = pid_c + + offset = a_idx * B * C + b_idx * C + c_idx + base_part_offset = a_idx * part_num * C + c_idx + part_offset = base_part_offset + pid_b * C + + mask = b_idx < B + inp_ptrs = inp + offset + inp_vals = tl.load(inp_ptrs, mask=mask) + if ( + tl.constexpr(inp_vals.dtype.is_int64()) + or tl.constexpr(inp_vals.dtype.is_uint64()) + ) or tl.constexpr(inp_vals.dtype.is_fp64()): + inp_vals = inp_vals + elif tl.constexpr(inp_vals.dtype.is_int()): + inp_vals = inp_vals.to(tl.int32) + else: + inp_vals = inp_vals.to(tl.float32) + result = tl.cumsum(inp_vals, axis=0) + + part_sum_via_sum = tl.sum(inp_vals) + + out_ptrs = out + offset + tl.store(out_ptrs, result, mask=mask) + + partial_sum_ptrs = partial_sum + part_offset + tl.store(partial_sum_ptrs, part_sum_via_sum) + + +@triton.jit(do_not_specialize=["part_num"]) +def add_base_sum_abc_kernel( + out, + partial_sum, + B, + C, + part_num, + BLOCK_SIZE: tl.constexpr, +): + pid_a = tl.program_id(0) + pid_b = tl.program_id(1) + pid_c = tl.program_id(2) + + a_idx = pid_a + b_idx = pid_b * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + c_idx = pid_c + + base_offset = a_idx * B * C + c_idx + offset = base_offset + b_idx * C + base_part_offset = a_idx * part_num * C + c_idx + last_part_offset = base_part_offset + (pid_b - 1) * C + + mask = b_idx < B + out_ptrs = out + offset + out_vals = tl.load(out_ptrs, mask=mask) + + if pid_b > 0: + partial_sum_ptrs = partial_sum + last_part_offset + last_part_sum_via_sum = tl.load(partial_sum_ptrs) + + final_vals = out_vals + last_part_sum_via_sum + tl.store(out_ptrs, final_vals.to(out_vals.dtype), mask=mask) + + +def scan_then_fan_col(inp, out, n_ele, dtype): + # TODO(all): tune on target board + BLOCK_SIZE = 1024 + if n_ele <= 1024 * 4: + BLOCK_SIZE = triton.next_power_of_2(n_ele) + part_num = math.ceil(n_ele / BLOCK_SIZE) + partial_sum = torch.empty(part_num, dtype=dtype, device=inp.device) + + grid = (part_num,) + with torch.cuda.device(inp.device): + scan_part_sum_kernel[grid](inp, out, partial_sum, n_ele, part_num, BLOCK_SIZE) + + if part_num >= 2: + scan_then_fan_col(partial_sum, partial_sum, part_num, dtype) + with torch.cuda.device(inp.device): + add_base_sum_kernel[grid](out, partial_sum, n_ele, part_num, BLOCK_SIZE) + + +def scan_then_fan(inp, out, A, B, C, dtype): + # TODO(all): tune on target board + BLOCK_SIZE = 1024 + if B <= 1024 * 4: + BLOCK_SIZE = triton.next_power_of_2(B) + part_num = math.ceil(B / BLOCK_SIZE) + partial_sum = torch.empty(A, part_num, C, dtype=dtype, device=inp.device) + + grid = (A, part_num, C) + with torch.cuda.device(inp.device): + scan_part_sum_abc_kernel[grid]( + inp, out, partial_sum, B, C, part_num, BLOCK_SIZE + ) + + if part_num >= 2: + scan_then_fan(partial_sum, partial_sum, A, part_num, C, dtype) + with torch.cuda.device(inp.device): + add_base_sum_abc_kernel[grid](out, partial_sum, B, C, part_num, BLOCK_SIZE) + def cumsum(inp, dim=1, *, dtype=None): logging.debug("GEMS CUMSUM") @@ -69,10 +199,12 @@ def cumsum(inp, dim=1, *, dtype=None): dtype = torch.int64 out = torch.empty_like(inp, dtype=dtype) - grid = lambda meta: ( - triton.cdiv(M, meta["BLOCK_M"]), - K, - ) - with torch.cuda.device(inp.device): - cumsum_kernel[grid](inp, out, M, N, K) + compute_dtype = out.dtype + if inp.dtype == torch.float16 or inp.dtype == torch.bfloat16: + compute_dtype = torch.float32 + + if M == 1 and K == 1: + scan_then_fan_col(inp, out, N, compute_dtype) + else: + scan_then_fan(inp, out, M, N, K, compute_dtype) return out diff --git a/src/flag_gems/ops/nonzero.py b/src/flag_gems/ops/nonzero.py new file mode 100644 index 00000000..9a3cbc71 --- /dev/null +++ b/src/flag_gems/ops/nonzero.py @@ -0,0 +1,79 @@ +import logging + +import torch +import triton +import triton.language as tl + +from ..utils import libentry + + +@libentry() +@triton.autotune( + configs=[ + triton.Config({"BLOCK_SIZE": k}, num_warps=w, num_stages=4) + for w in [4, 8, 16, 32] + for k in [256, 512, 1024, 2048, 4096, 8192] + ], + key=[ + "n_elements", + ], +) +@triton.jit +def nonzero_kernel( + inp, + prefix_sum, + out, + n_elements, + shape, + ndim: tl.constexpr, + BLOCK_SIZE: tl.constexpr, +): + pid = tl.program_id(0) + + offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask = offset < n_elements + + inp_vals = tl.load(inp + offset, mask=mask) + out_offset = tl.load(prefix_sum + offset, mask=mask) - 1 + + nonzero_mask = mask and inp_vals == True # noqa + + idx_flat = offset + for dim in range(ndim - 1, -1, -1): + dim_size = tl.load(shape + dim) + remainder = idx_flat % dim_size + idx_flat //= dim_size + tl.store(out + out_offset * ndim + dim, remainder, mask=nonzero_mask) + + +def nonzero(inp, *, as_tuple=False): + logging.debug("GEMS NONZERO") + + inp_ndim = inp.ndim + + inp = inp.contiguous() + n_elements = inp.numel() + inp_view = inp.view(n_elements) + + shape = torch.tensor(inp.shape, dtype=torch.int32, device=inp.device) + + inp_bool = inp_view + if inp_view.dtype != torch.bool: + inp_bool = inp_view != 0 + + prefix_sum = inp_bool.cumsum(axis=0) + + num_nonzeros = n_elements + out = torch.empty(num_nonzeros, inp_ndim, dtype=torch.int64, device=inp.device) + + grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) + with torch.cuda.device(inp.device): + nonzero_kernel[grid](inp_bool, prefix_sum, out, n_elements, shape, inp_ndim) + + num_nonzeros = prefix_sum[n_elements - 1].item() + out = out[0:num_nonzeros] + + if as_tuple: + return torch.unbind(out, dim=0) + else: + return out diff --git a/tests/accuracy_utils.py b/tests/accuracy_utils.py index 9b33b9d7..0627ffa4 100644 --- a/tests/accuracy_utils.py +++ b/tests/accuracy_utils.py @@ -8,6 +8,9 @@ RESOLUTION = { + torch.bool: 0, + torch.int16: 0, + torch.int32: 0, torch.float16: 1e-3, torch.float32: 1.3e-6, torch.bfloat16: 0.016, @@ -17,6 +20,8 @@ DISTRIBUTION_SHAPES = [(20, 320, 15)] REDUCTION_SHAPES = [(4096, 256 * i) for i in range(1, 10, 2)] MNK_SHAPES = [15, 160, 1024] +REDUCTION_MNK_SHAPES = [(15, 160, 1024), (16, 1025, 255)] +ONE_DIM_SHAPES = [(256 * i + 7,) for i in range(1, 10, 2)] DIM_POINTWISE_SHAPES = [ (1024, 1024, 1), diff --git a/tests/test_reduction_ops.py b/tests/test_reduction_ops.py index 6c9c7ee7..4e628bf6 100644 --- a/tests/test_reduction_ops.py +++ b/tests/test_reduction_ops.py @@ -7,6 +7,9 @@ DIM_LIST, DIMS_LIST, FLOAT_DTYPES, + INT_DTYPES, + ONE_DIM_SHAPES, + REDUCTION_MNK_SHAPES, REDUCTION_SHAPES, gems_assert_close, gems_assert_equal, @@ -236,11 +239,20 @@ def test_accuracy_cross_entropy_loss_probabilities( gems_assert_close(res_in_grad, ref_in_grad, dtype, reduce_dim=shape[dim]) -@pytest.mark.parametrize("shape", REDUCTION_SHAPES) -@pytest.mark.parametrize("dtype", FLOAT_DTYPES) +@pytest.mark.parametrize( + "shape", REDUCTION_SHAPES + ONE_DIM_SHAPES + REDUCTION_MNK_SHAPES +) +@pytest.mark.parametrize("dtype", FLOAT_DTYPES + INT_DTYPES) def test_accuracy_cumsum(shape, dtype): - dim = 1 - inp = torch.randn(shape, dtype=dtype, device="cuda") + if shape in REDUCTION_MNK_SHAPES: + dim = 1 + else: + dim = -1 + + if dtype in INT_DTYPES: + inp = torch.randint(-3, 3, shape, device="cuda").to(dtype) + else: + inp = torch.randn(shape, dtype=dtype, device="cuda") ref_inp = to_reference(inp, True) ref_out = torch.cumsum(ref_inp, dim=dim) @@ -250,6 +262,26 @@ def test_accuracy_cumsum(shape, dtype): gems_assert_close(res_out, ref_out, dtype, reduce_dim=shape[dim]) +@pytest.mark.parametrize( + "shape", REDUCTION_SHAPES + ONE_DIM_SHAPES + REDUCTION_MNK_SHAPES +) +@pytest.mark.parametrize("dtype", FLOAT_DTYPES + INT_DTYPES + [torch.bool]) +def test_accuracy_nonzero(shape, dtype): + if dtype == torch.bool: + inp = torch.randint(0, 2, shape, dtype=torch.int, device="cuda").to(dtype) + elif dtype in INT_DTYPES: + inp = torch.randint(-3, 3, shape, device="cuda").to(dtype) + else: + inp = torch.randn(shape, dtype=dtype, device="cuda") + ref_inp = to_reference(inp, False) + + ref_out = torch.nonzero(ref_inp) + with flag_gems.use_gems(): + res_out = torch.nonzero(inp) + + gems_assert_equal(res_out, ref_out) + + @pytest.mark.parametrize( "N, C, H, W, num_groups", [