diff --git a/tests/test_general_reduction_ops.py b/tests/test_general_reduction_ops.py index 8dcaa411..0f1ccdd2 100644 --- a/tests/test_general_reduction_ops.py +++ b/tests/test_general_reduction_ops.py @@ -127,7 +127,7 @@ def test_accuracy_max_without_dim(shape, dtype): @pytest.mark.parametrize("shape", REDUCTION_SHAPES) @pytest.mark.parametrize("dtype", FLOAT_DTYPES) def test_accuracy_max_without_dim_uncontiguous(shape, dtype): - inp = torch.randn(shape, dtype=dtype, device="cuda")[::2,] + inp = torch.randn(shape, dtype=dtype, device="cuda")[::2,::2] ref_inp = to_reference(inp) ref_out = torch.max(ref_inp)