Skip to content

Commit

Permalink
[Operator] Add randperm op
Browse files Browse the repository at this point in the history
  • Loading branch information
yjl0101 authored and yjl0101 committed Aug 29, 2024
1 parent 2db4271 commit c53b7d5
Show file tree
Hide file tree
Showing 6 changed files with 401 additions and 2 deletions.
16 changes: 16 additions & 0 deletions benchmark/test_tensor_constructor_perf.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,3 +159,19 @@ def full_kwargs(dtype, batch, size):
kwargs_func=full_kwargs,
)
bench.run()


def test_perf_randperm():
def randperm_args(dtype, batch, size):
return {"n": size, "dtype": dtype, "device": "cuda"}

bench = Benchmark(
op_name="randperm",
torch_op=torch.randperm,
arg_func=None,
dtypes=[torch.int32, torch.int64],
batch=POINTWISE_BATCH,
sizes=SIZES,
kwargs_func=randperm_args,
)
bench.run()
1 change: 1 addition & 0 deletions src/flag_gems/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@ def enable(lib=aten_lib):
lib.impl("index_select", index_select, "CUDA")
lib.impl("masked_fill", masked_fill, "CUDA")
lib.impl("_unique2", _unique2, "CUDA")
lib.impl("randperm", randperm, "CUDA")


class use_gems:
Expand Down
2 changes: 2 additions & 0 deletions src/flag_gems/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@
from .rand_like import rand_like
from .randn import randn
from .randn_like import randn_like
from .randperm import randperm
from .reciprocal import reciprocal
from .relu import relu
from .resolve_conj import resolve_conj
Expand Down Expand Up @@ -157,6 +158,7 @@
"minimum",
"rand",
"randn",
"randperm",
"rand_like",
"randn_like",
"resolve_neg",
Expand Down
Loading

0 comments on commit c53b7d5

Please sign in to comment.