Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Operator] Add randperm op [MooreThreads] #185

Merged
merged 1 commit into from
Nov 5, 2024

Conversation

yjl0101
Copy link
Contributor

@yjl0101 yjl0101 commented Aug 28, 2024

Add randperm operator

  • add two branches for randperm, one is bitonic method for small cases, the other is radix mathod for larger cases
  • radix method is composed of radix sort + shuffle, but for perf issue, we do a block shuffle rather than adjacent shuffle as pytorch cuda: https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/cuda/Randperm.cu#L106
  • the perf is 3-4 times slower than pytorch cuda cub, the key reasons I think are: 1. triton doesn't support shared memory atomic primitive, so "digit_hist_kernel" is much slower than cub; 2. triton can't control threads level operation, so lookback of radix sort is serialized rather than parallelized in cub;
  • the above understanding may not be accurate, any suggestions are welcome. The following table shows perf of some cases on NV A100, based on torch 2.3.1 and triton 2.3.1:
benchmark/test_tensor_constructor_perf.py Operator randperm Performance Test (torch.int32)
Size        Torch Latency (ms)   Gems Latency (ms)
--------------------------------------------------
1024                   0.03584             0.08704
11264                 0.134144             0.35328
21504                 0.130048            0.347136
31744                 0.129024             0.34816
41984                  0.13824            0.350208
52224                 0.135168            0.351232
62464                 0.141312             0.35328
72704                 0.147456            0.369664
82944                 0.145408            0.386048
93184                 0.144384            0.402432
103424                0.150528            0.416768
113664                0.149504            0.438272
123904                0.149504             0.45056
134144                0.155648            0.463872
144384                0.155648            0.483328
154624                0.154624            0.497664
164864                0.159744            0.505856
175104                 0.15872            0.530432
185344                0.159744            0.546816
195584                0.164864             0.55808
Operator randperm Performance Test (torch.int64)
Size        Torch Latency (ms)   Gems Latency (ms)
--------------------------------------------------
1024                  0.032768             0.13824
11264                 0.101376            0.347136
21504                 0.108544            0.352256
31744                 0.140288            0.351232
41984                  0.14336            0.354304
52224                 0.147456            0.352256
62464                 0.151552            0.364544
72704                 0.154624            0.380928
82944                 0.154624             0.39936
93184                 0.155648            0.418816
103424                0.156672            0.429056
113664                 0.15872             0.44544
123904                0.160768            0.464896
134144                0.162816            0.482304
144384                0.164864            0.495616
154624                0.165888            0.508928
164864                0.166912            0.520192
175104                 0.16896            0.545792
185344                0.171008            0.557056
195584                 0.17408            0.570368

@yjl0101 yjl0101 force-pushed the dev_randperm branch 2 times, most recently from c1ba0aa to c53b7d5 Compare August 29, 2024 02:04
@iclementine iclementine self-assigned this Sep 12, 2024
@Bowen12992 Bowen12992 self-assigned this Oct 23, 2024
rand_key = torch.randint(
low=0, high=i32max, size=[n], dtype=torch.int32, device=device
)
else:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

check n < i64max here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

if n_elements > 2 * 1024:
# radix method
BLOCK_SIZE = 1024
bits_per_pass = 4
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How are these parameters determined? Do they need to be tuned?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These parameters are set according to the perf of some test cases
Actually, BLOCK_SIZE and bits_per_pass is related to the shape of digit_hist&d_lookback which are intermediate tensors, so it looks can't find a way to extract the tuned value to create a tensor.

@yjl0101 yjl0101 force-pushed the dev_randperm branch 4 times, most recently from fdc1bab to 00aee14 Compare October 25, 2024 16:22
BLOCK_SIZE = 1024
bits_per_pass = 4
bits_per_segment = 3
passes = triton.cdiv(key.element_size() * 8, bits_per_pass)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we cut down the passes here by calculating the start_bit and end_bit based on the given n_elements to minimize iterations?I'm not sure if this approach will work

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good suggestions! Adjust the bits of key according to the given n_elements, and thus passses may be cut down in most cases. The following picture shows the current flaggems perf vs pytorch cuda on A100:

Size    Torch Latency (ms)    Gems Latency (ms)    Gems Speedup
---------------------------------------------------------------
1024              0.027648              0.08192           0.338
6144              0.095232             0.288768            0.33
11264             0.103424              0.28672           0.361
16384             0.103424             0.287744           0.359
21504             0.099328             0.287744           0.345
26624             0.099328             0.287744           0.345
31744             0.099328             0.285696           0.348
36864             0.101376             0.284672           0.356
41984             0.104448             0.284672           0.367
47104             0.103424             0.287744           0.359
52224             0.106496              0.28672           0.371
57344             0.109568             0.288768           0.379
62464             0.106496             0.289792           0.367
67584              0.10752              0.34304           0.313
72704             0.111616              0.34304           0.325
77824              0.10752             0.345088           0.312
Operator randperm Performance Test (dtype=torch.int64, mode=cuda)
Size    Torch Latency (ms)    Gems Latency (ms)    Gems Speedup
---------------------------------------------------------------
1024                0.0256             0.123904           0.207
6144              0.075776             0.290816           0.261
11264             0.078848             0.290816           0.271
16384              0.08192             0.287744           0.285
21504             0.083968             0.288768           0.291
26624             0.086016             0.288768           0.298
31744              0.10752             0.288768           0.372
36864             0.108544             0.288768           0.376
41984             0.108544             0.288768           0.376
47104             0.109568             0.287744           0.381
52224             0.109568             0.288768           0.379
57344              0.11264             0.287744           0.391
62464             0.113664             0.287744           0.395
67584             0.114688             0.344064           0.333
72704             0.115712              0.34304           0.337
77824             0.115712             0.344064           0.336

@yjl0101 yjl0101 force-pushed the dev_randperm branch 2 times, most recently from d525e23 to 181d15e Compare October 31, 2024 06:27
Copy link
Collaborator

@Bowen12992 Bowen12992 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@Bowen12992 Bowen12992 merged commit 3e47645 into FlagOpen:master Nov 5, 2024
4 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants