Skip to content

Commit

Permalink
[Operator] Add randperm op
Browse files Browse the repository at this point in the history
  • Loading branch information
yjl0101 committed Oct 24, 2024
1 parent 2bc31cd commit fdc1bab
Show file tree
Hide file tree
Showing 7 changed files with 452 additions and 2 deletions.
16 changes: 16 additions & 0 deletions benchmark/test_tensor_constructor_perf.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,3 +159,19 @@ def full_kwargs(dtype, batch, size):
kwargs_func=full_kwargs,
)
bench.run()


def test_perf_randperm():
def randperm_args(dtype, batch, size):
return {"n": size, "dtype": dtype, "device": "cuda"}

bench = Benchmark(
op_name="randperm",
torch_op=torch.randperm,
arg_func=None,
dtypes=[torch.int32, torch.int64],
batch=POINTWISE_BATCH,
sizes=SIZES,
kwargs_func=randperm_args,
)
bench.run()
1 change: 1 addition & 0 deletions src/flag_gems/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,7 @@ def enable(lib=aten_lib):
lib.impl("repeat_interleave.self_int", repeat_interleave_self_int, "CUDA")
lib.impl("vstack", vstack, "CUDA")
lib.impl("repeat_interleave.Tensor", repeat_interleave_tensor, "CUDA")
lib.impl("randperm", randperm, "CUDA")


class use_gems:
Expand Down
2 changes: 2 additions & 0 deletions src/flag_gems/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@
from .rand_like import rand_like
from .randn import randn
from .randn_like import randn_like
from .randperm import randperm
from .reciprocal import reciprocal
from .relu import relu
from .repeat import repeat
Expand Down Expand Up @@ -185,6 +186,7 @@
"minimum",
"rand",
"randn",
"randperm",
"rand_like",
"randn_like",
"resolve_neg",
Expand Down
Loading

0 comments on commit fdc1bab

Please sign in to comment.