Skip to content

Commit

Permalink
fix slice_scatter error on 1d inputs.
Browse files Browse the repository at this point in the history
  • Loading branch information
tongxin committed Nov 2, 2024
1 parent ca60fa7 commit 196b7a9
Show file tree
Hide file tree
Showing 5 changed files with 69 additions and 30 deletions.
15 changes: 0 additions & 15 deletions .vscode/launch.json

This file was deleted.

Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
[INFO] {"op_name": "slice_scatter", "dtype": "torch.float16", "mode": "cuda", "level": "core", "result": [{"legacy_shape": null, "shape_detail": [[64, 64], [64, 32], 1, 0, 64, 2], "latency_base": 0.014751999638974667, "latency": 0.009600000455975533, "speedup": 1.5366665560721162, "accuracy": null, "tflops": null, "utilization": null}, {"legacy_shape": null, "shape_detail": [[256, 256], [256, 128], 1, 0, 256, 2], "latency_base": 0.012671999633312225, "latency": 0.01152000017464161, "speedup": 1.0999999514936167, "accuracy": null, "tflops": null, "utilization": null}, {"legacy_shape": 1024, "shape_detail": [[1024, 1024], [504, 1024], 0, 16, 1024, 2], "latency_base": 0.018271999433636665, "latency": 0.012384000234305859, "speedup": 1.4754521227333324, "accuracy": null, "tflops": null, "utilization": null}, {"legacy_shape": null, "shape_detail": [[4096, 4096], [4096, 504], 1, 16, 1024, 2], "latency_base": 0.07599999755620956, "latency": 0.23561599850654602, "speedup": 0.3225587313167874, "accuracy": null, "tflops": null, "utilization": null}, {"legacy_shape": 65536, "shape_detail": [[1024, 65536], [504, 65536], 0, 16, 1024, 2], "latency_base": 0.35094401240348816, "latency": 0.2149440050125122, "speedup": 1.632722961419925, "accuracy": null, "tflops": null, "utilization": null}]}
[INFO] {"op_name": "slice_scatter", "dtype": "torch.float32", "mode": "cuda", "level": "core", "result": [{"legacy_shape": null, "shape_detail": [[64, 64], [64, 32], 1, 0, 64, 2], "latency_base": 0.010912000201642513, "latency": 0.008511999621987343, "speedup": 1.2819549678380777, "accuracy": null, "tflops": null, "utilization": null}, {"legacy_shape": null, "shape_detail": [[256, 256], [256, 128], 1, 0, 256, 2], "latency_base": 0.01190400030463934, "latency": 0.009568000212311745, "speedup": 1.2441471614226887, "accuracy": null, "tflops": null, "utilization": null}, {"legacy_shape": 1024, "shape_detail": [[1024, 1024], [504, 1024], 0, 16, 1024, 2], "latency_base": 0.02191999927163124, "latency": 0.013311999849975109, "speedup": 1.6466345792268189, "accuracy": null, "tflops": null, "utilization": null}, {"legacy_shape": null, "shape_detail": [[4096, 4096], [504, 4096], 0, 16, 1024, 2], "latency_base": 0.12726399302482605, "latency": 0.11184000223875046, "speedup": 1.1379112167142953, "accuracy": null, "tflops": null, "utilization": null}, {"legacy_shape": 65536, "shape_detail": [[1024, 65536], [504, 65536], 0, 16, 1024, 2], "latency_base": 0.6122879981994629, "latency": 0.4086720049381256, "speedup": 1.498238173403058, "accuracy": null, "tflops": null, "utilization": null}]}
[INFO] {"op_name": "slice_scatter", "dtype": "torch.bfloat16", "mode": "cuda", "level": "core", "result": [{"legacy_shape": null, "shape_detail": [[64, 64], [64, 32], 1, 0, 64, 2], "latency_base": 0.012128000147640705, "latency": 0.008832000195980072, "speedup": 1.3731883920429286, "accuracy": null, "tflops": null, "utilization": null}, {"legacy_shape": null, "shape_detail": [[256, 256], [128, 256], 0, 0, 256, 2], "latency_base": 0.012768000364303589, "latency": 0.008415999822318554, "speedup": 1.5171103414764673, "accuracy": null, "tflops": null, "utilization": null}, {"legacy_shape": 1024, "shape_detail": [[1024, 1024], [504, 1024], 0, 16, 1024, 2], "latency_base": 0.018271999433636665, "latency": 0.01158399973064661, "speedup": 1.5773480540832796, "accuracy": null, "tflops": null, "utilization": null}, {"legacy_shape": null, "shape_detail": [[4096, 4096], [504, 4096], 0, 16, 1024, 2], "latency_base": 0.07478400319814682, "latency": 0.0629120022058487, "speedup": 1.1887080457788137, "accuracy": null, "tflops": null, "utilization": null}, {"legacy_shape": 65536, "shape_detail": [[1024, 65536], [504, 65536], 0, 16, 1024, 2], "latency_base": 0.35094401240348816, "latency": 0.21622399985790253, "speedup": 1.6230576283581866, "accuracy": null, "tflops": null, "utilization": null}]}
7 changes: 7 additions & 0 deletions src/flag_gems/ops/slice_scatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,9 @@ def simplify(x, retain_dim, ordered_dims=None):
# and/or one inner dim.
ordered_dims = ordered_dims or sorted(range(x.ndim), key=lambda i: x.stride(i))
assert x.ndim == len(ordered_dims)
if len(ordered_dims) == 1:
return x, ordered_dims

size_list = [x.size(dim) for dim in ordered_dims]
stride_list = [x.stride(dim) for dim in ordered_dims]

Expand Down Expand Up @@ -312,6 +315,7 @@ def scatter_3d_mid_kernel(
@libentry()
@triton.autotune(
configs=[
triton.Config(kwargs={"R": 1, "C": 512}, num_warps=4),
triton.Config(kwargs={"R": 32, "C": 32}, num_warps=4),
triton.Config(kwargs={"R": 64, "C": 64}, num_warps=4),
triton.Config(kwargs={"R": 4, "C": 512}, num_warps=4),
Expand Down Expand Up @@ -561,6 +565,9 @@ def slice_scatter_v2(inp, src, dim=0, start=None, end=None, step=1):

if new_out is not None and new_src is not None:
if dim == ordered_dims[0]:
if len(ordered_dims) == 1:
new_out = new_out.unsqueeze(0)
new_src = new_src.unsqueeze(0)
# slices on inner dim
scatter_2d_inner(inp, new_src, new_out, start, end, step)
elif new_src.stride(-1) == new_out.stride(-1) == 1 and new_src.size(-1) >= 128:
Expand Down
35 changes: 34 additions & 1 deletion tests/accuracy_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,40 @@ def SkipTorchVersion(skip_pattern):
[(16, 256), (16, 256)],
[(20, 320, 15), (20, 320, 15), (20, 320, 15)],
]

CONTIGUOUS_SHAPE_STRIDES_1D = [
((1,), (1,)),
((1024,), (1,)),
((1000000,), (1,)),
]
DILATED_SHAPE_STRIDES_1D = [
((1,), (2,)),
((1024,), (2,)),
((1000000,), (2,)),
]
CONTIGUOUS_SHAPE_STRIDES_2D = [
((1, 1024), (1024, 1)),
((10000, 128), (128, 1)),
]
TRANSPOSED_SHAPE_STRIDES_2D = [
((1024, 1), (1, 1024)),
((128, 10000), (1, 128)),
]
CONTIGUOUS_SHAPE_STRIDES_3D = [
((20, 320, 15), (4800, 15, 1)),
((200, 40999, 3), (122997, 3, 1)),
]
TRANSPOSED_SHAPE_STRIDES_3D = [
((320, 20, 15), (15, 4800, 1)),
((3, 40999, 200), (1, 3, 122997)),
]
SHAPE_STRIDES = (
CONTIGUOUS_SHAPE_STRIDES_1D
+ DILATED_SHAPE_STRIDES_1D
+ CONTIGUOUS_SHAPE_STRIDES_2D
+ TRANSPOSED_SHAPE_STRIDES_2D
+ CONTIGUOUS_SHAPE_STRIDES_3D
+ TRANSPOSED_SHAPE_STRIDES_3D
)

UPSAMPLE_SHAPES = [
(32, 16, 128, 128),
Expand Down
39 changes: 25 additions & 14 deletions tests/test_reduction_ops.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,17 @@
import random

import pytest
import torch

import flag_gems

from .accuracy_utils import (
CONTIGUOUS_SHAPE_STRIDES_2D,
FLOAT_DTYPES,
INT_DTYPES,
REDUCTION_SHAPES,
REDUCTION_SMALL_SHAPES,
SHAPE_STRIDES,
gems_assert_close,
gems_assert_equal,
to_reference,
Expand All @@ -32,10 +36,13 @@
if QUICK_MODE
else list(zip([1, 0.1, 0], REDUCTION_SHAPES))
)
DIM_SHAPE = (
[(1, REDUCTION_SMALL_SHAPES[0])]
DIM_SHAPE_STRIDE = (
[(1, *CONTIGUOUS_SHAPE_STRIDES_2D[1])]
if QUICK_MODE
else list(zip([0, 1, 1], REDUCTION_SMALL_SHAPES))
else list(
(random.randint(0, len(shape) - 1), shape, stride)
for shape, stride in SHAPE_STRIDES
)
)
THRESHOLD_SHAPE = (
[(0.3, REDUCTION_SHAPES[0])]
Expand Down Expand Up @@ -467,26 +474,30 @@ def test_accuracy_select_scatter(shape, dim, dtype):


@pytest.mark.slice_scatter
@pytest.mark.parametrize(("dim", "shape"), DIM_SHAPE)
@pytest.mark.parametrize(("dim", "shape", "stride"), DIM_SHAPE_STRIDE)
@pytest.mark.parametrize("dtype", FLOAT_DTYPES)
@pytest.mark.parametrize("start", [16, 64])
@pytest.mark.parametrize("end", [1024, 256])
@pytest.mark.parametrize("step", [1, 2])
def test_accuracy_slice_scatter(shape, dim, dtype, start, end, step):
inp = torch.randn(shape, dtype=dtype, device="cuda")
def test_accuracy_slice_scatter(shape, stride, dim, dtype, start, end, step):
# inp = torch.randn(shape, dtype=dtype, device="cuda")
inp = torch.empty_strided(shape, stride, dtype=dtype, device="cuda")
inp.copy_(1)

range = end - start
valid_shape = list(inp.shape)
size = valid_shape[dim]

start = start % size
end = end % (size + 1)

if end < start:
range = 0
elif (end - start) > valid_shape[dim]:
range = valid_shape[dim]
start = 0
end = valid_shape[dim]
end, start = start, end
elif end == start:
end = size

valid_shape[dim] = (range + (step - 1)) // step
valid_shape[dim] = (end - start + step - 1) // step

src = torch.randn(valid_shape, dtype=dtype, device="cuda")
src = torch.rand(valid_shape, dtype=dtype, device="cuda")

ref_inp = to_reference(inp)
ref_src = to_reference(src)
Expand Down

0 comments on commit 196b7a9

Please sign in to comment.