diff --git a/benchmark/test_tensor_constructor_perf.py b/benchmark/test_tensor_constructor_perf.py index 57162005..b8dbeeba 100644 --- a/benchmark/test_tensor_constructor_perf.py +++ b/benchmark/test_tensor_constructor_perf.py @@ -159,3 +159,19 @@ def full_kwargs(dtype, batch, size): kwargs_func=full_kwargs, ) bench.run() + + +def test_perf_randperm(): + def randperm_args(dtype, batch, size): + return {"n": size, "dtype": dtype, "device": "cuda"} + + bench = Benchmark( + op_name="randperm", + torch_op=torch.randperm, + arg_func=None, + dtypes=[torch.int32, torch.int64], + batch=POINTWISE_BATCH, + sizes=SIZES, + kwargs_func=randperm_args, + ) + bench.run() diff --git a/src/flag_gems/__init__.py b/src/flag_gems/__init__.py index 0babf7e4..0ce1afe4 100644 --- a/src/flag_gems/__init__.py +++ b/src/flag_gems/__init__.py @@ -131,6 +131,7 @@ def enable(lib=aten_lib): lib.impl("index_select", index_select, "CUDA") lib.impl("masked_fill", masked_fill, "CUDA") lib.impl("_unique2", _unique2, "CUDA") + lib.impl("randperm", randperm, "CUDA") class use_gems: diff --git a/src/flag_gems/ops/__init__.py b/src/flag_gems/ops/__init__.py index 34a82f24..41bec057 100644 --- a/src/flag_gems/ops/__init__.py +++ b/src/flag_gems/ops/__init__.py @@ -67,6 +67,7 @@ from .rand_like import rand_like from .randn import randn from .randn_like import randn_like +from .randperm import randperm from .reciprocal import reciprocal from .relu import relu from .resolve_conj import resolve_conj @@ -157,6 +158,7 @@ "minimum", "rand", "randn", + "randperm", "rand_like", "randn_like", "resolve_neg", diff --git a/src/flag_gems/ops/randperm.py b/src/flag_gems/ops/randperm.py new file mode 100644 index 00000000..f0ccf12d --- /dev/null +++ b/src/flag_gems/ops/randperm.py @@ -0,0 +1,362 @@ +import logging + +import torch +import triton +import triton.language as tl + +from flag_gems.utils.random_utils import philox_cuda_seed_offset + +from ..utils import libentry +from .topk import argsort + +_MIN_INT32_VAL: tl.constexpr = torch.iinfo(torch.int32).min +_MAX_INT32_VAL: tl.constexpr = torch.iinfo(torch.int32).max +_MIN_INT64_VAL: tl.constexpr = torch.iinfo(torch.int64).min +_MAX_INT64_VAL: tl.constexpr = torch.iinfo(torch.int64).max +_MAX_UINT32_VAL: tl.constexpr = (1 << 32) - 1 +_MIN_UINT32_VAL: tl.constexpr = 0 + + +@triton.jit +def _get_iinfo_val( + dtype, + return_max, +): + if dtype is tl.int64: + if return_max: + return _MAX_INT64_VAL + else: + return _MIN_INT64_VAL + elif dtype is tl.int32: + if return_max: + return _MAX_INT32_VAL + else: + return _MIN_INT32_VAL + elif dtype is tl.uint32: + if return_max: + return _MAX_UINT32_VAL + else: + return _MIN_UINT32_VAL + else: + raise ValueError("Unknown dtype") + + +@libentry() +@triton.jit +def bitonic_sortbykey_kernel( + y_ptr, + index_ptr, + chunk_x, + chunk_index, + N: tl.constexpr, + BLOCK_SIZE: tl.constexpr, + DESCENDING: tl.constexpr, +): + cur_batch = tl.program_id(0) + chunk_x += cur_batch * N + chunk_index += cur_batch * N + index_ptr += cur_batch * N + y_ptr += cur_batch * N + + cols = tl.arange(0, BLOCK_SIZE) + mask = cols < N + + mask_val = _get_iinfo_val(chunk_x.dtype.element_ty, return_max=not DESCENDING) + + chunk_x_val = tl.load(chunk_x + cols, mask=mask, other=mask_val) + chunk_index_val = tl.load(chunk_index + cols, mask=mask) + + sorted_chunk_x, sorted_chunk_index = argsort( + chunk_x_val, chunk_index_val, 0, descending=DESCENDING + ) + tl.store(y_ptr + cols, sorted_chunk_x, mask=cols < N) + tl.store(index_ptr + cols, sorted_chunk_index, mask=cols < N) + + +@triton.jit +def radix_type_convert(k): + if tl.constexpr(k.dtype == tl.int32): + ik = k.to(tl.int32, bitcast=True) + mask = (ik >> 31) & 0x1 + o = tl.where(mask, ik & 0x7FFFFFFF, ik | 0x80000000) + elif tl.constexpr(k.dtype == tl.int64): + ik = k.to(tl.int64, bitcast=True) + mask = (ik >> 63) & 0x1 + o = tl.where(mask, ik & 0x7FFFFFFFFFFFFFFF, ik | 0x8000000000000000) + else: + o = k + return o + + +@libentry() +@triton.jit +def digit_hist_kernel( + digit_hist, + d_lookback, + key, + n_elements, + bits_per_pass, + bins, + passes, + bit_mask, + bins_segment, + BLOCK_SIZE: tl.constexpr, +): + bin_segid = tl.program_id(0) + pid1 = tl.program_id(1) + grid1 = tl.num_programs(1) + + key_offset = pid1 * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + key_mask = key_offset < n_elements + key_data = tl.load(key + key_offset, mask=key_mask) + ikey_data = radix_type_convert(key_data) + bit_offset = 0 + for p in range(passes): + key_digit = (ikey_data >> bit_offset) & bit_mask + blk_bin_start = bin_segid * bins_segment + for s in range(bins_segment): + bin_id = s + blk_bin_start + digit_mask = tl.where(key_digit == bin_id and key_mask, 1, 0) + digit_sum = tl.sum(digit_mask) + # +1 for exclusive + bin_offset = p * (bins + 1) * grid1 + (bin_id + 1) * grid1 + pid1 + # reduce rather than global atomic for perf issue + tl.store(digit_hist + bin_offset, digit_sum) + # zero init d_lookback + tl.store(d_lookback + (p * bins + bin_id) * grid1 + pid1, 0) + tl.store(digit_hist + p * (bins + 1) * grid1 + pid1, 0, mask=bin_segid == 0) + bit_offset += bits_per_pass + + +@libentry() +@triton.autotune( + configs=[triton.Config({}, num_warps=w) for w in [4, 8, 16]], + key=["n_elements"], +) +@triton.jit +def radix_sortbykey_scatter_kernel( + key_out, + value_out, + key_in, + value_in, + digit_hist, + d_lookback, + n_elements, + bit_offset, + p, + bit_mask, + bins_segment, + bins: tl.constexpr, + BLOCK_SIZE: tl.constexpr, +): + LOOKBACK_PARTIAL_MASK = 1 << 30 + LOOKBACK_GLOBAL_MASK = 1 << 31 + LOOKBACK_KIND_MASK = LOOKBACK_PARTIAL_MASK | LOOKBACK_GLOBAL_MASK + LOOKBACK_VALUE_MASK = ~LOOKBACK_KIND_MASK + + pid1 = tl.program_id(1) + grid1 = tl.num_programs(1) + + key_offset = pid1 * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + key_mask = key_offset < n_elements + value_data = tl.load(value_in + key_offset, mask=key_mask) + key_data = tl.load(key_in + key_offset, mask=key_mask) + + ikey_data = radix_type_convert(key_data) + key_digit = (ikey_data >> bit_offset) & bit_mask + + blk_bin_start = tl.program_id(0) * bins_segment + for s in range(bins_segment): + bin_id = s + blk_bin_start + key_digit_mask = (key_digit == bin_id) & key_mask + key_elem_mask = tl.where(key_digit_mask, 1, 0) + key_block_rank = tl.cumsum(key_elem_mask) + key_block_rank = tl.where(key_digit_mask, key_block_rank - 1, 0) + bin_of_bucket = tl.sum(key_elem_mask) + partial_counter = bin_of_bucket | LOOKBACK_PARTIAL_MASK + tl.store( + d_lookback + p * grid1 * bins + pid1 * bins + bin_id, + partial_counter, + cache_modifier=".wt", + ) + bin_offset = p * (bins + 1) + bin_id + prefix_offsets = tl.load(digit_hist + bin_offset) + bk = pid1 - 1 + inc_sum = bin_of_bucket + while bk >= 0: + rd_lbk_offset = p * grid1 * bins + bk * bins + bin_id + partial_prefix = tl.load(d_lookback + rd_lbk_offset, volatile=True) + while partial_prefix == 0: + partial_prefix = tl.load(d_lookback + rd_lbk_offset, volatile=True) + inc_sum += (partial_prefix & LOOKBACK_VALUE_MASK).to(tl.int32) + if partial_prefix & LOOKBACK_GLOBAL_MASK: + # break + bk = -1 + else: + bk -= 1 + global_counter = inc_sum | LOOKBACK_GLOBAL_MASK + tl.store( + d_lookback + p * grid1 * bins + pid1 * bins + bin_id, + global_counter, + cache_modifier=".wt", + ) + global_offsets = (prefix_offsets + inc_sum - bin_of_bucket) + key_block_rank + tl.store(key_out + global_offsets, key_data, mask=key_digit_mask) + tl.store(value_out + global_offsets, value_data, mask=key_digit_mask) + + +# for parallelization, randomly shuffle the entire block rather than adjacent equal elements as pytorch cuda backend +@libentry() +@triton.jit(do_not_specialize=["philox_seed", "philox_offset"]) +def duplicate_keys_shuffle_kernel( + value_in, n_elements, philox_seed, philox_offset, BLOCK_SIZE: tl.constexpr +): + pid0 = tl.program_id(0) + offset_range = tl.arange(0, BLOCK_SIZE) + value_offset = pid0 * BLOCK_SIZE + offset_range + value_mask = value_offset < n_elements + value_data = tl.load(value_in + value_offset, mask=value_mask) + + philox_seed = philox_seed.to(tl.int64) + philox_offset = philox_offset.to(tl.int64) + c0 = (philox_offset & 0xFFFFFFFF).to(tl.uint32) + c1 = ((philox_offset >> 32) & 0xFFFFFFFF).to(tl.uint32) + i4 = tl.program_id(0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + c0 += i4 + _O = c0 * 0 + r0, _, _, _ = tl.philox(philox_seed, c0, c1, _O, _O) + + _block_size = BLOCK_SIZE + r1 = r0 % _block_size.to(tl.uint32) + mask_val = _get_iinfo_val(tl.uint32, True) + r1 = tl.where(value_offset < n_elements, r1, mask_val) + _, sorted_chunk_index = argsort(r1, offset_range, 0, descending=False) + store_offset = pid0 * BLOCK_SIZE + sorted_chunk_index + tl.store(value_in + store_offset, value_data, mask=store_offset < n_elements) + + +def sort_by_key(key, value): + n_elements = key.numel() + if n_elements > 2 * 1024: + # radix method + BLOCK_SIZE = 1024 + bits_per_pass = 4 + bits_per_segment = 3 + passes = triton.cdiv(key.element_size() * 8, bits_per_pass) + bins = 2**bits_per_pass + bins_per_sgement = 2**bits_per_segment + bit_mask = bins - 1 + + grid = (bins // bins_per_sgement, triton.cdiv(n_elements, BLOCK_SIZE)) + digit_hist = torch.empty( + (passes, bins + 1, grid[1]), dtype=torch.int32, device=key.device + ) + d_lookback = torch.empty( + bins * grid[1] * passes, dtype=torch.int32, device=key.device + ) + + key_out_p = torch.empty_like(key) + key_out_q = torch.empty_like(key) + value_out_p = torch.empty_like(value) + value_out_q = torch.empty_like(value) + + # step1 + with torch.cuda.device(key.device): + digit_hist_kernel[grid]( + digit_hist, + d_lookback, + key, + n_elements, + bits_per_pass, + bins, + passes, + bit_mask, + bins_per_sgement, + BLOCK_SIZE, + ) + digit_hist = torch.sum(digit_hist, dim=2, keepdim=False) + # step2 + digit_hist = digit_hist.cumsum(dim=1) + + bit_offset = 0 + for p in range(passes): + k_in = (key if p == 0 else key_out_p) if p % 2 == 0 else key_out_q + v_in = (value if p == 0 else value_out_p) if p % 2 == 0 else value_out_q + k_out = key_out_q if p % 2 == 0 else key_out_p + v_out = value_out_q if p % 2 == 0 else value_out_p + # step3 + with torch.cuda.device(key.device): + radix_sortbykey_scatter_kernel[grid]( + k_out, + v_out, + k_in, + v_in, + digit_hist, + d_lookback, + n_elements, + bit_offset, + p, + bit_mask, + bins_per_sgement, + bins, + BLOCK_SIZE, + ) + bit_offset += bits_per_pass + + # last step, shuffle inner-block data + BLOCK_SIZE_SHUFFLE = 512 + grid_shuffle = (triton.cdiv(n_elements, BLOCK_SIZE_SHUFFLE),) + philox_seed, philox_offset = philox_cuda_seed_offset(n_elements) + with torch.cuda.device(key.device): + duplicate_keys_shuffle_kernel[grid_shuffle]( + v_out, + n_elements, + philox_seed, + philox_offset, + BLOCK_SIZE_SHUFFLE, + num_warps=4, + ) + return v_out + else: + # bitonic method + BLOCK_SIZE = triton.next_power_of_2(n_elements) + grid = (1,) + k_out = torch.empty_like(key) + v_out = torch.empty_like(value) + with torch.cuda.device(key.device): + bitonic_sortbykey_kernel[grid]( + k_out, v_out, key, value, n_elements, BLOCK_SIZE, False + ) + return v_out + + +def randperm( + n, + *, + generator=None, + out=None, + dtype=torch.int64, + layout=torch.strided, + device=None, + requires_grad=False, + pin_memory=False, +): + logging.debug("GEMS RANDPERM") + + assert dtype == torch.int16 or dtype == torch.int32 or dtype == torch.int64 + + if device is None: + device = torch.device("cuda") + in_range = torch.arange(n, dtype=dtype, device=device) + i32max = torch.iinfo(torch.int32).max + i64max = torch.iinfo(torch.int64).max + if n < i32max: + rand_key = torch.randint( + low=0, high=i32max, size=[n], dtype=torch.int32, device=device + ) + else: + rand_key = torch.randint( + low=0, high=i64max, size=[n], dtype=torch.int64, device=device + ) + perm_range = sort_by_key(rand_key, in_range) + return perm_range diff --git a/src/flag_gems/ops/topk.py b/src/flag_gems/ops/topk.py index 53d6350d..b1607b01 100644 --- a/src/flag_gems/ops/topk.py +++ b/src/flag_gems/ops/topk.py @@ -116,8 +116,12 @@ def _compare_and_swap(x, ids, flip, i: core.constexpr, n_dims: core.constexpr): left = core.reshape(left, x.shape) right = core.reshape(right, x.shape) - left_idx = core.broadcast_to(tl.sum(y_idx * (1 - mask), 1)[:, None, :], shape) - right_idx = core.broadcast_to(tl.sum(y_idx * mask, 1)[:, None, :], shape) + left_idx = core.broadcast_to(tl.sum(y_idx * (1 - mask), 1)[:, None, :], shape).to( + ids.dtype + ) + right_idx = core.broadcast_to(tl.sum(y_idx * mask, 1)[:, None, :], shape).to( + ids.dtype + ) left_idx = core.reshape(left_idx, ids.shape) right_idx = core.reshape(right_idx, ids.shape) diff --git a/tests/test_tensor_constructor_ops.py b/tests/test_tensor_constructor_ops.py index 88829e71..92858e73 100644 --- a/tests/test_tensor_constructor_ops.py +++ b/tests/test_tensor_constructor_ops.py @@ -4,6 +4,7 @@ import flag_gems from .accuracy_utils import ( + ALL_INT_DTYPES, DISTRIBUTION_SHAPES, FLOAT_DTYPES, POINTWISE_SHAPES, @@ -102,3 +103,16 @@ def test_accuracy_full_like(shape, dtype): with flag_gems.use_gems(): res_out = torch.full_like(x, 3.1415926) gems_assert_equal(res_out, torch.full_like(x, 3.1415926)) + + +@pytest.mark.parametrize("n", [123, 12345, 123456]) +@pytest.mark.parametrize("dtype", ALL_INT_DTYPES) +def test_accuracy_randperm(n, dtype): + if n > torch.iinfo(torch.int16).max: + return + ref_out = torch.randperm(n, dtype=dtype, device="cuda") + with flag_gems.use_gems(): + res_out = torch.randperm(n, dtype=dtype, device="cuda") + sorted_ref, _ = torch.sort(ref_out) + sorted_res, _ = torch.sort(res_out) + gems_assert_equal(sorted_ref, sorted_res)