Skip to content

Commit

Permalink
[Operator] Add tile op (#148)
Browse files Browse the repository at this point in the history
Co-authored-by: Clement Chan <[email protected]>
  • Loading branch information
zfu82 and iclementine authored Aug 23, 2024
1 parent 2c4625e commit 6404d38
Show file tree
Hide file tree
Showing 6 changed files with 503 additions and 0 deletions.
16 changes: 16 additions & 0 deletions benchmark/test_pointwise_perf.py
Original file line number Diff line number Diff line change
Expand Up @@ -614,3 +614,19 @@ def masked_fill_args(dtype, batch, size):
sizes=SIZES,
)
bench.run()


def test_perf_tile():
def tile_kwargs(dtype, batch, size):
return {"dims": [2, 4]}

bench = Benchmark(
op_name="tile",
torch_op=torch.tile,
arg_func=unary_arg,
dtypes=FLOAT_DTYPES,
batch=POINTWISE_BATCH,
sizes=SIZES,
kwargs_func=tile_kwargs,
)
bench.run()
1 change: 1 addition & 0 deletions src/flag_gems/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@ def enable(lib=aten_lib):
lib.impl("isclose", isclose, "CUDA")
lib.impl("allclose", allclose, "CUDA")
lib.impl("flip", flip, "CUDA")
lib.impl("tile", tile, "CUDA")
lib.impl("index_select", index_select, "CUDA")
lib.impl("masked_fill", masked_fill, "CUDA")
lib.impl("_unique2", _unique2, "CUDA")
Expand Down
2 changes: 2 additions & 0 deletions src/flag_gems/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@
from .sub import sub
from .sum import sum, sum_dim
from .tanh import tanh
from .tile import tile
from .topk import topk
from .triu import triu
from .uniform import uniform_
Expand Down Expand Up @@ -179,6 +180,7 @@
"softmax",
"sub",
"tanh",
"tile",
"triu",
"topk",
"max",
Expand Down
Loading

0 comments on commit 6404d38

Please sign in to comment.