Skip to content

Commit

Permalink
[INTERPRETER] Correct None tensor check logic (#5049)
Browse files Browse the repository at this point in the history
In the interpreter mode, we cannot use `not tensor` to check if `tensor`
is None or not because the interpreter directly evaluates the tensor.

Also consolidated the test cases for `tl.store`.
  • Loading branch information
Jokeren authored Nov 2, 2024
1 parent 56584c4 commit 530efbb
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 31 deletions.
45 changes: 16 additions & 29 deletions python/test/unit/language/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1740,47 +1740,34 @@ def kernel(X, Y, Z, N: tl.constexpr):

@pytest.mark.interpreter
@pytest.mark.parametrize("dtype_str", list(torch_dtypes))
@pytest.mark.parametrize("constant_field", ["value", "mask"])
@pytest.mark.parametrize("num_ctas", num_ctas_list)
def test_store_constant(dtype_str, num_ctas, device):
def test_store_constant(num_ctas, dtype_str, constant_field, device):
check_type_supported(dtype_str, device)
"""Tests that boolean True is stored as 1"""

@triton.jit
def kernel(output_ptr, n_elements, BLOCK_SIZE: tl.constexpr):
def kernel(output_ptr, n_elements, BLOCK_SIZE: tl.constexpr, CONSTANT_FIELD: tl.constexpr):
offsets = tl.program_id(axis=0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
output = GENERATE_TEST_HERE
if CONSTANT_FIELD == "value":
value = 1
output = tl.full([BLOCK_SIZE], value=value, dtype=value.dtype)
mask = offsets < n_elements
elif CONSTANT_FIELD == "mask":
output = offsets < n_elements
mask = False
tl.store(output_ptr + offsets, output, mask=mask)

triton_dtype_str = 'uint8' if dtype_str == 'bool' else dtype_str
kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': f'tl.zeros([BLOCK_SIZE], dtype=tl.{triton_dtype_str}) + 1'})
block_size = 128
ref = torch.ones([block_size], dtype=getattr(torch, dtype_str), device=device)
output = torch.zeros([block_size], dtype=getattr(torch, dtype_str), device=device)
kernel[(1, )](output, block_size, BLOCK_SIZE=block_size, num_ctas=num_ctas)

assert torch.all(output == ref)

kernel[(1, )](output, block_size, BLOCK_SIZE=block_size, num_ctas=num_ctas, CONSTANT_FIELD=constant_field)

@pytest.mark.interpreter
@pytest.mark.parametrize("num_ctas", num_ctas_list)
def test_store_constant_default_dtype(num_ctas, device):
"""Tests that boolean True is stored as 1"""

@triton.jit
def kernel(output_ptr, n_elements, BLOCK_SIZE: tl.constexpr):
offsets = tl.program_id(axis=0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
value = 1
output = tl.full([BLOCK_SIZE], value=value, dtype=value.dtype)
tl.store(output_ptr + offsets, output, mask=mask)

block_size = 128
ref = torch.ones([block_size], dtype=getattr(torch, 'int32'), device=device)
output = torch.zeros([block_size], dtype=getattr(torch, 'int32'), device=device)
kernel[(1, )](output, block_size, BLOCK_SIZE=block_size, num_ctas=num_ctas)

assert torch.all(output == ref)
if constant_field == "value":
print(output, ref)
assert torch.all(output == ref)
else:
assert torch.all(output == 0)


def test_load_store_same_ptr(device):
Expand Down
4 changes: 2 additions & 2 deletions python/triton/language/semantic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1253,7 +1253,7 @@ def _store_legacy(ptr, val, mask, boundary_check, cache, eviction, builder):
val = cast(val, elt_ty, builder)

# Build IR
if not mask:
if mask is None:
return tl.tensor(builder.create_store(ptr.handle, val.handle, cache, eviction), tl.void)
if not mask.type.scalar.is_bool():
raise ValueError("Mask must have boolean scalar type")
Expand Down Expand Up @@ -1308,7 +1308,7 @@ def atom_red_typechecking_impl(ptr: tl.tensor, val: tl.tensor, mask: tl.tensor,
if val is not None:
val = broadcast_impl_shape(val, ptr.type.get_block_shapes(), builder)
val = cast(val, ptr.type.scalar.element_ty, builder)
if not mask:
if mask is None:
mask_ir = builder.get_int1(True)
mask_ty = tl.int1
if ptr.type.is_block():
Expand Down

0 comments on commit 530efbb

Please sign in to comment.