Skip to content

Commit

Permalink
Merge branch 'master' into dev_xcoresigma_jiangbin_conv
Browse files Browse the repository at this point in the history
  • Loading branch information
FatJhon authored Nov 4, 2024
2 parents ce12a38 + 7ecdd12 commit 4a0527d
Show file tree
Hide file tree
Showing 6 changed files with 92 additions and 4 deletions.
29 changes: 29 additions & 0 deletions benchmark/test_special_perf.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,3 +244,32 @@ def upsample_nearest2d_input_fn(shape, dtype, device):
dtypes=FLOAT_DTYPES,
)
bench.run()


def test_perf_repeat_interleave_self_tensor():
def repeat_interleave_self_tensor_arg(dtype, batch, size):
inp = torch.randn([batch, size], dtype=dtype, device="cuda")
repeats = torch.randint(
low=0,
high=0x2F,
size=[
batch,
],
device="cuda",
)
dim = 0
return (
inp,
repeats,
dim,
)

bench = Benchmark(
op_name="repeat_interleave_self_tensor",
torch_op=torch.repeat_interleave,
arg_func=repeat_interleave_self_tensor_arg,
dtypes=FLOAT_DTYPES,
batch=POINTWISE_BATCH,
sizes=SIZES,
)
bench.run()
1 change: 1 addition & 0 deletions src/flag_gems/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,7 @@ def enable(lib=aten_lib):
lib.impl("conv2d", conv2d, "CUDA")
lib.impl("conv1d", conv1d, "CUDA")
lib.impl("_conv_depthwise2d", _conv_depthwise2d, "CUDA")
lib.impl("repeat_interleave.self_Tensor", repeat_interleave_self_tensor, "CUDA")


class use_gems:
Expand Down
7 changes: 6 additions & 1 deletion src/flag_gems/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,11 @@
from .reciprocal import reciprocal
from .relu import relu
from .repeat import repeat
from .repeat_interleave import repeat_interleave_self_int, repeat_interleave_tensor
from .repeat_interleave import (
repeat_interleave_self_int,
repeat_interleave_self_tensor,
repeat_interleave_tensor,
)
from .resolve_conj import resolve_conj
from .resolve_neg import resolve_neg
from .rms_norm import rms_norm
Expand Down Expand Up @@ -249,4 +253,5 @@
"conv2d",
"conv1d",
"_conv_depthwise2d",
"repeat_interleave_self_tensor",
]
4 changes: 1 addition & 3 deletions src/flag_gems/ops/index_select.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,16 +29,14 @@ def index_select_kernel(
rows_offsets = pid_x * BLOCK_M + tl.arange(0, BLOCK_M)[:, None]
rows_mask = rows_offsets < M
cols_offsets = pid_y * BLOCK_N + tl.arange(0, BLOCK_N)
cols_mask = cols_offsets < N

block_mask = rows_mask and cols_mask
out_mask = rows_mask and (cols_offsets < index_len)

indices = tl.load(index + cols_offsets, mask=(cols_offsets < index_len), other=0)
inp_off = rows_offsets * N + indices[None, :]
out_off = rows_offsets * index_len + cols_offsets[None, :]

selected = tl.load(inp + inp_off, mask=block_mask, other=0.0)
selected = tl.load(inp + inp_off, mask=rows_mask, other=0.0)
tl.store(out + out_off, selected, mask=out_mask)


Expand Down
39 changes: 39 additions & 0 deletions src/flag_gems/ops/repeat_interleave.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,3 +102,42 @@ def repeat_interleave_tensor(repeats, *, output_size=None):
num_warps=1,
)
return out


def repeat_interleave_self_tensor(inp, repeats, dim=None, *, output_size=None):
logging.debug("GEMS REPEAT_INTERLEAVE_SELF_TENSOR")

if dim is None:
inp = inp.flatten()
dim = 0
else:
if (dim < -inp.ndim) or (dim >= inp.ndim):
raise IndexError(
"Dimension out of range (expected to be in range of [{}, {}], but got {})".format(
-inp.ndim, inp.ndim - 1, dim
)
)

if repeats.ndim == 0 or (repeats.ndim == 1 and repeats.size(0) == 1):
return repeat_interleave_self_int(
inp, repeats.item(), dim=dim, output_size=output_size
)
elif repeats.ndim > 1:
raise RuntimeError("repeats must be 0-dim or 1-dim tensor")

inp_shape = list(inp.shape)
if dim < 0:
dim = dim + len(inp_shape)

if repeats.size(0) != inp_shape[dim]:
raise RuntimeError(
"repeats must have the same size as input along dim, but got \
repeats.size(0) = {} and input.size({}) = {}".format(
repeats.size(0), dim, inp_shape[dim]
)
)

indices = repeat_interleave_tensor(repeats)
res = torch.index_select(inp, dim, indices)

return res
16 changes: 16 additions & 0 deletions tests/test_special_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -796,3 +796,19 @@ def test_accuracy_repeat_interleave_tensor(shape, dtype):
with flag_gems.use_gems():
res_out = torch.repeat_interleave(repeats)
gems_assert_equal(res_out, ref_out)


@pytest.mark.repeat_interleave
@pytest.mark.parametrize("shape", REPEAT_INTERLEAVE_SHAPES)
@pytest.mark.parametrize("dim", [-1, 0, 1])
@pytest.mark.parametrize("dtype", FLOAT_DTYPES)
def test_accuracy_repeat_interleave_self_tensor(shape, dim, dtype):
inp = torch.randn(shape, dtype=dtype, device="cuda")
repeats = torch.randint(0, 30, (shape[dim],), device="cuda")
ref_inp = to_reference(inp)
ref_repeats = to_reference(repeats)

ref_out = torch.repeat_interleave(ref_inp, ref_repeats, dim)
with flag_gems.use_gems():
res_out = torch.repeat_interleave(inp, repeats, dim)
gems_assert_equal(res_out, ref_out)

0 comments on commit 4a0527d

Please sign in to comment.