Skip to content

Commit

Permalink
test slice_scatter fallback
Browse files Browse the repository at this point in the history
  • Loading branch information
tongxin committed Nov 4, 2024
1 parent 196b7a9 commit c37efae
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 7 deletions.

This file was deleted.

2 changes: 2 additions & 0 deletions tests/accuracy_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,8 @@ def SkipTorchVersion(skip_pattern):
+ TRANSPOSED_SHAPE_STRIDES_3D
)

IRREGULAR_SHAPE_STRIDES = [((10, 10, 10, 10, 10), (1, 10000, 23, 399, 1024))]

UPSAMPLE_SHAPES = [
(32, 16, 128, 128),
(15, 37, 256, 256),
Expand Down
56 changes: 52 additions & 4 deletions tests/test_reduction_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
CONTIGUOUS_SHAPE_STRIDES_2D,
FLOAT_DTYPES,
INT_DTYPES,
IRREGULAR_SHAPE_STRIDES,
REDUCTION_SHAPES,
REDUCTION_SMALL_SHAPES,
SHAPE_STRIDES,
Expand Down Expand Up @@ -36,14 +37,24 @@
if QUICK_MODE
else list(zip([1, 0.1, 0], REDUCTION_SHAPES))
)
DIM_SHAPE_STRIDE = (
DIM_SHAPE_STRIDES = (
[(1, *CONTIGUOUS_SHAPE_STRIDES_2D[1])]
if QUICK_MODE
else list(
(random.randint(0, len(shape) - 1), shape, stride)
for shape, stride in SHAPE_STRIDES
)
)
REGULAR_DIM_SHAPE_STRIDES = (
[(1, *CONTIGUOUS_SHAPE_STRIDES_2D[1])]
if QUICK_MODE
else list(
(random.randint(0, len(shape) - 1), shape, stride)
for shape, stride in CONTIGUOUS_SHAPE_STRIDES_2D
)
)
IRREGULAR_DIM_SHAPE_STRIDES = [(3, *IRREGULAR_SHAPE_STRIDES)]

THRESHOLD_SHAPE = (
[(0.3, REDUCTION_SHAPES[0])]
if QUICK_MODE
Expand Down Expand Up @@ -474,13 +485,12 @@ def test_accuracy_select_scatter(shape, dim, dtype):


@pytest.mark.slice_scatter
@pytest.mark.parametrize(("dim", "shape", "stride"), DIM_SHAPE_STRIDE)
@pytest.mark.parametrize(("dim", "shape", "stride"), DIM_SHAPE_STRIDES)
@pytest.mark.parametrize("dtype", FLOAT_DTYPES)
@pytest.mark.parametrize("start", [16, 64])
@pytest.mark.parametrize("end", [1024, 256])
@pytest.mark.parametrize("step", [1, 2])
def test_accuracy_slice_scatter(shape, stride, dim, dtype, start, end, step):
# inp = torch.randn(shape, dtype=dtype, device="cuda")
def test_accuracy_slice_scatter_v2(shape, stride, dim, dtype, start, end, step):
inp = torch.empty_strided(shape, stride, dtype=dtype, device="cuda")
inp.copy_(1)

Expand Down Expand Up @@ -512,6 +522,44 @@ def test_accuracy_slice_scatter(shape, stride, dim, dtype, start, end, step):
gems_assert_equal(res_out, ref_out)


@pytest.mark.slice_scatter
@pytest.mark.parametrize(("dim", "shape", "stride"), REGULAR_DIM_SHAPE_STRIDES)
@pytest.mark.parametrize("dtype", FLOAT_DTYPES)
@pytest.mark.parametrize("start", [16, 64])
@pytest.mark.parametrize("end", [1024, 256])
@pytest.mark.parametrize("step", [1, 2])
def test_accuracy_slice_scatter_fallback(shape, stride, dim, dtype, start, end, step):
inp = torch.empty_strided(shape, stride, dtype=dtype, device="cuda")
inp.copy_(1)

valid_shape = list(inp.shape)
size = valid_shape[dim]

start = start % size
end = end % (size + 1)

if end < start:
end, start = start, end
elif end == start:
end = size

valid_shape[dim] = (end - start + step - 1) // step

src = torch.rand(valid_shape, dtype=dtype, device="cuda")

ref_inp = to_reference(inp)
ref_src = to_reference(src)
ref_out = torch.slice_scatter(
ref_inp, dim=dim, src=ref_src, start=start, end=end, step=step
)

res_out = flag_gems.ops.slice_scatter(
inp, dim=dim, src=src, start=start, end=end, step=step
)

gems_assert_equal(res_out, ref_out)


# TODO: failed at (200, 40999, 3)
@pytest.mark.index_select
@pytest.mark.parametrize("shape", REDUCTION_SHAPES)
Expand Down

0 comments on commit c37efae

Please sign in to comment.