Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[INTERPRETER] Fix argument passing for internal parameters in function declarations #5169

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 14 additions & 13 deletions python/test/unit/language/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -5581,7 +5581,7 @@ def matmul_kernel( #
stride_cm, stride_cn, #
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, #
low_precision_acc: tl.constexpr, #
num_pipeline_stages: tl.constexpr = 3 #
num_stages: tl.constexpr = 3 #
):
pid = tl.program_id(axis=0)
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
Expand All @@ -5593,7 +5593,7 @@ def matmul_kernel( #
a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
for k in tl.range(0, tl.cdiv(K, BLOCK_SIZE_K), num_stages=num_pipeline_stages):
for k in tl.range(0, tl.cdiv(K, BLOCK_SIZE_K), num_stages=num_stages):
a = tl.load(a_ptrs)
b = tl.load(b_ptrs)
accumulator = tl.dot(a, b, acc=accumulator, max_num_imprecise_acc=low_precision_acc)
Expand Down Expand Up @@ -5632,7 +5632,7 @@ def test_dot_max_num_imprecise_acc(M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, in_type_s
max_num_impressive_acc = low_precision_acc if low_precision_acc <= BLOCK_K else None
h = matmul_kernel[grid](a, b, C, M, N, K, a.stride(0), a.stride(1), b.stride(0), b.stride(1), C.stride(0),
C.stride(1), BLOCK_M, BLOCK_N, BLOCK_K, max_num_impressive_acc, num_warps=num_warps,
num_pipeline_stages=num_stages)
num_stages=num_stages)
torch_a = torch.from_numpy(A).to(device=device)
th_a = f8_to_f16(torch_a, in_type_str)
torch_b = torch.from_numpy(B).to(device=device)
Expand Down Expand Up @@ -5824,7 +5824,7 @@ def test_tl_range(device):
pgm = matmul_kernel[
1,
](a, b, c, M, N, K, a.stride(0), a.stride(1), b.stride(0), b.stride(1), c.stride(0), c.stride(1), BLOCK_M, BLOCK_N,
BLOCK_K, 0, num_pipeline_stages=5)
BLOCK_K, 0, num_stages=5)
ref_out = torch.matmul(a, b).to(torch.float32)
if is_interpreter():
# GPU invokes tensor core for float16 matmul, which is not supported in interpreter.
Expand All @@ -5850,8 +5850,8 @@ def maxnreg_noinline2(X):
tl.store(X, 0)


@pytest.mark.interpreter
def test_maxnreg(device):
assert not is_interpreter(), "this test won't work with the interpreter"
if not is_cuda():
pytest.skip('maxnreg only works on CUDA')

Expand All @@ -5865,14 +5865,15 @@ def kernel(X):
X = torch.empty(1, dtype=torch.int32, device=device)
k = kernel[(1, )](X, maxnreg=42)

# Ensure that .maxnreg is set on the kernel function (marked with .entry)
# and not on either of the noinline functions (marked with .func).
try:
assert re.search(r'\.visible \.entry [^{;]*\.maxnreg 42', k.asm["ptx"])
assert not re.search(r'\.visible \.func [^{;]*\.maxnreg', k.asm["ptx"])
except AssertionError:
print("Failing ptx:\n", k.asm["ptx"])
raise
if not is_interpreter():
# Ensure that .maxnreg is set on the kernel function (marked with .entry)
# and not on either of the noinline functions (marked with .func).
try:
assert re.search(r'\.visible \.entry [^{;]*\.maxnreg 42', k.asm["ptx"])
assert not re.search(r'\.visible \.func [^{;]*\.maxnreg', k.asm["ptx"])
except AssertionError:
print("Failing ptx:\n", k.asm["ptx"])
raise


@pytest.mark.interpreter
Expand Down
7 changes: 5 additions & 2 deletions python/triton/runtime/interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -1077,10 +1077,13 @@ def _restore_args_dev(self, args_dev, args_hst, kwargs, kwargs_hst):
kwarg_dev.data.copy_(kwarg_hst.to(kwarg_dev.device).data)

def __call__(self, *args_dev, **kwargs):
# removes reserved keywords from kwargs
kwargs = {k: v for k, v in kwargs.items() if k not in RESERVED_KWS}
if kwargs.pop("warmup", False):
return
# Removes not used reserved keywords from kwargs
# Triton doesn't support keyword-only, variable positional or variable keyword arguments
# It's safe to inspect only positional or keyword arguments (i.e., argspec.args)
argspec = inspect.getfullargspec(self.fn)
kwargs = {k: v for k, v in kwargs.items() if k in argspec.args}
# copy arguments to the host
args_hst, kwargs_hst = self._init_args_hst(args_dev, kwargs)
# remaps core language functions to interpreted ones
Expand Down
Loading