Skip to content

Commit

Permalink
[CACHE] Verify that when preloading a kernel its name matches what we…
Browse files Browse the repository at this point in the history
… have in specialization_data (#3395)

Adding helpful error message on mismatching name during preloading a
kernel.
  • Loading branch information
pawelszczerbuk authored Mar 15, 2024
1 parent 55bb887 commit f08bdc1
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 4 deletions.
10 changes: 10 additions & 0 deletions python/test/unit/runtime/test_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,12 @@ def kernel_add(a, b, o, N: tl.constexpr, type: tl.constexpr):
tl.device_assert(idx < 32, "idx < 32")
tl.store(o + idx, tl.load(a + idx) + tl.load(b + idx))

@triton.jit
def kernel_sub(a, b, o, N: tl.constexpr, type: tl.constexpr):
idx = tl.arange(0, N)
tl.device_assert(idx < 32, "idx < 32")
tl.store(o + idx, tl.load(a + idx) - tl.load(b + idx))

device = torch.cuda.current_device()

# get the serialized specialization data
Expand Down Expand Up @@ -343,3 +349,7 @@ def inc_counter(*args, **kwargs):
assert counter == 0
assert len(kernel_add.cache[device]) == 1
assert final_kernel.hash == hash

# test that we can't preload a mismatched kernel
with pytest.raises(RuntimeError, match="Specialization data is for"):
kernel_sub.preload(specialization_data)
11 changes: 7 additions & 4 deletions python/triton/runtime/jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,12 +190,12 @@ def __getitem__(self, grid) -> T:
# return cast(T, functools.partial(cast(Callable, self.run), grid=grid))


def serialize_specialization_data(signature, constants, attrs, options, key):
def serialize_specialization_data(name, signature, constants, attrs, options, key):
constants = {key: str(value) if value.__class__.__name__ == "dtype" else value for key, value in constants.items()}
import json
obj = {
'signature': signature, 'constants': constants, 'attrs': attrs.to_dict(), 'options': options.__dict__, 'key':
key
'name': name, 'signature': signature, 'constants': constants, 'attrs': attrs.to_dict(), 'options':
options.__dict__, 'key': key
}
serialized_obj = json.dumps(obj)
return serialized_obj
Expand Down Expand Up @@ -329,7 +329,7 @@ def __init__(self, module, name, jit_function):
self.jit_function = jit_function
pass

specialization_data = serialize_specialization_data(signature, constants, configs[0], options, key)
specialization_data = serialize_specialization_data(name, signature, constants, configs[0], options, key)

kwargs = dict(
signature=signature,
Expand Down Expand Up @@ -508,6 +508,9 @@ def preload(self, specialization_data):
import triton.language as tl
device = driver.active.get_current_device()
deserialized_obj = json.loads(specialization_data)
if deserialized_obj['name'] != self.fn.__name__:
raise RuntimeError(
f"Specialization data is for {deserialized_obj['name']} but trying to preload for {self.fn.__name__}")
constants = {
int(key): tl.dtype(value) if tl.dtype.is_dtype(value) else value
for key, value in deserialized_obj['constants'].items()
Expand Down

0 comments on commit f08bdc1

Please sign in to comment.