Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Operator] index_add #145

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open

[Operator] index_add #145

wants to merge 4 commits into from

Conversation

GwokHiujin
Copy link
Collaborator

@GwokHiujin GwokHiujin commented Jul 31, 2024

We have completed the development of the index_add operator. Specifically:

  • The corresponding aTen operator is index_add
  • Added accuracy test and perf test

@GwokHiujin GwokHiujin marked this pull request as ready for review July 31, 2024 04:47
@GwokHiujin GwokHiujin changed the title [Operator] init index_add [Operator] index_add Jul 31, 2024
Comment on lines 44 to 46
cur_inp = tl.load(inp + inp_off, mask=block_mask, other=0.0).to(tl.float32)
src_off = rows_offsets * N + cols_offsets[None, :]
cur_src = tl.load(src + src_off, mask=block_mask, other=0.0).to(tl.float32)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Possibly lose precision for fp64 src and inputs?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What about just keep src and inp as-is without casting?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Possibly lose precision for fp64 src and inputs?

I've encountered precision loss issues in some data types (like bf16 and float32). Ignoring casting might lead to problems. I'll implement the suggested changes below and see if they resolve the issue.

src = dim_compress(src, dim)
out = inp.clone()

grid = lambda meta: (triton.cdiv(M, meta["BLOCK_M"]),)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The input & src is permuted into shapes
input: Shape(M, ...) where product(...) == inp_len
src: Shape(M, ...) where product(...) == N
and contiguous.

So we can view then as
input: Shape(M, inp_len)
src: Shape(M, N)
index: (N, )

Then the task is partitioned along the M dimension in tile size of BLOCK_M, while the N dimension is looped in tiles of size BLOCK_N.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Though it is hard to figure out a general solution now, but permuting the tensor to make the inp_len & N dimensional to be contiguous is not always good.

For example,

input & src are both 2d tensors, now index_add along axis 0, then the permutations are actually not needed to make index_add easier.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this is a key issue I constantly consider(Since it actually occurs in other operations, too). As a temporary solution, I set conditional judgments, such as: if the input dimension equals (self.ndim - 1), I don't perform the permutation. I'm uncertain if this approach is effective.

BTW Performance testing revealed that permutations can increase latency by about 7 times compared to Torch, making the reduction of unnecessary permutations crucial... ; (

Copy link
Collaborator

@iclementine iclementine left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are some spaces for optimization, but LGTM.

Copy link
Collaborator

@iclementine iclementine left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some suggestions

  1. Ensure index is contiguous, or consider its stride;
  2. keep the data loaded from src as-is to avoid down-cast;
  3. (Maybe) use some heuristics to make a better task partitioning & avoid unnecessary data permutations.

* Use a 2D grid with the kernel
* Ensure index is contiguous
* Keep the data in kernel loaded from src
* Try to avoid some unnecessary permutations
@iclementine iclementine self-assigned this Aug 19, 2024
iclementine
iclementine previously approved these changes Aug 23, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants