diff --git a/benchmark/test_special_perf.py b/benchmark/test_special_perf.py index 8c3c7310..572b03d8 100644 --- a/benchmark/test_special_perf.py +++ b/benchmark/test_special_perf.py @@ -244,3 +244,32 @@ def upsample_nearest2d_input_fn(shape, dtype, device): dtypes=FLOAT_DTYPES, ) bench.run() + + +def test_perf_repeat_interleave_self_tensor(): + def repeat_interleave_self_tensor_arg(dtype, batch, size): + inp = torch.randn([batch, size], dtype=dtype, device="cuda") + repeats = torch.randint( + low=0, + high=0x2F, + size=[ + batch, + ], + device="cuda", + ) + dim = 0 + return ( + inp, + repeats, + dim, + ) + + bench = Benchmark( + op_name="repeat_interleave_self_tensor", + torch_op=torch.repeat_interleave, + arg_func=repeat_interleave_self_tensor_arg, + dtypes=FLOAT_DTYPES, + batch=POINTWISE_BATCH, + sizes=SIZES, + ) + bench.run() diff --git a/src/flag_gems/__init__.py b/src/flag_gems/__init__.py index 7f6901ce..7ffa629e 100644 --- a/src/flag_gems/__init__.py +++ b/src/flag_gems/__init__.py @@ -162,6 +162,7 @@ def enable(lib=aten_lib): lib.impl("conv2d", conv2d, "CUDA") lib.impl("conv1d", conv1d, "CUDA") lib.impl("_conv_depthwise2d", _conv_depthwise2d, "CUDA") + lib.impl("repeat_interleave.self_Tensor", repeat_interleave_self_tensor, "CUDA") class use_gems: diff --git a/src/flag_gems/ops/__init__.py b/src/flag_gems/ops/__init__.py index 9d998f9f..84eef565 100755 --- a/src/flag_gems/ops/__init__.py +++ b/src/flag_gems/ops/__init__.py @@ -77,7 +77,11 @@ from .reciprocal import reciprocal from .relu import relu from .repeat import repeat -from .repeat_interleave import repeat_interleave_self_int, repeat_interleave_tensor +from .repeat_interleave import ( + repeat_interleave_self_int, + repeat_interleave_self_tensor, + repeat_interleave_tensor, +) from .resolve_conj import resolve_conj from .resolve_neg import resolve_neg from .rms_norm import rms_norm @@ -249,4 +253,5 @@ "conv2d", "conv1d", "_conv_depthwise2d", + "repeat_interleave_self_tensor", ] diff --git a/src/flag_gems/ops/index_select.py b/src/flag_gems/ops/index_select.py index b78a764c..e2b13265 100644 --- a/src/flag_gems/ops/index_select.py +++ b/src/flag_gems/ops/index_select.py @@ -29,16 +29,14 @@ def index_select_kernel( rows_offsets = pid_x * BLOCK_M + tl.arange(0, BLOCK_M)[:, None] rows_mask = rows_offsets < M cols_offsets = pid_y * BLOCK_N + tl.arange(0, BLOCK_N) - cols_mask = cols_offsets < N - block_mask = rows_mask and cols_mask out_mask = rows_mask and (cols_offsets < index_len) indices = tl.load(index + cols_offsets, mask=(cols_offsets < index_len), other=0) inp_off = rows_offsets * N + indices[None, :] out_off = rows_offsets * index_len + cols_offsets[None, :] - selected = tl.load(inp + inp_off, mask=block_mask, other=0.0) + selected = tl.load(inp + inp_off, mask=rows_mask, other=0.0) tl.store(out + out_off, selected, mask=out_mask) diff --git a/src/flag_gems/ops/repeat_interleave.py b/src/flag_gems/ops/repeat_interleave.py index 1c170c81..d18017c8 100644 --- a/src/flag_gems/ops/repeat_interleave.py +++ b/src/flag_gems/ops/repeat_interleave.py @@ -102,3 +102,42 @@ def repeat_interleave_tensor(repeats, *, output_size=None): num_warps=1, ) return out + + +def repeat_interleave_self_tensor(inp, repeats, dim=None, *, output_size=None): + logging.debug("GEMS REPEAT_INTERLEAVE_SELF_TENSOR") + + if dim is None: + inp = inp.flatten() + dim = 0 + else: + if (dim < -inp.ndim) or (dim >= inp.ndim): + raise IndexError( + "Dimension out of range (expected to be in range of [{}, {}], but got {})".format( + -inp.ndim, inp.ndim - 1, dim + ) + ) + + if repeats.ndim == 0 or (repeats.ndim == 1 and repeats.size(0) == 1): + return repeat_interleave_self_int( + inp, repeats.item(), dim=dim, output_size=output_size + ) + elif repeats.ndim > 1: + raise RuntimeError("repeats must be 0-dim or 1-dim tensor") + + inp_shape = list(inp.shape) + if dim < 0: + dim = dim + len(inp_shape) + + if repeats.size(0) != inp_shape[dim]: + raise RuntimeError( + "repeats must have the same size as input along dim, but got \ + repeats.size(0) = {} and input.size({}) = {}".format( + repeats.size(0), dim, inp_shape[dim] + ) + ) + + indices = repeat_interleave_tensor(repeats) + res = torch.index_select(inp, dim, indices) + + return res diff --git a/tests/test_special_ops.py b/tests/test_special_ops.py index 13468a4f..0e558d65 100644 --- a/tests/test_special_ops.py +++ b/tests/test_special_ops.py @@ -796,3 +796,19 @@ def test_accuracy_repeat_interleave_tensor(shape, dtype): with flag_gems.use_gems(): res_out = torch.repeat_interleave(repeats) gems_assert_equal(res_out, ref_out) + + +@pytest.mark.repeat_interleave +@pytest.mark.parametrize("shape", REPEAT_INTERLEAVE_SHAPES) +@pytest.mark.parametrize("dim", [-1, 0, 1]) +@pytest.mark.parametrize("dtype", FLOAT_DTYPES) +def test_accuracy_repeat_interleave_self_tensor(shape, dim, dtype): + inp = torch.randn(shape, dtype=dtype, device="cuda") + repeats = torch.randint(0, 30, (shape[dim],), device="cuda") + ref_inp = to_reference(inp) + ref_repeats = to_reference(repeats) + + ref_out = torch.repeat_interleave(ref_inp, ref_repeats, dim) + with flag_gems.use_gems(): + res_out = torch.repeat_interleave(inp, repeats, dim) + gems_assert_equal(res_out, ref_out)