Skip to content

Commit

Permalink
[Operator] Add uniform, normal, resolve_neg & resolve_conj & zeros &…
Browse files Browse the repository at this point in the history
… ones & full Ops with UT & Bench (#139)
  • Loading branch information
Bowen12992 committed Aug 6, 2024
1 parent 2bf0115 commit 48942a0
Show file tree
Hide file tree
Showing 22 changed files with 847 additions and 103 deletions.
2 changes: 2 additions & 0 deletions .github/workflows/python-test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ jobs:
"CUDA_VISIBLE_DEVICES=0 pytest -s tests/test_unary_pointwise_ops.py &"
"CUDA_VISIBLE_DEVICES=0 pytest -s tests/test_pointwise_type_promotion.py &"
"CUDA_VISIBLE_DEVICES=1 pytest -s tests/test_binary_pointwise_ops.py &"
"CUDA_VISIBLE_DEVICES=1 pytest -s tests/test_tensor_constructor_ops.py &"
"CUDA_VISIBLE_DEVICES=1 pytest -s tests/test_distribution_ops.py &"
"CUDA_VISIBLE_DEVICES=2 pytest -s tests/test_blas_ops.py &"
"CUDA_VISIBLE_DEVICES=3 pytest -s tests/test_reduction_ops.py &"
"CUDA_VISIBLE_DEVICES=4 pytest -s tests/test_special_ops.py &"
Expand Down
96 changes: 96 additions & 0 deletions benchmark/test_distribution_perf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
import torch

from .performance_utils import (
FLOAT_DTYPES,
POINTWISE_BATCH,
SIZES,
Benchmark,
unary_arg,
)


def test_perf_rand():
def rand_kwargs(dtype, batch, size):
return {"size": (batch, size), "dtype": dtype, "device": "cuda"}

bench = Benchmark(
op_name="rand",
torch_op=torch.rand,
arg_func=None,
dtypes=FLOAT_DTYPES,
batch=POINTWISE_BATCH,
sizes=SIZES,
kwargs_func=rand_kwargs,
)
bench.run()


def test_perf_randn():
def randn_kwargs(dtype, batch, size):
return {"size": (batch, size), "dtype": dtype, "device": "cuda"}

bench = Benchmark(
op_name="randn",
torch_op=torch.randn,
arg_func=None,
dtypes=FLOAT_DTYPES,
batch=POINTWISE_BATCH,
sizes=SIZES,
kwargs_func=randn_kwargs,
)
bench.run()


def test_perf_rand_like():
bench = Benchmark(
op_name="rand_like",
torch_op=torch.rand_like,
arg_func=unary_arg,
dtypes=FLOAT_DTYPES,
batch=POINTWISE_BATCH,
sizes=SIZES,
)
bench.run()


def test_perf_normal():
def normal_arg(dtype, batch, size):
loc = torch.full(size=(size, batch), fill_value=3.0, dtype=dtype, device="cuda")
scale = torch.full(
size=(size, batch), fill_value=10.0, dtype=dtype, device="cuda"
)
return loc, scale

bench = Benchmark(
op_name="distributions.normal.Normal",
torch_op=torch.distributions.normal.Normal,
arg_func=normal_arg,
dtypes=FLOAT_DTYPES,
batch=POINTWISE_BATCH,
sizes=SIZES,
)
bench.run()


def test_perf_uniform():
bench = Benchmark(
op_name="uniform_",
torch_op=torch.Tensor.uniform_,
arg_func=unary_arg,
dtypes=FLOAT_DTYPES,
batch=POINTWISE_BATCH,
sizes=SIZES,
)
bench.run()


def test_perf_exponential_():
bench = Benchmark(
op_name="exponential_",
torch_op=torch.Tensor.exponential_,
arg_func=unary_arg,
dtypes=FLOAT_DTYPES,
batch=POINTWISE_BATCH,
sizes=SIZES,
)
bench.run()
90 changes: 31 additions & 59 deletions benchmark/test_special_perf.py
Original file line number Diff line number Diff line change
@@ -1,86 +1,58 @@
import torch

from .performance_utils import (
FLOAT_DTYPES,
POINTWISE_BATCH,
SIZES,
Benchmark,
unary_arg,
)
from .performance_utils import POINTWISE_BATCH, SIZES, Benchmark


def test_perf_rand():
def rand_kwargs(dtype, batch, size):
return {"size": (batch, size), "dtype": dtype, "device": "cuda"}

bench = Benchmark(
op_name="rand",
torch_op=torch.rand,
arg_func=None,
dtypes=FLOAT_DTYPES,
batch=POINTWISE_BATCH,
sizes=SIZES,
kwargs_func=rand_kwargs,
)
bench.run()


def test_perf_randn():
def randn_kwargs(dtype, batch, size):
return {"size": (batch, size), "dtype": dtype, "device": "cuda"}
def test_perf_embedding():
def embedding_kwargs(dtype, batch, size):
input = torch.randint(0, batch, (batch,), device="cuda")
weight = torch.randn((batch + 1, size), device="cuda", dtype=dtype)
return {"input": input, "weight": weight}

bench = Benchmark(
op_name="randn",
torch_op=torch.randn,
op_name="embedding",
torch_op=torch.nn.functional.embedding,
arg_func=None,
dtypes=FLOAT_DTYPES,
dtypes=[
torch.float32,
torch.float16,
], # Note(Zhengzekang): triton do not support bfloat16 atomic add which is used in embedding grad.
batch=POINTWISE_BATCH,
sizes=SIZES,
kwargs_func=randn_kwargs,
kwargs_func=embedding_kwargs,
)
bench.run()


def test_perf_rand_like():
bench = Benchmark(
op_name="rand_like",
torch_op=torch.rand_like,
arg_func=unary_arg,
dtypes=FLOAT_DTYPES,
batch=POINTWISE_BATCH,
sizes=SIZES,
)
bench.run()

def test_perf_resolve_neg():
def resolve_neg_arg(dtype, batch, size):
x = torch.randn(size=(batch, size), dtype=dtype, device="cuda")
y = x.conj()
z = y.imag
return (z,)

def test_perf_exponential_():
bench = Benchmark(
op_name="exponential_",
torch_op=torch.Tensor.exponential_,
arg_func=unary_arg,
dtypes=FLOAT_DTYPES,
op_name="resolve_neg",
torch_op=torch.resolve_neg,
arg_func=resolve_neg_arg,
dtypes=[torch.cfloat],
batch=POINTWISE_BATCH,
sizes=SIZES,
)
bench.run()


def test_perf_embedding():
def embedding_kwargs(dtype, batch, size):
input = torch.randint(0, batch, (batch,), device="cuda")
weight = torch.randn((batch + 1, size), device="cuda", dtype=dtype)
return {"input": input, "weight": weight}
def test_perf_resolve_conj():
def resolve_conj_arg(dtype, batch, size):
x = torch.randn(size=(size, batch), dtype=dtype, device="cuda")
return (x.conj(),)

bench = Benchmark(
op_name="embedding",
torch_op=torch.nn.functional.embedding,
arg_func=None,
dtypes=[
torch.float32,
torch.float16,
], # Note(Zhengzekang): triton do not support bfloat16 atomic add which is used in embedding grad.
op_name="resolve_conj",
torch_op=torch.resolve_conj,
arg_func=resolve_conj_arg,
dtypes=[torch.cfloat],
batch=POINTWISE_BATCH,
sizes=SIZES,
kwargs_func=embedding_kwargs,
)
bench.run()
105 changes: 105 additions & 0 deletions benchmark/test_tensor_constructor_perf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
import torch

from .performance_utils import (
FLOAT_DTYPES,
POINTWISE_BATCH,
SIZES,
Benchmark,
unary_arg,
)


def test_perf_ones():
def ones_kwargs(dtype, batch, size):
return {"size": (batch, size), "dtype": dtype, "device": "cuda"}

bench = Benchmark(
op_name="ones",
torch_op=torch.ones,
arg_func=None,
dtypes=FLOAT_DTYPES,
batch=POINTWISE_BATCH,
sizes=SIZES,
kwargs_func=ones_kwargs,
)
bench.run()


def test_perf_zeros():
def zeros_kwargs(dtype, batch, size):
return {"size": (batch, size), "dtype": dtype, "device": "cuda"}

bench = Benchmark(
op_name="zeros",
torch_op=torch.zeros,
arg_func=None,
dtypes=FLOAT_DTYPES,
batch=POINTWISE_BATCH,
sizes=SIZES,
kwargs_func=zeros_kwargs,
)
bench.run()


def test_perf_full():
def full_kwargs(dtype, batch, size):
return {
"size": (batch, size),
"fill_value": 3.1415926,
"dtype": dtype,
"device": "cuda",
}

bench = Benchmark(
op_name="full",
torch_op=torch.full,
arg_func=None,
dtypes=FLOAT_DTYPES,
batch=POINTWISE_BATCH,
sizes=SIZES,
kwargs_func=full_kwargs,
)
bench.run()


def test_perf_ones_like():
bench = Benchmark(
op_name="ones_like",
torch_op=torch.ones_like,
arg_func=unary_arg,
dtypes=FLOAT_DTYPES,
batch=POINTWISE_BATCH,
sizes=SIZES,
)
bench.run()


def test_perf_zeros_like():
bench = Benchmark(
op_name="zeros_like",
torch_op=torch.zeros_like,
arg_func=unary_arg,
dtypes=FLOAT_DTYPES,
batch=POINTWISE_BATCH,
sizes=SIZES,
)
bench.run()


def test_perf_full_like():
def full_kwargs(dtype, batch, size):
return {
"input": torch.randn([batch, size], dtype=dtype, device="cuda"),
"fill_value": 3.1415926,
}

bench = Benchmark(
op_name="full_like",
torch_op=torch.full_like,
arg_func=None,
dtypes=FLOAT_DTYPES,
batch=POINTWISE_BATCH,
sizes=SIZES,
kwargs_func=full_kwargs,
)
bench.run()
13 changes: 13 additions & 0 deletions src/flag_gems/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,19 @@ def enable(lib=aten_lib):
lib.impl("rand", rand, "CUDA")
lib.impl("randn", randn, "CUDA")
lib.impl("rand_like", rand_like, "CUDA")
lib.impl("zeros", zeros, "CUDA")
lib.impl("ones", ones, "CUDA")
lib.impl("full", full, "CUDA")
lib.impl("zeros_like", zeros_like, "CUDA")
lib.impl("ones_like", ones_like, "CUDA")
lib.impl("full_like", full_like, "CUDA")
lib.impl("resolve_neg", resolve_neg, "CUDA")
lib.impl("resolve_conj", resolve_conj, "CUDA")
lib.impl("normal.Tensor_float", normal_tensor_float, "CUDA")
lib.impl("normal.float_Tensor", normal_float_tensor, "CUDA")
lib.impl("normal.Tensor_Tensor", normal_tensor_tensor, "CUDA")
lib.impl("normal.float_float", normal_float_float, "CUDA")
lib.impl("uniform_", uniform_, "CUDA")
lib.impl("mean", mean, "CUDA")
lib.impl("mean.dim", mean_dim, "CUDA")
lib.impl("mm", mm, "CUDA")
Expand Down
Loading

0 comments on commit 48942a0

Please sign in to comment.