From 8a195cd0d35cfaee5faee41488666c59d7f94676 Mon Sep 17 00:00:00 2001 From: "hanhaowen@sensetime.com" Date: Mon, 4 Nov 2024 07:43:59 +0000 Subject: [PATCH] . --- src/flag_gems/ops/max.py | 2 ++ tests/test_general_reduction_ops.py | 12 ++++++++++++ 2 files changed, 14 insertions(+) diff --git a/src/flag_gems/ops/max.py b/src/flag_gems/ops/max.py index 36c24454..31924b8b 100644 --- a/src/flag_gems/ops/max.py +++ b/src/flag_gems/ops/max.py @@ -93,6 +93,8 @@ def max_kernel( def max(inp): logging.debug("GEMS MAX") + if not inp.is_contiguous(): + inp = inp.contiguous() M = inp.numel() block_size = triton.next_power_of_2(math.ceil(math.sqrt(M))) mid_size = triton.cdiv(M, block_size) diff --git a/tests/test_general_reduction_ops.py b/tests/test_general_reduction_ops.py index 7d4f7584..0662589d 100644 --- a/tests/test_general_reduction_ops.py +++ b/tests/test_general_reduction_ops.py @@ -122,6 +122,18 @@ def test_accuracy_max_without_dim(shape, dtype): gems_assert_equal(res_out, ref_out) +@pytest.mark.max +@pytest.mark.parametrize("shape", REDUCTION_SHAPES) +@pytest.mark.parametrize("dtype", FLOAT_DTYPES) +def test_accuracy_max_without_dim_uncontiguous(shape, dtype): + inp = torch.randn(shape, dtype=dtype, device="cuda")[::2,] + ref_inp = to_reference(inp) + + ref_out = torch.max(ref_inp) + with flag_gems.use_gems(): + res_out = torch.max(inp) + + gems_assert_equal(res_out, ref_out) # TODO: failed at (200, 40999, 3), while successed at this shape in mean_dim @pytest.mark.max