diff --git a/.github/workflows/python-test.yaml b/.github/workflows/python-test.yaml index bd0e5e2f..d140206b 100644 --- a/.github/workflows/python-test.yaml +++ b/.github/workflows/python-test.yaml @@ -29,6 +29,8 @@ jobs: "CUDA_VISIBLE_DEVICES=0 pytest -s tests/test_unary_pointwise_ops.py &" "CUDA_VISIBLE_DEVICES=0 pytest -s tests/test_pointwise_type_promotion.py &" "CUDA_VISIBLE_DEVICES=1 pytest -s tests/test_binary_pointwise_ops.py &" + "CUDA_VISIBLE_DEVICES=1 pytest -s tests/test_tensor_constructor_ops.py &" + "CUDA_VISIBLE_DEVICES=1 pytest -s tests/test_distribution_ops.py &" "CUDA_VISIBLE_DEVICES=2 pytest -s tests/test_blas_ops.py &" "CUDA_VISIBLE_DEVICES=3 pytest -s tests/test_reduction_ops.py &" "CUDA_VISIBLE_DEVICES=4 pytest -s tests/test_special_ops.py &" diff --git a/benchmark/test_distribution_perf.py b/benchmark/test_distribution_perf.py new file mode 100644 index 00000000..3e7cd605 --- /dev/null +++ b/benchmark/test_distribution_perf.py @@ -0,0 +1,96 @@ +import torch + +from .performance_utils import ( + FLOAT_DTYPES, + POINTWISE_BATCH, + SIZES, + Benchmark, + unary_arg, +) + + +def test_perf_rand(): + def rand_kwargs(dtype, batch, size): + return {"size": (batch, size), "dtype": dtype, "device": "cuda"} + + bench = Benchmark( + op_name="rand", + torch_op=torch.rand, + arg_func=None, + dtypes=FLOAT_DTYPES, + batch=POINTWISE_BATCH, + sizes=SIZES, + kwargs_func=rand_kwargs, + ) + bench.run() + + +def test_perf_randn(): + def randn_kwargs(dtype, batch, size): + return {"size": (batch, size), "dtype": dtype, "device": "cuda"} + + bench = Benchmark( + op_name="randn", + torch_op=torch.randn, + arg_func=None, + dtypes=FLOAT_DTYPES, + batch=POINTWISE_BATCH, + sizes=SIZES, + kwargs_func=randn_kwargs, + ) + bench.run() + + +def test_perf_rand_like(): + bench = Benchmark( + op_name="rand_like", + torch_op=torch.rand_like, + arg_func=unary_arg, + dtypes=FLOAT_DTYPES, + batch=POINTWISE_BATCH, + sizes=SIZES, + ) + bench.run() + + +def test_perf_normal(): + def normal_arg(dtype, batch, size): + loc = torch.full(size=(size, batch), fill_value=3.0, dtype=dtype, device="cuda") + scale = torch.full( + size=(size, batch), fill_value=10.0, dtype=dtype, device="cuda" + ) + return loc, scale + + bench = Benchmark( + op_name="distributions.normal.Normal", + torch_op=torch.distributions.normal.Normal, + arg_func=normal_arg, + dtypes=FLOAT_DTYPES, + batch=POINTWISE_BATCH, + sizes=SIZES, + ) + bench.run() + + +def test_perf_uniform(): + bench = Benchmark( + op_name="uniform_", + torch_op=torch.Tensor.uniform_, + arg_func=unary_arg, + dtypes=FLOAT_DTYPES, + batch=POINTWISE_BATCH, + sizes=SIZES, + ) + bench.run() + + +def test_perf_exponential_(): + bench = Benchmark( + op_name="exponential_", + torch_op=torch.Tensor.exponential_, + arg_func=unary_arg, + dtypes=FLOAT_DTYPES, + batch=POINTWISE_BATCH, + sizes=SIZES, + ) + bench.run() diff --git a/benchmark/test_special_perf.py b/benchmark/test_special_perf.py index fb175dc3..73b760aa 100644 --- a/benchmark/test_special_perf.py +++ b/benchmark/test_special_perf.py @@ -1,86 +1,58 @@ import torch -from .performance_utils import ( - FLOAT_DTYPES, - POINTWISE_BATCH, - SIZES, - Benchmark, - unary_arg, -) +from .performance_utils import POINTWISE_BATCH, SIZES, Benchmark -def test_perf_rand(): - def rand_kwargs(dtype, batch, size): - return {"size": (batch, size), "dtype": dtype, "device": "cuda"} - - bench = Benchmark( - op_name="rand", - torch_op=torch.rand, - arg_func=None, - dtypes=FLOAT_DTYPES, - batch=POINTWISE_BATCH, - sizes=SIZES, - kwargs_func=rand_kwargs, - ) - bench.run() - - -def test_perf_randn(): - def randn_kwargs(dtype, batch, size): - return {"size": (batch, size), "dtype": dtype, "device": "cuda"} +def test_perf_embedding(): + def embedding_kwargs(dtype, batch, size): + input = torch.randint(0, batch, (batch,), device="cuda") + weight = torch.randn((batch + 1, size), device="cuda", dtype=dtype) + return {"input": input, "weight": weight} bench = Benchmark( - op_name="randn", - torch_op=torch.randn, + op_name="embedding", + torch_op=torch.nn.functional.embedding, arg_func=None, - dtypes=FLOAT_DTYPES, + dtypes=[ + torch.float32, + torch.float16, + ], # Note(Zhengzekang): triton do not support bfloat16 atomic add which is used in embedding grad. batch=POINTWISE_BATCH, sizes=SIZES, - kwargs_func=randn_kwargs, + kwargs_func=embedding_kwargs, ) bench.run() -def test_perf_rand_like(): - bench = Benchmark( - op_name="rand_like", - torch_op=torch.rand_like, - arg_func=unary_arg, - dtypes=FLOAT_DTYPES, - batch=POINTWISE_BATCH, - sizes=SIZES, - ) - bench.run() - +def test_perf_resolve_neg(): + def resolve_neg_arg(dtype, batch, size): + x = torch.randn(size=(batch, size), dtype=dtype, device="cuda") + y = x.conj() + z = y.imag + return (z,) -def test_perf_exponential_(): bench = Benchmark( - op_name="exponential_", - torch_op=torch.Tensor.exponential_, - arg_func=unary_arg, - dtypes=FLOAT_DTYPES, + op_name="resolve_neg", + torch_op=torch.resolve_neg, + arg_func=resolve_neg_arg, + dtypes=[torch.cfloat], batch=POINTWISE_BATCH, sizes=SIZES, ) bench.run() -def test_perf_embedding(): - def embedding_kwargs(dtype, batch, size): - input = torch.randint(0, batch, (batch,), device="cuda") - weight = torch.randn((batch + 1, size), device="cuda", dtype=dtype) - return {"input": input, "weight": weight} +def test_perf_resolve_conj(): + def resolve_conj_arg(dtype, batch, size): + x = torch.randn(size=(size, batch), dtype=dtype, device="cuda") + return (x.conj(),) bench = Benchmark( - op_name="embedding", - torch_op=torch.nn.functional.embedding, - arg_func=None, - dtypes=[ - torch.float32, - torch.float16, - ], # Note(Zhengzekang): triton do not support bfloat16 atomic add which is used in embedding grad. + op_name="resolve_conj", + torch_op=torch.resolve_conj, + arg_func=resolve_conj_arg, + dtypes=[torch.cfloat], batch=POINTWISE_BATCH, sizes=SIZES, - kwargs_func=embedding_kwargs, ) bench.run() diff --git a/benchmark/test_tensor_constructor_perf.py b/benchmark/test_tensor_constructor_perf.py new file mode 100644 index 00000000..e033c154 --- /dev/null +++ b/benchmark/test_tensor_constructor_perf.py @@ -0,0 +1,105 @@ +import torch + +from .performance_utils import ( + FLOAT_DTYPES, + POINTWISE_BATCH, + SIZES, + Benchmark, + unary_arg, +) + + +def test_perf_ones(): + def ones_kwargs(dtype, batch, size): + return {"size": (batch, size), "dtype": dtype, "device": "cuda"} + + bench = Benchmark( + op_name="ones", + torch_op=torch.ones, + arg_func=None, + dtypes=FLOAT_DTYPES, + batch=POINTWISE_BATCH, + sizes=SIZES, + kwargs_func=ones_kwargs, + ) + bench.run() + + +def test_perf_zeros(): + def zeros_kwargs(dtype, batch, size): + return {"size": (batch, size), "dtype": dtype, "device": "cuda"} + + bench = Benchmark( + op_name="zeros", + torch_op=torch.zeros, + arg_func=None, + dtypes=FLOAT_DTYPES, + batch=POINTWISE_BATCH, + sizes=SIZES, + kwargs_func=zeros_kwargs, + ) + bench.run() + + +def test_perf_full(): + def full_kwargs(dtype, batch, size): + return { + "size": (batch, size), + "fill_value": 3.1415926, + "dtype": dtype, + "device": "cuda", + } + + bench = Benchmark( + op_name="full", + torch_op=torch.full, + arg_func=None, + dtypes=FLOAT_DTYPES, + batch=POINTWISE_BATCH, + sizes=SIZES, + kwargs_func=full_kwargs, + ) + bench.run() + + +def test_perf_ones_like(): + bench = Benchmark( + op_name="ones_like", + torch_op=torch.ones_like, + arg_func=unary_arg, + dtypes=FLOAT_DTYPES, + batch=POINTWISE_BATCH, + sizes=SIZES, + ) + bench.run() + + +def test_perf_zeros_like(): + bench = Benchmark( + op_name="zeros_like", + torch_op=torch.zeros_like, + arg_func=unary_arg, + dtypes=FLOAT_DTYPES, + batch=POINTWISE_BATCH, + sizes=SIZES, + ) + bench.run() + + +def test_perf_full_like(): + def full_kwargs(dtype, batch, size): + return { + "input": torch.randn([batch, size], dtype=dtype, device="cuda"), + "fill_value": 3.1415926, + } + + bench = Benchmark( + op_name="full_like", + torch_op=torch.full_like, + arg_func=None, + dtypes=FLOAT_DTYPES, + batch=POINTWISE_BATCH, + sizes=SIZES, + kwargs_func=full_kwargs, + ) + bench.run() diff --git a/src/flag_gems/__init__.py b/src/flag_gems/__init__.py index 5996ff94..5e15f57a 100644 --- a/src/flag_gems/__init__.py +++ b/src/flag_gems/__init__.py @@ -50,6 +50,19 @@ def enable(lib=aten_lib): lib.impl("rand", rand, "CUDA") lib.impl("randn", randn, "CUDA") lib.impl("rand_like", rand_like, "CUDA") + lib.impl("zeros", zeros, "CUDA") + lib.impl("ones", ones, "CUDA") + lib.impl("full", full, "CUDA") + lib.impl("zeros_like", zeros_like, "CUDA") + lib.impl("ones_like", ones_like, "CUDA") + lib.impl("full_like", full_like, "CUDA") + lib.impl("resolve_neg", resolve_neg, "CUDA") + lib.impl("resolve_conj", resolve_conj, "CUDA") + lib.impl("normal.Tensor_float", normal_tensor_float, "CUDA") + lib.impl("normal.float_Tensor", normal_float_tensor, "CUDA") + lib.impl("normal.Tensor_Tensor", normal_tensor_tensor, "CUDA") + lib.impl("normal.float_float", normal_float_float, "CUDA") + lib.impl("uniform_", uniform_, "CUDA") lib.impl("mean", mean, "CUDA") lib.impl("mean.dim", mean_dim, "CUDA") lib.impl("mm", mm, "CUDA") diff --git a/src/flag_gems/ops/__init__.py b/src/flag_gems/ops/__init__.py index c03dde8a..3df8ac63 100644 --- a/src/flag_gems/ops/__init__.py +++ b/src/flag_gems/ops/__init__.py @@ -25,6 +25,8 @@ from .exp import exp from .exponential_ import exponential_ from .flip import flip +from .full import full +from .full_like import full_like from .ge import ge, ge_scalar from .gelu import gelu from .groupnorm import group_norm @@ -45,6 +47,14 @@ from .mv import mv from .ne import ne, ne_scalar from .neg import neg +from .normal import ( + normal_float_float, + normal_float_tensor, + normal_tensor_float, + normal_tensor_tensor, +) +from .ones import ones +from .ones_like import ones_like from .outer import outer from .pow import pow_scalar, pow_tensor_scalar, pow_tensor_tensor from .prod import prod, prod_dim @@ -53,6 +63,8 @@ from .randn import randn from .reciprocal import reciprocal from .relu import relu +from .resolve_conj import resolve_conj +from .resolve_neg import resolve_neg from .rms_norm import rms_norm from .rsqrt import rsqrt from .sigmoid import sigmoid @@ -63,9 +75,12 @@ from .sum import sum, sum_dim from .tanh import tanh from .triu import triu +from .uniform import uniform_ from .var_mean import var_mean from .vector_norm import vector_norm from .where import where_scalar_other, where_scalar_self, where_self +from .zeros import zeros +from .zeros_like import zeros_like __all__ = [ "all", @@ -91,6 +106,9 @@ "cos", "cumsum", "div", + "zeros", + "ones", + "full", "native_dropout", "erf", "embedding", @@ -99,6 +117,9 @@ "exp", "exponential_", "flip", + "ones_like", + "full_like", + "zeros_like", "ge", "ge_scalar", "gelu", @@ -121,6 +142,13 @@ "mul", "rand", "randn", + "resolve_neg", + "resolve_conj", + "normal_tensor_float", + "normal_float_tensor", + "normal_tensor_tensor", + "normal_float_float", + "uniform_", "rand_like", "mv", "ne", diff --git a/src/flag_gems/ops/full.py b/src/flag_gems/ops/full.py new file mode 100644 index 00000000..d9375f34 --- /dev/null +++ b/src/flag_gems/ops/full.py @@ -0,0 +1,36 @@ +import logging + +import torch +import triton +import triton.language as tl + +from flag_gems.utils.shape_utils import volume + + +@triton.jit(do_not_specialize=["fill_value"]) +def full_kernel( + output_ptr, + n_elements, + fill_value, + BLOCK_SIZE: tl.constexpr, +): + pid = tl.program_id(axis=0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + tl.store(output_ptr + offsets, fill_value, mask=mask) + + +def full(size, fill_value, *, dtype=None, layout=None, device=None, pin_memory=None): + logging.debug("GEMS FULL") + if dtype is None: + dtype = torch.get_default_dtype() + if device is None: + device = torch.device("cuda") + + out = torch.empty(size, device=device, dtype=dtype) + N = volume(size) + grid_fn = lambda meta: (triton.cdiv(N, meta["BLOCK_SIZE"]),) + with torch.cuda.device(device): + full_kernel[grid_fn](out, N, fill_value, BLOCK_SIZE=1024) + return out diff --git a/src/flag_gems/ops/full_like.py b/src/flag_gems/ops/full_like.py new file mode 100644 index 00000000..819cdad1 --- /dev/null +++ b/src/flag_gems/ops/full_like.py @@ -0,0 +1,29 @@ +import logging + +import torch +import triton + +from .full import full_kernel + + +def full_like( + x, + fill_value, + *, + dtype=None, + layout=None, + device=None, + pin_memory=None, + memory_format=None, +): + logging.debug("GEMS FULL_LIKE") + if device is None: + device = x.device + if dtype is None: + dtype = x.dtype + out = torch.empty_like(x, device=device, dtype=dtype) + N = x.numel() + grid_fn = lambda meta: (triton.cdiv(N, meta["BLOCK_SIZE"]),) + with torch.cuda.device(x.device): + full_kernel[grid_fn](out, N, fill_value, BLOCK_SIZE=1024) + return out diff --git a/src/flag_gems/ops/normal.py b/src/flag_gems/ops/normal.py new file mode 100644 index 00000000..f7e2e9cc --- /dev/null +++ b/src/flag_gems/ops/normal.py @@ -0,0 +1,80 @@ +import logging + +import torch +import triton + +from ..utils import pointwise_dynamic +from ..utils.random_utils import philox_cuda_seed_offset +from ..utils.shape_utils import broadcast_shapes, volume +from .randn import randn_kernel + +UNROLL = 4 + + +@pointwise_dynamic( + is_tensor=[True, True, True], promotion_methods=[(0, 1, 2, "DEFAULT")] +) +@triton.jit +def transform_func_tensor_tensor(val, std, mean): + return val * std + mean + + +@pointwise_dynamic( + is_tensor=[True, True, False], promotion_methods=[(0, 1, 2, "DEFAULT")] +) +@triton.jit +def transform_func_tensor_float(val, std, mean): + return val * std + mean + + +@pointwise_dynamic( + is_tensor=[True, False, True], promotion_methods=[(0, 1, 2, "DEFAULT")] +) +@triton.jit +def transform_func_float_tensor(val, std, mean): + return val * std + mean + + +@pointwise_dynamic( + is_tensor=[True, False, False], promotion_methods=[(0, 1, 2, "DEFAULT")] +) +@triton.jit +def transform_func_float_float(val, std, mean): + return val * std + mean + + +def normal_distribution(mean, std, *, generator=None): + shape = broadcast_shapes([mean.shape, std.shape]) + out = torch.empty(shape, device=mean.device, dtype=torch.float32) + N = volume(shape) + grid_fn = lambda meta: (triton.cdiv(N, meta["BLOCK"] * UNROLL),) + + increment = triton.cdiv(N, UNROLL) + philox_seed, philox_offset = philox_cuda_seed_offset(increment) + with torch.cuda.device(mean.device): + randn_kernel[grid_fn](out, N, philox_seed, philox_offset) + return out + + +def normal_tensor_tensor(mean, std, *, generator=None): + logging.debug("GEMS NORMAL_TENSOR_TENSOR") + out = normal_distribution(mean, std) + return transform_func_tensor_tensor(out, std, mean) + + +def normal_tensor_float(mean, std, *, generator=None): + logging.debug("GEMS NORMAL_TENSOR_FLOAT") + out = normal_distribution(mean, std) + return transform_func_tensor_float(out, std, mean) + + +def normal_float_tensor(mean, std, *, generator=None): + logging.debug("GEMS NORMAL_FLOAT_TENSOR") + out = normal_distribution(mean, std) + return transform_func_float_tensor(out, std, mean) + + +def normal_float_float(mean, std, *, generator=None): + logging.debug("GEMS NORMAL_FLOAT_FLOAT") + out = normal_distribution(mean, std) + return transform_func_float_float(out, std, mean) diff --git a/src/flag_gems/ops/ones.py b/src/flag_gems/ops/ones.py new file mode 100644 index 00000000..96a33cdf --- /dev/null +++ b/src/flag_gems/ops/ones.py @@ -0,0 +1,35 @@ +import logging + +import torch +import triton +import triton.language as tl + +from flag_gems.utils.shape_utils import volume + + +@triton.jit +def ones_kernel( + output_ptr, + n_elements, + BLOCK_SIZE: tl.constexpr, +): + pid = tl.program_id(axis=0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + tl.store(output_ptr + offsets, 1.0, mask=mask) + + +def ones(size, *, dtype=None, layout=None, device=None, pin_memory=None): + logging.debug("GEMS ONES") + if dtype is None: + dtype = torch.get_default_dtype() + if device is None: + device = torch.device("cuda") + + out = torch.empty(size, device=device, dtype=dtype) + N = volume(size) + grid_fn = lambda meta: (triton.cdiv(N, meta["BLOCK_SIZE"]),) + with torch.cuda.device(device): + ones_kernel[grid_fn](out, N, BLOCK_SIZE=1024) + return out diff --git a/src/flag_gems/ops/ones_like.py b/src/flag_gems/ops/ones_like.py new file mode 100644 index 00000000..88c894ea --- /dev/null +++ b/src/flag_gems/ops/ones_like.py @@ -0,0 +1,22 @@ +import logging + +import torch +import triton + +from .ones import ones_kernel + + +def ones_like( + x, *, dtype=None, layout=None, device=None, pin_memory=None, memory_format=None +): + logging.debug("GEMS ONES_LIKE") + if device is None: + device = x.device + if dtype is None: + dtype = x.dtype + out = torch.empty_like(x, device=device, dtype=dtype) + N = x.numel() + grid_fn = lambda meta: (triton.cdiv(N, meta["BLOCK_SIZE"]),) + with torch.cuda.device(x.device): + ones_kernel[grid_fn](out, N, BLOCK_SIZE=1024) + return out diff --git a/src/flag_gems/ops/resolve_conj.py b/src/flag_gems/ops/resolve_conj.py new file mode 100644 index 00000000..5232f4c8 --- /dev/null +++ b/src/flag_gems/ops/resolve_conj.py @@ -0,0 +1,8 @@ +import logging + +import torch + + +def resolve_conj(A: torch.Tensor): + logging.debug("GEMS RESOLVE_CONJ") + return torch.complex(A.real, A.imag.neg()) if A.is_conj() else A diff --git a/src/flag_gems/ops/resolve_neg.py b/src/flag_gems/ops/resolve_neg.py new file mode 100644 index 00000000..1c090632 --- /dev/null +++ b/src/flag_gems/ops/resolve_neg.py @@ -0,0 +1,10 @@ +import logging + +import torch + +from flag_gems.ops.neg import neg_func + + +def resolve_neg(A: torch.Tensor): + logging.debug("GEMS RESOLVE_NEG") + return neg_func(A) if A.is_neg() else A diff --git a/src/flag_gems/ops/uniform.py b/src/flag_gems/ops/uniform.py new file mode 100644 index 00000000..0ade2a96 --- /dev/null +++ b/src/flag_gems/ops/uniform.py @@ -0,0 +1,77 @@ +import logging + +import torch +import triton +import triton.language as tl + +from flag_gems.utils.random_utils import philox_cuda_seed_offset, uint_to_uniform_float +from flag_gems.utils.shape_utils import volume + + +def heur_block(args): + if args["N"] <= 512: + return 512 + else: + return 1024 + + +def heur_num_warps(args): + if args["N"] <= 512: + return 4 + elif args["N"] <= 1024: + return 8 + else: + return 16 + + +@triton.heuristics( + { + "BLOCK": heur_block, + "num_warps": heur_num_warps, + } +) +@triton.jit(do_not_specialize=["philox_seed", "philox_offset"]) +def uniform_kernel( + out_ptr, + N, + philox_seed, + philox_offset, + from_, + to, + BLOCK: tl.constexpr, +): + philox_seed = philox_seed.to(tl.int64) + philox_offset = philox_offset.to(tl.int64) + c0 = (philox_offset & 0xFFFFFFFF).to(tl.uint32) + c1 = ((philox_offset >> 32) & 0xFFFFFFFF).to(tl.uint32) + i4 = tl.program_id(0) * BLOCK + tl.arange(0, BLOCK) + c0 += i4 + _O = c0 * 0 + r0, r1, r2, r3 = tl.philox(philox_seed, c0, c1, _O, _O) + r0 = uint_to_uniform_float(r0) * (to - from_) + from_ + r1 = uint_to_uniform_float(r1) * (to - from_) + from_ + r2 = uint_to_uniform_float(r2) * (to - from_) + from_ + r3 = uint_to_uniform_float(r3) * (to - from_) + from_ + off_0 = tl.program_id(0) * BLOCK * 4 + tl.arange(0, BLOCK) + off_1 = off_0 + BLOCK + off_2 = off_1 + BLOCK + off_3 = off_2 + BLOCK + tl.store(out_ptr + off_0, r0, mask=off_0 < N, eviction_policy="evict_first") + tl.store(out_ptr + off_1, r1, mask=off_1 < N, eviction_policy="evict_first") + tl.store(out_ptr + off_2, r2, mask=off_2 < N, eviction_policy="evict_first") + tl.store(out_ptr + off_3, r3, mask=off_3 < N, eviction_policy="evict_first") + + +UNROLL = 4 + + +def uniform_(self, from_=0.0, to=1.0, *, generator=None): + logging.debug("GEMS UNIFORM") + N = volume(self.shape) + grid_fn = lambda meta: (triton.cdiv(N, meta["BLOCK"] * UNROLL),) + + increment = triton.cdiv(N, UNROLL) + philox_seed, philox_offset = philox_cuda_seed_offset(increment) + with torch.cuda.device(self.device): + uniform_kernel[grid_fn](self, N, philox_seed, philox_offset, from_, to) + return self diff --git a/src/flag_gems/ops/zeros.py b/src/flag_gems/ops/zeros.py new file mode 100644 index 00000000..3063a6c0 --- /dev/null +++ b/src/flag_gems/ops/zeros.py @@ -0,0 +1,35 @@ +import logging + +import torch +import triton +import triton.language as tl + +from flag_gems.utils.shape_utils import volume + + +@triton.jit +def zeros_kernel( + output_ptr, + n_elements, + BLOCK_SIZE: tl.constexpr, +): + pid = tl.program_id(axis=0) # We use a 1D launch grid so axis is 0. + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + tl.store(output_ptr + offsets, 0.0, mask=mask) + + +def zeros(size, *, dtype=None, layout=None, device=None, pin_memory=None): + logging.debug("GEMS ZEROS") + if dtype is None: + dtype = torch.get_default_dtype() + if device is None: + device = torch.device("cuda") + + out = torch.empty(size, device=device, dtype=dtype) + N = volume(size) + grid_fn = lambda meta: (triton.cdiv(N, meta["BLOCK_SIZE"]),) + with torch.cuda.device(device): + zeros_kernel[grid_fn](out, N, BLOCK_SIZE=1024) + return out diff --git a/src/flag_gems/ops/zeros_like.py b/src/flag_gems/ops/zeros_like.py new file mode 100644 index 00000000..264d7874 --- /dev/null +++ b/src/flag_gems/ops/zeros_like.py @@ -0,0 +1,22 @@ +import logging + +import torch +import triton + +from .zeros import zeros_kernel + + +def zeros_like( + x, *, dtype=None, layout=None, device=None, pin_memory=None, memory_format=None +): + logging.debug("GEMS FULL_LIKE") + if device is None: + device = x.device + if dtype is None: + dtype = x.dtype + out = torch.empty_like(x, device=device, dtype=dtype) + N = x.numel() + grid_fn = lambda meta: (triton.cdiv(N, meta["BLOCK_SIZE"]),) + with torch.cuda.device(x.device): + zeros_kernel[grid_fn](out, N, BLOCK_SIZE=1024) + return out diff --git a/tests/accuracy_utils.py b/tests/accuracy_utils.py index 05239ed7..9b33b9d7 100644 --- a/tests/accuracy_utils.py +++ b/tests/accuracy_utils.py @@ -14,6 +14,7 @@ } POINTWISE_SHAPES = [(1024, 1024), (16, 1024, 256), (16, 128, 64, 64), (20, 320, 15)] +DISTRIBUTION_SHAPES = [(20, 320, 15)] REDUCTION_SHAPES = [(4096, 256 * i) for i in range(1, 10, 2)] MNK_SHAPES = [15, 160, 1024] diff --git a/tests/ks_tests.py b/tests/ks_tests.py new file mode 100644 index 00000000..ba91565a --- /dev/null +++ b/tests/ks_tests.py @@ -0,0 +1,72 @@ +import numpy as np +import pytest +import scipy +import torch + +import flag_gems + +from .accuracy_utils import DISTRIBUTION_SHAPES, FLOAT_DTYPES + +# The Kolmogorov-Smirnov test (K-S test or KS test) is performed on the +# distribution operator. By having randomness, CI does not perform + + +@pytest.mark.parametrize("shape", DISTRIBUTION_SHAPES) +@pytest.mark.parametrize("dtype", FLOAT_DTYPES) +def test_accuracy_normal_pvalue(shape, dtype): + loc = torch.full(size=shape, fill_value=3.0, dtype=dtype, device="cuda") + scale = torch.full(size=shape, fill_value=10.0, dtype=dtype, device="cuda") + with flag_gems.use_gems(): + res_out = torch.distributions.normal.Normal(loc, scale).sample() + pvalue = scipy.stats.kstest( + res_out.cpu().numpy().flatten(), + lambda x: scipy.stats.norm.cdf(x, loc=3.0, scale=10.0), + ).pvalue + assert pvalue > 0.05 + + +@pytest.mark.parametrize("shape", DISTRIBUTION_SHAPES) +@pytest.mark.parametrize("dtype", [torch.float16, torch.float32]) +def test_accuracy_uniform_pvalue(shape, dtype): + x = torch.randn(size=shape, dtype=dtype, device="cuda") + with flag_gems.use_gems(): + x.uniform_(-3, 3) + pvalue = scipy.stats.kstest( + x.cpu().numpy().flatten(), + lambda x: scipy.stats.uniform.cdf(x, loc=-3.0, scale=6.0), + ).pvalue + assert pvalue > 0.05 + + +@pytest.mark.parametrize("shape", DISTRIBUTION_SHAPES) +@pytest.mark.parametrize("dtype", (torch.float32,)) +@pytest.mark.parametrize("lambd", (0.01, 0.5, 100.0)) +def test_accuracy_exponential_pvalue(shape, dtype, lambd): + x = torch.empty(size=shape, dtype=dtype, device="cuda") + with flag_gems.use_gems(): + x.exponential_(lambd=lambd) + expo_cdf = lambda x: np.where(x < 0, 0, 1.0 - np.exp(-lambd * x)) + pvalue = scipy.stats.kstest(x.cpu().numpy().flatten(), expo_cdf).pvalue + assert pvalue > 0.05 + + +@pytest.mark.parametrize("shape", DISTRIBUTION_SHAPES) +@pytest.mark.parametrize("dtype", [torch.float16, torch.float32]) +def test_accuracy_rand_pvalue(shape, dtype): + with flag_gems.use_gems(): + res_out = torch.rand(shape, dtype=dtype, device="cuda") + pvalue = scipy.stats.kstest( + res_out.cpu().numpy().flatten(), lambda x: scipy.stats.uniform.cdf(x) + ).pvalue + assert pvalue > 0.05 + + +@pytest.mark.parametrize("shape", DISTRIBUTION_SHAPES) +@pytest.mark.parametrize("dtype", [torch.float16, torch.float32]) +def test_accuracy_randn_pvalue(shape, dtype): + with flag_gems.use_gems(): + res_out = torch.randn(shape, dtype=dtype, device="cuda") + pvalue = scipy.stats.kstest( + res_out.cpu().numpy().flatten(), lambda x: scipy.stats.norm.cdf(x) + ).pvalue + assert pvalue > 0.05 diff --git a/tests/test_distribution_ops.py b/tests/test_distribution_ops.py new file mode 100644 index 00000000..4f5e9745 --- /dev/null +++ b/tests/test_distribution_ops.py @@ -0,0 +1,38 @@ +import pytest +import torch + +import flag_gems + +from .accuracy_utils import DISTRIBUTION_SHAPES, FLOAT_DTYPES + + +@pytest.mark.parametrize("shape", DISTRIBUTION_SHAPES) +@pytest.mark.parametrize("dtype", FLOAT_DTYPES) +def test_accuracy_normal(shape, dtype): + loc = torch.full(size=shape, fill_value=3.0, dtype=dtype, device="cuda") + scale = torch.full(size=shape, fill_value=10.0, dtype=dtype, device="cuda") + with flag_gems.use_gems(): + res_out = torch.distributions.normal.Normal(loc, scale).sample() + mean = torch.mean(res_out) + std = torch.std(res_out) + assert torch.abs(mean - 3.0) < 0.1 + assert torch.abs(std - 10.0) < 0.1 + + +@pytest.mark.parametrize("shape", DISTRIBUTION_SHAPES) +@pytest.mark.parametrize("dtype", [torch.float16, torch.float32]) +def test_accuracy_uniform(shape, dtype): + x = torch.randn(size=shape, dtype=dtype, device="cuda") + with flag_gems.use_gems(): + x.uniform_(-3, 3) + assert (x <= 3.0).all() + assert (x >= -3.0).all() + + +@pytest.mark.parametrize("shape", DISTRIBUTION_SHAPES) +@pytest.mark.parametrize("dtype", FLOAT_DTYPES) +def test_accuracy_exponential_(shape, dtype): + x = torch.empty(size=shape, dtype=dtype, device="cuda") + with flag_gems.use_gems(): + x.exponential_() + assert x.min() > 0 diff --git a/tests/test_special_ops.py b/tests/test_special_ops.py index a42d2587..a155788e 100644 --- a/tests/test_special_ops.py +++ b/tests/test_special_ops.py @@ -1,8 +1,6 @@ from typing import Optional -import numpy as np import pytest -import scipy import torch import flag_gems @@ -197,51 +195,23 @@ def test_embedding(EmbeddingSize, Batch, M, N, padding_idx, scale_grad_by_freq, @pytest.mark.parametrize("shape", POINTWISE_SHAPES) -@pytest.mark.parametrize("dtype", FLOAT_DTYPES) -def test_accuracy_rand(shape, dtype): - with flag_gems.use_gems(): - res_out = torch.rand(shape, dtype=dtype, device="cuda") - assert (res_out <= 1.0).all() - assert (res_out >= 0.0).all() - - -@pytest.mark.parametrize("shape", POINTWISE_SHAPES) -@pytest.mark.parametrize("dtype", FLOAT_DTYPES) -def test_accuracy_randn(shape, dtype): - with flag_gems.use_gems(): - res_out = torch.randn(shape, dtype=dtype, device="cuda") - mean = torch.mean(res_out) - std = torch.std(res_out) - assert torch.abs(mean) < 0.01 - assert torch.abs(std - 1) < 0.01 - - -@pytest.mark.parametrize("shape", POINTWISE_SHAPES) -@pytest.mark.parametrize("dtype", FLOAT_DTYPES) -def test_accuracy_rand_like(shape, dtype): +@pytest.mark.parametrize("dtype", [torch.cfloat]) +def test_accuracy_resolve_neg(shape, dtype): x = torch.randn(size=shape, dtype=dtype, device="cuda") + y = x.conj() + z = y.imag + assert z.is_neg() with flag_gems.use_gems(): - res_out = torch.rand_like(x) - assert (res_out <= 1.0).all() - assert (res_out >= 0.0).all() + out = z.resolve_neg() + assert not out.is_neg() @pytest.mark.parametrize("shape", POINTWISE_SHAPES) -@pytest.mark.parametrize("dtype", FLOAT_DTYPES) -def test_accuracy_exponential_(shape, dtype): - x = torch.empty(size=shape, dtype=dtype, device="cuda") - with flag_gems.use_gems(): - x.exponential_() - assert x.min() > 0 - - -@pytest.mark.parametrize("shape", POINTWISE_SHAPES[:1]) -@pytest.mark.parametrize("dtype", (torch.float32,)) -@pytest.mark.parametrize("lambd", (0.01, 0.5, 100.0)) -def test_accuracy_exponential_pvalue(shape, dtype, lambd): - x = torch.empty(size=shape, dtype=dtype, device="cuda") +@pytest.mark.parametrize("dtype", [torch.cfloat]) +def test_accuracy_resolve_conj(shape, dtype): + x = torch.randn(size=shape, dtype=dtype, device="cuda") + y = x.conj() + assert y.is_conj() with flag_gems.use_gems(): - x.exponential_(lambd=lambd) - expo_cdf = lambda x: np.where(x < 0, 0, 1.0 - np.exp(-lambd * x)) - pvalue = scipy.stats.kstest(x.cpu().numpy().flatten(), expo_cdf).pvalue - assert pvalue > 0.05 + z = y.resolve_conj() + assert not z.is_conj() diff --git a/tests/test_specific_ops.py b/tests/test_specific_ops.py index 7569dc8d..0000ab31 100644 --- a/tests/test_specific_ops.py +++ b/tests/test_specific_ops.py @@ -176,3 +176,4 @@ if exec_flag is False: logging.fatal(f"No op named {args.name} found! Check the name and list!") + exit(-1) diff --git a/tests/test_tensor_constructor_ops.py b/tests/test_tensor_constructor_ops.py new file mode 100644 index 00000000..07ff4662 --- /dev/null +++ b/tests/test_tensor_constructor_ops.py @@ -0,0 +1,92 @@ +import pytest +import torch + +import flag_gems + +from .accuracy_utils import ( + DISTRIBUTION_SHAPES, + FLOAT_DTYPES, + POINTWISE_SHAPES, + gems_assert_equal, +) + + +@pytest.mark.parametrize("shape", DISTRIBUTION_SHAPES) +@pytest.mark.parametrize("dtype", [torch.float16, torch.float32]) +def test_accuracy_rand(shape, dtype): + with flag_gems.use_gems(): + res_out = torch.rand(shape, dtype=dtype, device="cuda") + assert (res_out <= 1.0).all() + assert (res_out >= 0.0).all() + + +@pytest.mark.parametrize("shape", DISTRIBUTION_SHAPES) +@pytest.mark.parametrize("dtype", [torch.float16, torch.float32]) +def test_accuracy_randn(shape, dtype): + with flag_gems.use_gems(): + res_out = torch.randn(shape, dtype=dtype, device="cuda") + mean = torch.mean(res_out) + std = torch.std(res_out) + assert torch.abs(mean) < 0.01 + assert torch.abs(std - 1) < 0.01 + + +@pytest.mark.parametrize("shape", DISTRIBUTION_SHAPES) +@pytest.mark.parametrize("dtype", FLOAT_DTYPES) +def test_accuracy_rand_like(shape, dtype): + x = torch.randn(size=shape, dtype=dtype, device="cuda") + with flag_gems.use_gems(): + res_out = torch.rand_like(x) + assert (res_out <= 1.0).all() + assert (res_out >= 0.0).all() + + +@pytest.mark.parametrize("shape", POINTWISE_SHAPES) +@pytest.mark.parametrize("dtype", FLOAT_DTYPES) +def test_accuracy_zeros(shape, dtype): + with flag_gems.use_gems(): + res_out = torch.zeros(shape, dtype=dtype, device="cuda") + gems_assert_equal(res_out, torch.zeros(shape, dtype=dtype, device="cuda")) + + +@pytest.mark.parametrize("shape", POINTWISE_SHAPES) +@pytest.mark.parametrize("dtype", FLOAT_DTYPES) +def test_accuracy_ones(shape, dtype): + with flag_gems.use_gems(): + res_out = torch.ones(shape, dtype=dtype, device="cuda") + gems_assert_equal(res_out, torch.ones(shape, dtype=dtype, device="cuda")) + + +@pytest.mark.parametrize("shape", POINTWISE_SHAPES) +@pytest.mark.parametrize("dtype", FLOAT_DTYPES) +def test_accuracy_full(shape, dtype): + with flag_gems.use_gems(): + res_out = torch.full(shape, 3.1415926, dtype=dtype, device="cuda") + gems_assert_equal(res_out, torch.full(shape, 3.1415926, dtype=dtype, device="cuda")) + + +@pytest.mark.parametrize("shape", POINTWISE_SHAPES) +@pytest.mark.parametrize("dtype", FLOAT_DTYPES) +def test_accuracy_zeros_like(shape, dtype): + x = torch.empty(size=shape, dtype=dtype, device="cuda") + with flag_gems.use_gems(): + res_out = torch.zeros_like(x) + gems_assert_equal(res_out, torch.zeros_like(x)) + + +@pytest.mark.parametrize("shape", POINTWISE_SHAPES) +@pytest.mark.parametrize("dtype", FLOAT_DTYPES) +def test_accuracy_ones_like(shape, dtype): + x = torch.empty(size=shape, dtype=dtype, device="cuda") + with flag_gems.use_gems(): + res_out = torch.ones_like(x) + gems_assert_equal(res_out, torch.ones_like(x)) + + +@pytest.mark.parametrize("shape", POINTWISE_SHAPES) +@pytest.mark.parametrize("dtype", FLOAT_DTYPES) +def test_accuracy_full_like(shape, dtype): + x = torch.empty(size=shape, dtype=dtype, device="cuda") + with flag_gems.use_gems(): + res_out = torch.full_like(x, 3.1415926) + gems_assert_equal(res_out, torch.full_like(x, 3.1415926))