Skip to content

Commit

Permalink
[Operator] Add vstack op (#175)
Browse files Browse the repository at this point in the history
  • Loading branch information
yjl0101 authored Sep 20, 2024
1 parent da86496 commit f4b2495
Show file tree
Hide file tree
Showing 5 changed files with 196 additions and 0 deletions.
18 changes: 18 additions & 0 deletions benchmark/test_special_perf.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,3 +270,21 @@ def cat_kwargs(dtype, batch, size):
kwargs_func=cat_kwargs,
)
bench.run()


def test_perf_vstack():
def vstack_args(dtype, batch, size):
inp1 = torch.randn(size=(batch, size), dtype=dtype, device="cuda")
inp2 = torch.randn(size=(batch + 1, size), dtype=dtype, device="cuda")
inp3 = torch.randn(size=(batch + 2, size), dtype=dtype, device="cuda")
return [[inp1, inp2, inp3]]

bench = Benchmark(
op_name="vstack",
torch_op=torch.vstack,
arg_func=vstack_args,
dtypes=FLOAT_DTYPES,
batch=(512),
sizes=SIZES,
)
bench.run()
1 change: 1 addition & 0 deletions src/flag_gems/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,7 @@ def enable(lib=aten_lib):
lib.impl("hstack", hstack, "CUDA")
lib.impl("cat", cat, "CUDA")
lib.impl("repeat_interleave.self_int", repeat_interleave_self_int, "CUDA")
lib.impl("vstack", vstack, "CUDA")


class use_gems:
Expand Down
2 changes: 2 additions & 0 deletions src/flag_gems/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@
from .unique import _unique2
from .var_mean import var_mean
from .vector_norm import vector_norm
from .vstack import vstack
from .where import where_scalar_other, where_scalar_self, where_self
from .zeros import zeros
from .zeros_like import zeros_like
Expand Down Expand Up @@ -237,4 +238,5 @@
"hstack",
"cat",
"repeat_interleave_self_int",
"vstack",
]
141 changes: 141 additions & 0 deletions src/flag_gems/ops/vstack.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
import logging

import torch
import triton
import triton.language as tl

from ..utils import libentry


@libentry()
@triton.autotune(
configs=[
triton.Config({"BLOCK_SIZE": k}, num_warps=w)
for w in [4, 8, 16, 32]
for k in [512, 1024, 2048, 4096]
],
key=[
"max_tile_elems",
],
)
@triton.jit
def vstack_kernel(
itensor_ptr0,
itensor_ptr1,
itensor_ptr2,
itensor_ptr3,
output_ptr,
local_row0,
local_row1,
local_row2,
local_row3,
exc_row_offset0,
exc_row_offset1,
exc_row_offset2,
exc_row_offset3,
total_row_offset,
row_stride,
max_tile_elems,
BLOCK_SIZE: tl.constexpr,
):
pid_x = tl.program_id(axis=0)
tensor_idx = tl.program_id(axis=1)
col_idx = tl.arange(0, BLOCK_SIZE)

intensor_ptr = tl.where(tensor_idx == 0, itensor_ptr0, itensor_ptr1)
intensor_ptr = tl.where(tensor_idx == 2, itensor_ptr2, intensor_ptr)
intensor_ptr = tl.where(tensor_idx == 3, itensor_ptr3, intensor_ptr)
base_exc_row_idx = tl.where(tensor_idx == 0, exc_row_offset0, exc_row_offset1)
base_exc_row_idx = tl.where(tensor_idx == 2, exc_row_offset2, base_exc_row_idx)
base_exc_row_idx = tl.where(tensor_idx == 3, exc_row_offset3, base_exc_row_idx)
local_row = tl.where(tensor_idx == 0, local_row0, local_row1)
local_row = tl.where(tensor_idx == 2, local_row2, local_row)
local_row = tl.where(tensor_idx == 3, local_row3, local_row)

end_idx = local_row * row_stride.to(tl.int64)
idx = (pid_x * BLOCK_SIZE + col_idx).to(tl.int64)
offset_mask = idx < end_idx
in_offset = intensor_ptr + idx
row_stride_offset = (total_row_offset + base_exc_row_idx) * row_stride.to(tl.int64)
out_offset = output_ptr + row_stride_offset + idx
out = tl.load(in_offset, mask=offset_mask)
tl.store(out_offset, out, mask=offset_mask)


def vstack(tensors: list[torch.Tensor]):
logging.debug("GEMS VSTACK")

tensors = torch.atleast_2d(tensors)
num_tensors = len(tensors)
assert num_tensors > 0

# Ensure all tensors are on the same device and have the same dtype
device = tensors[0].device
dtype = tensors[0].dtype
for tensor in tensors:
assert (
tensor.device == device
and tensor.dtype == dtype
and tensors[0].shape[1:] == tensor.shape[1:]
)

c_tensors = [t.contiguous() for t in tensors]
# Calculate the output shape
total_rows = sum(tensor.shape[0] for tensor in c_tensors)
output_shape = list(c_tensors[0].shape)
output_shape[0] = total_rows
output = torch.empty(output_shape, device=device, dtype=dtype)
row_stride = c_tensors[0].stride(0)

outer_iters = triton.cdiv(num_tensors, 4)
total_row_offset = 0
for i in range(outer_iters):
max_rows = 1
itensors = []
exclusive_row = []
local_row = []
array_row_offset = 0
scheduled_num_tensors = 0
for j in range(4):
tensor_idx = i * 4 + j
if tensor_idx < num_tensors:
scheduled_num_tensors += 1
itensors.append(c_tensors[tensor_idx])
local_row.append(c_tensors[tensor_idx].shape[0])
exclusive_row.append(array_row_offset)
array_row_offset += c_tensors[tensor_idx].shape[0]
max_rows = max(max_rows, c_tensors[tensor_idx].shape[0])
else:
empty_tensor = torch.empty(
0, dtype=c_tensors[0].dtype, device=c_tensors[0].device
)
itensors.append(empty_tensor)
local_row.append(local_row[-1])
exclusive_row.append(exclusive_row[-1])
max_tile_elems = max_rows * row_stride
grid = lambda META: (
triton.cdiv(max_tile_elems, META["BLOCK_SIZE"]),
scheduled_num_tensors,
)
# Launch the kernel
with torch.cuda.device(c_tensors[0].device):
vstack_kernel[grid](
itensors[0],
itensors[1],
itensors[2],
itensors[3],
output,
local_row[0],
local_row[1],
local_row[2],
local_row[3],
exclusive_row[0],
exclusive_row[1],
exclusive_row[2],
exclusive_row[3],
total_row_offset,
row_stride,
max_tile_elems,
)
total_row_offset += array_row_offset
return output
34 changes: 34 additions & 0 deletions tests/test_special_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -652,3 +652,37 @@ def test_accuracy_cat(shape, dim, dtype):
with flag_gems.use_gems():
res_out = torch.cat(inp, dim)
gems_assert_equal(res_out, ref_out)


VSTACK_SHAPES = [
[(3,), (3,)],
[(3, 33), (7, 33)],
[(13, 3, 333), (17, 3, 333), (7, 3, 333)],
[
(13, 3, 64, 5, 2),
(16, 3, 64, 5, 2),
(7, 3, 64, 5, 2),
(4, 3, 64, 5, 2),
(1, 3, 64, 5, 2),
],
]


@pytest.mark.parametrize("shape", VSTACK_SHAPES)
@pytest.mark.parametrize("dtype", FLOAT_DTYPES + INT_DTYPES)
def test_accuracy_vstack(shape, dtype):
if dtype in FLOAT_DTYPES:
inp = [torch.randn(s, dtype=dtype, device="cuda") for s in shape]
else:
inp = [
torch.randint(low=0, high=0x7FFF, size=s, dtype=dtype, device="cuda").to(
dtype
)
for s in shape
]
ref_inp = [to_reference(_) for _ in inp]
ref_out = torch.vstack(ref_inp)

with flag_gems.use_gems():
res_out = torch.vstack(inp)
gems_assert_equal(res_out, ref_out)

0 comments on commit f4b2495

Please sign in to comment.