diff --git a/benchmark/result_test_reduction_perf--level_core--record_log-k_test_slice_scatter_perf.log b/benchmark/result_test_reduction_perf--level_core--record_log-k_test_slice_scatter_perf.log deleted file mode 100644 index 79604efa..00000000 --- a/benchmark/result_test_reduction_perf--level_core--record_log-k_test_slice_scatter_perf.log +++ /dev/null @@ -1,3 +0,0 @@ -[INFO] {"op_name": "slice_scatter", "dtype": "torch.float16", "mode": "cuda", "level": "core", "result": [{"legacy_shape": null, "shape_detail": [[64, 64], [64, 32], 1, 0, 64, 2], "latency_base": 0.014751999638974667, "latency": 0.009600000455975533, "speedup": 1.5366665560721162, "accuracy": null, "tflops": null, "utilization": null}, {"legacy_shape": null, "shape_detail": [[256, 256], [256, 128], 1, 0, 256, 2], "latency_base": 0.012671999633312225, "latency": 0.01152000017464161, "speedup": 1.0999999514936167, "accuracy": null, "tflops": null, "utilization": null}, {"legacy_shape": 1024, "shape_detail": [[1024, 1024], [504, 1024], 0, 16, 1024, 2], "latency_base": 0.018271999433636665, "latency": 0.012384000234305859, "speedup": 1.4754521227333324, "accuracy": null, "tflops": null, "utilization": null}, {"legacy_shape": null, "shape_detail": [[4096, 4096], [4096, 504], 1, 16, 1024, 2], "latency_base": 0.07599999755620956, "latency": 0.23561599850654602, "speedup": 0.3225587313167874, "accuracy": null, "tflops": null, "utilization": null}, {"legacy_shape": 65536, "shape_detail": [[1024, 65536], [504, 65536], 0, 16, 1024, 2], "latency_base": 0.35094401240348816, "latency": 0.2149440050125122, "speedup": 1.632722961419925, "accuracy": null, "tflops": null, "utilization": null}]} -[INFO] {"op_name": "slice_scatter", "dtype": "torch.float32", "mode": "cuda", "level": "core", "result": [{"legacy_shape": null, "shape_detail": [[64, 64], [64, 32], 1, 0, 64, 2], "latency_base": 0.010912000201642513, "latency": 0.008511999621987343, "speedup": 1.2819549678380777, "accuracy": null, "tflops": null, "utilization": null}, {"legacy_shape": null, "shape_detail": [[256, 256], [256, 128], 1, 0, 256, 2], "latency_base": 0.01190400030463934, "latency": 0.009568000212311745, "speedup": 1.2441471614226887, "accuracy": null, "tflops": null, "utilization": null}, {"legacy_shape": 1024, "shape_detail": [[1024, 1024], [504, 1024], 0, 16, 1024, 2], "latency_base": 0.02191999927163124, "latency": 0.013311999849975109, "speedup": 1.6466345792268189, "accuracy": null, "tflops": null, "utilization": null}, {"legacy_shape": null, "shape_detail": [[4096, 4096], [504, 4096], 0, 16, 1024, 2], "latency_base": 0.12726399302482605, "latency": 0.11184000223875046, "speedup": 1.1379112167142953, "accuracy": null, "tflops": null, "utilization": null}, {"legacy_shape": 65536, "shape_detail": [[1024, 65536], [504, 65536], 0, 16, 1024, 2], "latency_base": 0.6122879981994629, "latency": 0.4086720049381256, "speedup": 1.498238173403058, "accuracy": null, "tflops": null, "utilization": null}]} -[INFO] {"op_name": "slice_scatter", "dtype": "torch.bfloat16", "mode": "cuda", "level": "core", "result": [{"legacy_shape": null, "shape_detail": [[64, 64], [64, 32], 1, 0, 64, 2], "latency_base": 0.012128000147640705, "latency": 0.008832000195980072, "speedup": 1.3731883920429286, "accuracy": null, "tflops": null, "utilization": null}, {"legacy_shape": null, "shape_detail": [[256, 256], [128, 256], 0, 0, 256, 2], "latency_base": 0.012768000364303589, "latency": 0.008415999822318554, "speedup": 1.5171103414764673, "accuracy": null, "tflops": null, "utilization": null}, {"legacy_shape": 1024, "shape_detail": [[1024, 1024], [504, 1024], 0, 16, 1024, 2], "latency_base": 0.018271999433636665, "latency": 0.01158399973064661, "speedup": 1.5773480540832796, "accuracy": null, "tflops": null, "utilization": null}, {"legacy_shape": null, "shape_detail": [[4096, 4096], [504, 4096], 0, 16, 1024, 2], "latency_base": 0.07478400319814682, "latency": 0.0629120022058487, "speedup": 1.1887080457788137, "accuracy": null, "tflops": null, "utilization": null}, {"legacy_shape": 65536, "shape_detail": [[1024, 65536], [504, 65536], 0, 16, 1024, 2], "latency_base": 0.35094401240348816, "latency": 0.21622399985790253, "speedup": 1.6230576283581866, "accuracy": null, "tflops": null, "utilization": null}]} diff --git a/tests/accuracy_utils.py b/tests/accuracy_utils.py index 38aea66e..0be6544a 100644 --- a/tests/accuracy_utils.py +++ b/tests/accuracy_utils.py @@ -95,6 +95,8 @@ def SkipTorchVersion(skip_pattern): + TRANSPOSED_SHAPE_STRIDES_3D ) +IRREGULAR_SHAPE_STRIDES = [((10, 10, 10, 10, 10), (1, 10000, 23, 399, 1024))] + UPSAMPLE_SHAPES = [ (32, 16, 128, 128), (15, 37, 256, 256), diff --git a/tests/test_reduction_ops.py b/tests/test_reduction_ops.py index 068d4916..4197ccc9 100644 --- a/tests/test_reduction_ops.py +++ b/tests/test_reduction_ops.py @@ -9,6 +9,7 @@ CONTIGUOUS_SHAPE_STRIDES_2D, FLOAT_DTYPES, INT_DTYPES, + IRREGULAR_SHAPE_STRIDES, REDUCTION_SHAPES, REDUCTION_SMALL_SHAPES, SHAPE_STRIDES, @@ -36,7 +37,7 @@ if QUICK_MODE else list(zip([1, 0.1, 0], REDUCTION_SHAPES)) ) -DIM_SHAPE_STRIDE = ( +DIM_SHAPE_STRIDES = ( [(1, *CONTIGUOUS_SHAPE_STRIDES_2D[1])] if QUICK_MODE else list( @@ -44,6 +45,16 @@ for shape, stride in SHAPE_STRIDES ) ) +REGULAR_DIM_SHAPE_STRIDES = ( + [(1, *CONTIGUOUS_SHAPE_STRIDES_2D[1])] + if QUICK_MODE + else list( + (random.randint(0, len(shape) - 1), shape, stride) + for shape, stride in CONTIGUOUS_SHAPE_STRIDES_2D + ) +) +IRREGULAR_DIM_SHAPE_STRIDES = [(3, *IRREGULAR_SHAPE_STRIDES)] + THRESHOLD_SHAPE = ( [(0.3, REDUCTION_SHAPES[0])] if QUICK_MODE @@ -474,13 +485,12 @@ def test_accuracy_select_scatter(shape, dim, dtype): @pytest.mark.slice_scatter -@pytest.mark.parametrize(("dim", "shape", "stride"), DIM_SHAPE_STRIDE) +@pytest.mark.parametrize(("dim", "shape", "stride"), DIM_SHAPE_STRIDES) @pytest.mark.parametrize("dtype", FLOAT_DTYPES) @pytest.mark.parametrize("start", [16, 64]) @pytest.mark.parametrize("end", [1024, 256]) @pytest.mark.parametrize("step", [1, 2]) -def test_accuracy_slice_scatter(shape, stride, dim, dtype, start, end, step): - # inp = torch.randn(shape, dtype=dtype, device="cuda") +def test_accuracy_slice_scatter_v2(shape, stride, dim, dtype, start, end, step): inp = torch.empty_strided(shape, stride, dtype=dtype, device="cuda") inp.copy_(1) @@ -512,6 +522,44 @@ def test_accuracy_slice_scatter(shape, stride, dim, dtype, start, end, step): gems_assert_equal(res_out, ref_out) +@pytest.mark.slice_scatter +@pytest.mark.parametrize(("dim", "shape", "stride"), REGULAR_DIM_SHAPE_STRIDES) +@pytest.mark.parametrize("dtype", FLOAT_DTYPES) +@pytest.mark.parametrize("start", [16, 64]) +@pytest.mark.parametrize("end", [1024, 256]) +@pytest.mark.parametrize("step", [1, 2]) +def test_accuracy_slice_scatter_fallback(shape, stride, dim, dtype, start, end, step): + inp = torch.empty_strided(shape, stride, dtype=dtype, device="cuda") + inp.copy_(1) + + valid_shape = list(inp.shape) + size = valid_shape[dim] + + start = start % size + end = end % (size + 1) + + if end < start: + end, start = start, end + elif end == start: + end = size + + valid_shape[dim] = (end - start + step - 1) // step + + src = torch.rand(valid_shape, dtype=dtype, device="cuda") + + ref_inp = to_reference(inp) + ref_src = to_reference(src) + ref_out = torch.slice_scatter( + ref_inp, dim=dim, src=ref_src, start=start, end=end, step=step + ) + + res_out = flag_gems.ops.slice_scatter( + inp, dim=dim, src=src, start=start, end=end, step=step + ) + + gems_assert_equal(res_out, ref_out) + + # TODO: failed at (200, 40999, 3) @pytest.mark.index_select @pytest.mark.parametrize("shape", REDUCTION_SHAPES)