Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Operator] index_add #145

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions benchmark/test_reduction_perf.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,3 +286,28 @@ def test_perf_vector_norm():
sizes=SIZES,
)
bench.run()


def test_perf_index_add():
def index_add_args(dtype, batch, size):
inp = torch.randn([batch, size], dtype=dtype, device="cuda")
import random

dim = random.choice([0, 1])
src_shape = list(inp.shape)
index_max = src_shape[dim]
index_len = index_max // 2
index = torch.randint(0, index_max, (index_len,), device="cuda")
src_shape[dim] = index_len
src = torch.randn(src_shape, dtype=dtype, device="cuda")
return (inp, dim, index, src)

bench = Benchmark(
op_name="index_add",
torch_op=torch.index_add,
arg_func=index_add_args,
dtypes=FLOAT_DTYPES,
batch=REDUCTION_BATCH,
sizes=SIZES,
)
bench.run()
1 change: 1 addition & 0 deletions src/flag_gems/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ def enable(lib=aten_lib):
lib.impl("isclose", isclose, "CUDA")
lib.impl("allclose", allclose, "CUDA")
lib.impl("flip", flip, "CUDA")
lib.impl("index_add", index_add, "CUDA")


class use_gems:
Expand Down
2 changes: 2 additions & 0 deletions src/flag_gems/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from .gelu import gelu
from .groupnorm import group_norm
from .gt import gt, gt_scalar
from .index_add import index_add
from .isclose import allclose, isclose
from .isfinite import isfinite
from .isinf import isinf
Expand Down Expand Up @@ -155,4 +156,5 @@
"where_self",
"where_scalar_self",
"where_scalar_other",
"index_add",
]
83 changes: 83 additions & 0 deletions src/flag_gems/ops/index_add.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
import logging

import torch
import triton
import triton.language as tl

from ..utils import dim_compress, libentry


def cfggen():
block_m = [1, 2, 4, 8]
configs = [
triton.Config({"BLOCK_M": m, "BLOCK_N": 1024}, num_warps=4) for m in block_m
]
return configs


@libentry()
@triton.autotune(configs=cfggen(), key=["M", "N"])
@triton.jit
def index_add_kernel(
inp,
out,
index,
src,
M,
N,
alpha,
inp_len,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
):
pid = tl.program_id(0)
rows_offsets = pid * BLOCK_M + tl.arange(0, BLOCK_M)[:, None]
rows_mask = rows_offsets < M

for off in range(0, N, BLOCK_N):
iclementine marked this conversation as resolved.
Show resolved Hide resolved
cols_offsets = off + tl.arange(0, BLOCK_N)
index_mask = cols_offsets < N
block_mask = rows_mask and index_mask

cur_indices = tl.load(index + cols_offsets, mask=index_mask, other=0)
GwokHiujin marked this conversation as resolved.
Show resolved Hide resolved
inp_off = rows_offsets * inp_len + cur_indices[None, :]
cur_inp = tl.load(inp + inp_off, mask=block_mask, other=0.0).to(tl.float32)
src_off = rows_offsets * N + cols_offsets[None, :]
cur_src = tl.load(src + src_off, mask=block_mask, other=0.0).to(tl.float32)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Possibly lose precision for fp64 src and inputs?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What about just keep src and inp as-is without casting?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Possibly lose precision for fp64 src and inputs?

I've encountered precision loss issues in some data types (like bf16 and float32). Ignoring casting might lead to problems. I'll implement the suggested changes below and see if they resolve the issue.

cur_inp += alpha * cur_src

tl.store(out + inp_off, cur_inp, mask=block_mask)


def index_add(inp, dim, index, src, alpha=1):
logging.debug("GEMS INDEX ADD")
assert ((0 <= index) * (index < inp.size(dim))).equal(
torch.ones(tuple(index.shape), dtype=torch.bool, device="cuda")
), "0 <= index < self.size(dim)"
assert dim >= -inp.ndim and dim < inp.ndim, "Invalid dim"
assert index.numel() == src.size(
dim
), "The dimth dimension of source must have the same size as the length of index"
assert (
inp.ndim == src.ndim
), "Self and source should have the same number of dimensions"
assert (
((inp.size(i) == src.size(i)) or i == dim) for i in range(0, inp.ndim)
), "src.size(d) == self.size(d) for all dimensions d != dim"

dim = dim % inp.ndim
inp_len = inp.size(dim)
N = index.numel()
M = src.numel() // N
inp = dim_compress(inp, dim)
src = dim_compress(src, dim)
out = inp.clone()

grid = lambda meta: (triton.cdiv(M, meta["BLOCK_M"]),)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The input & src is permuted into shapes
input: Shape(M, ...) where product(...) == inp_len
src: Shape(M, ...) where product(...) == N
and contiguous.

So we can view then as
input: Shape(M, inp_len)
src: Shape(M, N)
index: (N, )

Then the task is partitioned along the M dimension in tile size of BLOCK_M, while the N dimension is looped in tiles of size BLOCK_N.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Though it is hard to figure out a general solution now, but permuting the tensor to make the inp_len & N dimensional to be contiguous is not always good.

For example,

input & src are both 2d tensors, now index_add along axis 0, then the permutations are actually not needed to make index_add easier.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this is a key issue I constantly consider(Since it actually occurs in other operations, too). As a temporary solution, I set conditional judgments, such as: if the input dimension equals (self.ndim - 1), I don't perform the permutation. I'm uncertain if this approach is effective.

BTW Performance testing revealed that permutations can increase latency by about 7 times compared to Torch, making the reduction of unnecessary permutations crucial... ; (

index_add_kernel[grid](inp, out, index, src, M, N, alpha, inp_len)
if dim != out.ndim - 1:
order = [i for i in range(out.ndim - 1)]
order.insert(dim, inp.ndim - 1)
return out.permute(order).contiguous()
else:
return out
22 changes: 22 additions & 0 deletions tests/test_reduction_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -670,3 +670,25 @@ def test_accuracy_vectornorm(shape, ord, dim, keepdim, dtype):
res_out = torch.linalg.vector_norm(inp, ord, dim, keepdim)

gems_assert_close(res_out, ref_out, dtype)


@pytest.mark.parametrize("shape", REDUCTION_SHAPES)
@pytest.mark.parametrize("dim", DIM_LIST)
@pytest.mark.parametrize("dtype", FLOAT_DTYPES)
def test_accuracy_index_add(shape, dim, dtype):
inp = torch.randn(shape, dtype=dtype, device="cuda")

src_shape = list(inp.shape)
index_max = src_shape[dim]
index_len = index_max // 2
index = torch.randperm(index_len, device="cuda")
src_shape[dim] = index_len
src = torch.randn(src_shape, dtype=dtype, device="cuda")
alpha = 2

ref_inp = to_reference(inp)
ref_out = torch.index_add(ref_inp, dim, index, src, alpha=alpha)
with flag_gems.use_gems():
res_out = torch.index_add(inp, dim, index, src, alpha=alpha)

gems_assert_close(res_out, ref_out, dtype=dtype, reduce_dim=dim)