Skip to content

Commit

Permalink
[FRONTEND] Add _to_tensor for philox parameters (#3390) (#3396)
Browse files Browse the repository at this point in the history
- Otherwise, frontend crashes for non-tensor arguments.

Fixes #3390
  • Loading branch information
lijinpei authored Mar 16, 2024
1 parent 4742a75 commit 1467514
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 18 deletions.
69 changes: 51 additions & 18 deletions python/test/unit/language/test_random.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,11 +114,12 @@ def random_raw(self):
# test generation of random uint32


@pytest.mark.parametrize('size, seed, dtype', [(size, seed, dtype)
for size in ['10', '4,53', '400']
for seed in [0, 42, 124, 54, 0xffffffff, 0x0000000fcafeb0ba]
for dtype in ['int32', 'int64']])
def test_randint(size, seed, device, dtype):
@pytest.mark.parametrize('size, seed, dtype, const_seed', [(size, seed, dtype, const_seed)
for size in ['10', '4,53', '400']
for seed in [0, 42, 124, 54, 0xffffffff, 0x0000000fcafeb0ba]
for dtype in ['int32', 'int64']
for const_seed in [True, False]])
def test_randint(size, seed, device, dtype, const_seed):
size = list(map(int, size.split(',')))
torch_dtype = getattr(torch, dtype)
numpy_dtype = getattr(np, f"u{dtype}")
Expand All @@ -131,11 +132,21 @@ def kernel(X, N, seed):
rand = tl.randint(seed, offset)
tl.store(X + offset, rand, mask=offset < N)

@triton.jit
def const_kernel(X, N, seed: tl.constexpr):
pid = tl.program_id(0).to(X.dtype.element_ty)
offset = pid * BLOCK + tl.arange(0, BLOCK)
rand = tl.randint(seed, offset)
tl.store(X + offset, rand, mask=offset < N)

# triton result
x = torch.empty(size, dtype=torch_dtype, device=device)
N = x.numel()
grid = (triton.cdiv(N, BLOCK), )
kernel[grid](x, N, seed)
if const_seed:
const_kernel[grid](x, N, seed=seed)
else:
kernel[grid](x, N, seed)
out_tri = x.cpu().numpy().astype(numpy_dtype).flatten().tolist()
# reference result
gen = CustomPhilox4x(seed, config=config)
Expand All @@ -146,11 +157,12 @@ def kernel(X, N, seed):
# test uniform PRNG


@pytest.mark.parametrize('size, seed, dtype', [(size, seed, dtype)
for size in [100000]
for seed in [0, 42, 124, 54]
for dtype in ['int32', 'int64']])
def test_rand(size, seed, dtype, device):
@pytest.mark.parametrize('size, seed, dtype, const_seed', [(size, seed, dtype, const_seed)
for size in [100000]
for seed in [0, 42, 124, 54]
for dtype in ['int32', 'int64']
for const_seed in [True, False]])
def test_rand(size, seed, dtype, device, const_seed):

@triton.jit
def kernel(X, N, seed, dtype: tl.constexpr):
Expand All @@ -159,23 +171,34 @@ def kernel(X, N, seed, dtype: tl.constexpr):
rand = tl.rand(seed, offset)
tl.store(X + offset, rand, mask=offset < N)

@triton.jit
def const_kernel(X, N, seed: tl.constexpr, dtype: tl.constexpr):
pid = tl.program_id(0).to(dtype)
offset = pid * BLOCK + tl.arange(0, BLOCK)
rand = tl.rand(seed, offset)
tl.store(X + offset, rand, mask=offset < N)

# triton result
x = torch.empty(size, dtype=torch.float32, device=device)
N = x.numel()
grid = (triton.cdiv(N, BLOCK), )
kernel[grid](x, N, seed, dtype=getattr(tl, dtype))
if const_seed:
const_kernel[grid](x, N, seed=seed, dtype=getattr(tl, dtype))
else:
kernel[grid](x, N, seed, dtype=getattr(tl, dtype))
assert all((x >= 0) & (x <= 1))
assert scipy.stats.kstest(x.tolist(), 'uniform', args=(0, 1)).statistic < 0.01


# test normal PRNG


@pytest.mark.parametrize('size, seed, dtype', [(size, seed, dtype)
for size in [100000]
for seed in [0, 42, 124, 54]
for dtype in ['int32', 'int64']])
def test_randn(size, seed, dtype, device):
@pytest.mark.parametrize('size, seed, dtype, const_seed', [(size, seed, dtype, const_seed)
for size in [100000]
for seed in [0, 42, 124, 54]
for dtype in ['int32', 'int64']
for const_seed in [True, False]])
def test_randn(size, seed, dtype, device, const_seed):

@triton.jit
def kernel(X, N, seed, dtype: tl.constexpr):
Expand All @@ -184,11 +207,21 @@ def kernel(X, N, seed, dtype: tl.constexpr):
rand = tl.randn(seed, offset)
tl.store(X + offset, rand, mask=offset < N)

@triton.jit
def const_kernel(X, N, seed: tl.constexpr, dtype: tl.constexpr):
pid = tl.program_id(0).to(dtype)
offset = pid * BLOCK + tl.arange(0, BLOCK)
rand = tl.randn(seed, offset)
tl.store(X + offset, rand, mask=offset < N)

# triton result
x = torch.empty(size, dtype=torch.float32, device=device)
N = x.numel()
grid = (triton.cdiv(N, BLOCK), )
kernel[grid](x, N, seed, dtype=getattr(tl, dtype))
if const_seed:
const_kernel[grid](x, N, seed=seed, dtype=getattr(tl, dtype))
else:
kernel[grid](x, N, seed, dtype=getattr(tl, dtype))
assert abs(x.mean()) < 1e-2
assert abs(x.std() - 1) < 1e-2

Expand Down
5 changes: 5 additions & 0 deletions python/triton/language/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,11 @@ def is_builtin(fn) -> bool:
return getattr(fn, TRITON_BUILTIN, False)


@builtin
def to_tensor(x, _builder=None):
return _to_tensor(x, _builder)


def _to_tensor(x, builder):
if isinstance(x, bool):
return tensor(builder.get_int1(x), int1)
Expand Down
5 changes: 5 additions & 0 deletions python/triton/language/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,11 @@ def philox_impl(c0, c1, c2, c3, k0, k1, n_rounds: tl.constexpr = N_ROUNDS_DEFAUL

@jit
def philox(seed, c0, c1, c2, c3, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT):
seed = tl.to_tensor(seed)
c0 = tl.to_tensor(c0)
c1 = tl.to_tensor(c1)
c2 = tl.to_tensor(c2)
c3 = tl.to_tensor(c3)
seed = seed.to(tl.uint64)
if tl.constexpr(c0.dtype.primitive_bitwidth) == 32:
int_dtype = tl.uint32
Expand Down

0 comments on commit 1467514

Please sign in to comment.