Skip to content

Commit

Permalink
.
Browse files Browse the repository at this point in the history
  • Loading branch information
[email protected] committed Nov 4, 2024
1 parent 7ecdd12 commit 8a195cd
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 0 deletions.
2 changes: 2 additions & 0 deletions src/flag_gems/ops/max.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,8 @@ def max_kernel(

def max(inp):
logging.debug("GEMS MAX")
if not inp.is_contiguous():
inp = inp.contiguous()
M = inp.numel()
block_size = triton.next_power_of_2(math.ceil(math.sqrt(M)))
mid_size = triton.cdiv(M, block_size)
Expand Down
12 changes: 12 additions & 0 deletions tests/test_general_reduction_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,18 @@ def test_accuracy_max_without_dim(shape, dtype):

gems_assert_equal(res_out, ref_out)

@pytest.mark.max
@pytest.mark.parametrize("shape", REDUCTION_SHAPES)
@pytest.mark.parametrize("dtype", FLOAT_DTYPES)
def test_accuracy_max_without_dim_uncontiguous(shape, dtype):
inp = torch.randn(shape, dtype=dtype, device="cuda")[::2,]
ref_inp = to_reference(inp)

ref_out = torch.max(ref_inp)
with flag_gems.use_gems():
res_out = torch.max(inp)

gems_assert_equal(res_out, ref_out)

# TODO: failed at (200, 40999, 3), while successed at this shape in mean_dim
@pytest.mark.max
Expand Down

0 comments on commit 8a195cd

Please sign in to comment.