-
Notifications
You must be signed in to change notification settings - Fork 38
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Operator] Add randperm op [MooreThreads] #185
Conversation
c1ba0aa
to
c53b7d5
Compare
rand_key = torch.randint( | ||
low=0, high=i32max, size=[n], dtype=torch.int32, device=device | ||
) | ||
else: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
check n < i64max
here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
if n_elements > 2 * 1024: | ||
# radix method | ||
BLOCK_SIZE = 1024 | ||
bits_per_pass = 4 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How are these parameters determined? Do they need to be tuned?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These parameters are set according to the perf of some test cases
Actually, BLOCK_SIZE and bits_per_pass is related to the shape of digit_hist&d_lookback which are intermediate tensors, so it looks can't find a way to extract the tuned value to create a tensor.
fdc1bab
to
00aee14
Compare
src/flag_gems/ops/randperm.py
Outdated
BLOCK_SIZE = 1024 | ||
bits_per_pass = 4 | ||
bits_per_segment = 3 | ||
passes = triton.cdiv(key.element_size() * 8, bits_per_pass) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we cut down the passes
here by calculating the start_bit
and end_bit
based on the given n_elements
to minimize iterations?I'm not sure if this approach will work
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good suggestions! Adjust the bits of key according to the given n_elements
, and thus passses
may be cut down in most cases. The following picture shows the current flaggems perf vs pytorch cuda on A100:
Size Torch Latency (ms) Gems Latency (ms) Gems Speedup
---------------------------------------------------------------
1024 0.027648 0.08192 0.338
6144 0.095232 0.288768 0.33
11264 0.103424 0.28672 0.361
16384 0.103424 0.287744 0.359
21504 0.099328 0.287744 0.345
26624 0.099328 0.287744 0.345
31744 0.099328 0.285696 0.348
36864 0.101376 0.284672 0.356
41984 0.104448 0.284672 0.367
47104 0.103424 0.287744 0.359
52224 0.106496 0.28672 0.371
57344 0.109568 0.288768 0.379
62464 0.106496 0.289792 0.367
67584 0.10752 0.34304 0.313
72704 0.111616 0.34304 0.325
77824 0.10752 0.345088 0.312
Operator randperm Performance Test (dtype=torch.int64, mode=cuda)
Size Torch Latency (ms) Gems Latency (ms) Gems Speedup
---------------------------------------------------------------
1024 0.0256 0.123904 0.207
6144 0.075776 0.290816 0.261
11264 0.078848 0.290816 0.271
16384 0.08192 0.287744 0.285
21504 0.083968 0.288768 0.291
26624 0.086016 0.288768 0.298
31744 0.10752 0.288768 0.372
36864 0.108544 0.288768 0.376
41984 0.108544 0.288768 0.376
47104 0.109568 0.287744 0.381
52224 0.109568 0.288768 0.379
57344 0.11264 0.287744 0.391
62464 0.113664 0.287744 0.395
67584 0.114688 0.344064 0.333
72704 0.115712 0.34304 0.337
77824 0.115712 0.344064 0.336
d525e23
to
181d15e
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
Add randperm operator