Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Operator] Add nonzero op [MooreThreads] #178

Merged
merged 1 commit into from
Sep 2, 2024

Conversation

wuke1993
Copy link
Contributor

Add nonzero op.

  • nonzero needs to call cumsum, but current cumsum do not support N>1024*1024. So we add a scan op (align to cub InclusiveScan, scan_op=add) which only support M==1 and K==1 cumsum cases.
  • nonzero perf on A100, batch=1024:
Operator nonzero Performance Test (torch.float16)
Size        Torch Latency (ms)   Gems Latency (ms)
--------------------------------------------------
1024                   0.05632             0.37648
6144                  0.155648            0.561696
11264                 0.270336            0.671232
16384                 0.377856            0.785984
21504                   0.4864            0.898528
26624                 0.591872             1.01238
31744                  0.69632             1.12637
36864                 0.802816             1.25216
41984                 0.913408             1.40141
47104                  1.02298             1.55158
52224                  1.12947             1.70736
57344                  1.23699             1.85536
62464                  1.34656             2.00752
67584                  1.45101             2.16147
72704                   1.5616             2.31488
77824                  1.67117              2.4664
Operator nonzero Performance Test (torch.float32)
Size        Torch Latency (ms)   Gems Latency (ms)
--------------------------------------------------
1024                  0.057344             0.36144
6144                  0.185344            0.557696
11264                  0.29696            0.675136
16384                   0.4096            0.784192
21504                  0.52736            0.895136
26624                 0.641024              1.0095
31744                 0.756736             1.13766
36864                 0.872448             1.29763
41984                 0.996352             1.45187
47104                  1.11002             1.61133
52224                  1.22675              1.7689
57344                  1.34451             1.92502
62464                   1.4592             2.08288
67584                  1.57696             2.24349
72704                  1.68858             2.40195
77824                  1.80326             2.55888
Operator nonzero Performance Test (torch.bfloat16)
Size        Torch Latency (ms)   Gems Latency (ms)
--------------------------------------------------
1024                   0.05632              0.3632
6144                  0.155648            0.555104
11264                 0.270336            0.669632
16384                  0.37888            0.782176
21504                   0.4864            0.895072
26624                 0.592896             1.00518
31744                  0.69632             1.12154
36864                 0.801792             1.25072
41984                 0.914432             1.39862
47104                  1.02195             1.55526
52224                   1.1305             1.70653
57344                  1.23699              1.8545
62464                  1.34656             2.00531
67584                  1.45203             2.15994
72704                  1.56365             2.31178
77824                   1.6681             2.46554
Operator nonzero Performance Test (torch.bool)
Size        Torch Latency (ms)   Gems Latency (ms)
--------------------------------------------------
1024                  0.055296              0.2896
6144                  0.146432            0.483616
11264                 0.253952            0.595648
16384                 0.357376             0.70848
21504                 0.461824             0.81936
26624                 0.566272            0.931808
31744                 0.667648             1.05053
36864                 0.770048             1.17546
41984                 0.876544             1.31706
47104                  0.97792             1.45971
52224                  1.08032             1.60371
57344                  1.18579             1.74394
62464                  1.29126             1.88778
67584                  1.39366             2.03043
72704                  1.49811             2.17069
77824                  1.60358             2.31408

@wuke1993
Copy link
Contributor Author

Branch selection for cumsum:

  • we set THRESHOLD_SCAN_THEN_FAN=1024 * 4. Cases bigger than THRESHOLD_SCAN_THEN_FAN will call scan.
  • perf on A100, M==1 and K==1:
Operator cumsum Performance Test (torch.float32)
Size        cumsum_current Latency (ms)   cumsum_scan Latency (ms)
--------------------------------------------------
1024                   0.01024            0.006144
2048                  0.013312            0.080896
3072                  0.032768            0.080896
4096                  0.032768            0.080896
5120                   0.11264            0.079872
6144                  0.113664            0.079872
7168                  0.114688            0.080896
8192                  0.115712            0.079872

Operator cumsum Performance Test (torch.float32)
Size        Torch Latency (ms)   cumsum_scan Latency (ms)
--------------------------------------------------
1024                  0.009216            0.006144
32768                 0.013312             0.07168
65536                 0.014336             0.07168
131072                0.014336            0.070656
196608                 0.01536            0.072704
262144                0.016384            0.072704
327680                0.017408             0.07168
393216                0.018432             0.07168
458752                0.018432             0.07168
524288                0.019456            0.072704
589824                0.019456            0.072704
655360                 0.02048            0.069632
720896                 0.02048             0.07168
786432                 0.02048             0.07168
851968                 0.02048            0.069632
917504                0.021504            0.069632
983040                0.021504             0.07168
1048576               0.021504            0.069632

src/flag_gems/ops/scan.py Outdated Show resolved Hide resolved
@@ -0,0 +1,78 @@
import logging
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The functions in this file are better building blocks for cumsum. Shall we pack and move these functions into cumsum.py?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there any plan to develop scan operator? If there is, then I think a separate scan operator is better.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's no corresponding scan operator in aten, is there?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, there is no corresponding scan operator in aten.
There is a 'ScanKernels' in native cuda which contains the implementation for cumsum, cumprod and ect. Could we also treat scan just like this 'ScanKernels' not a operator?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add a ABC version scan_then_fan in scan, which can support M and K are not 1.
perf on A100, batch=1024:

Operator cumsum Performance Test (torch.float16)
Size        Torch Latency (ms)   Gems Latency (ms)
--------------------------------------------------
1024                  0.077824            0.014336
6144                   0.15872            0.078848
11264                 0.282624            0.089088
16384                 0.406528            0.131072
21504                 0.370688            0.166912
26624                 0.456704            0.200704
31744                  0.54272             0.23552
36864                 0.628736            0.272384
41984                 0.713728              0.3072
47104                 0.800768            0.345088
52224                  0.88576            0.375808
57344                   0.9728             0.41472
62464                  1.05779            0.454656
67584                 0.999424            0.482304
72704                  1.07622            0.524288
77824                  1.14995             0.55808
Operator cumsum Performance Test (torch.float32)
Size        Torch Latency (ms)   Gems Latency (ms)
--------------------------------------------------
1024                  0.080896            0.014336
6144                  0.147456              0.0768
11264                 0.263168            0.137216
16384                 0.376832             0.18944
21504                  0.33792            0.244736
26624                 0.415744            0.299008
31744                 0.493568            0.352256
36864                 0.571392              0.4096
41984                 0.649216            0.463872
47104                  0.72704            0.514048
52224                 0.804864            0.566272
57344                 0.883712            0.627712
62464                 0.961536            0.684032
67584                 0.945152             0.73216
72704                  1.01376            0.821248
77824                  1.08339            0.848896
Operator cumsum Performance Test (torch.bfloat16)
Size        Torch Latency (ms)   Gems Latency (ms)
--------------------------------------------------
1024                  0.077824            0.013312
6144                  0.159744            0.079872
11264                 0.284672            0.089088
16384                   0.4096            0.130048
21504                 0.380928            0.166912
26624                 0.467968            0.200704
31744                 0.557056            0.236544
36864                  0.64512             0.27136
41984                 0.733184            0.306176
47104                 0.821248            0.344064
52224                 0.909312            0.376832
57344                   0.9984             0.41472
62464                  1.08544            0.454656
67584                  1.02707            0.482304
72704                  1.10387             0.52224
77824                  1.17965            0.553984

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The triton flavored scan operator should support combine functions other than add. Triton already has its block level builtin baked in. You are free to provide a general scan operator. For now let's just replace original cumsum. Shall we?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The triton flavored scan operator should support combine functions other than add. Triton already has its block level builtin baked in. You are free to provide a general scan operator. For now let's just replace original cumsum. Shall we?

Sure, done.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The triton flavored scan operator should support combine functions other than add. Triton already has its block level builtin baked in. You are free to provide a general scan operator. For now let's just replace original cumsum. Shall we?

Sure, done.

How does the time it takes to autotune the scan op compare to the time it takes to autotune the original cumsum op?

plz see the comment below

Copy link
Collaborator

@zhzhcookie zhzhcookie Aug 29, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add a ABC version scan_then_fan in scan, which can support M and K are not 1. perf on A100, batch=1024:

Operator cumsum Performance Test (torch.float16)
Size        Torch Latency (ms)   Gems Latency (ms)
--------------------------------------------------
1024                  0.077824            0.014336
6144                   0.15872            0.078848
11264                 0.282624            0.089088
16384                 0.406528            0.131072
21504                 0.370688            0.166912
26624                 0.456704            0.200704
31744                  0.54272             0.23552
36864                 0.628736            0.272384
41984                 0.713728              0.3072
47104                 0.800768            0.345088
52224                  0.88576            0.375808
57344                   0.9728             0.41472
62464                  1.05779            0.454656
67584                 0.999424            0.482304
72704                  1.07622            0.524288
77824                  1.14995             0.55808
Operator cumsum Performance Test (torch.float32)
Size        Torch Latency (ms)   Gems Latency (ms)
--------------------------------------------------
1024                  0.080896            0.014336
6144                  0.147456              0.0768
11264                 0.263168            0.137216
16384                 0.376832             0.18944
21504                  0.33792            0.244736
26624                 0.415744            0.299008
31744                 0.493568            0.352256
36864                 0.571392              0.4096
41984                 0.649216            0.463872
47104                  0.72704            0.514048
52224                 0.804864            0.566272
57344                 0.883712            0.627712
62464                 0.961536            0.684032
67584                 0.945152             0.73216
72704                  1.01376            0.821248
77824                  1.08339            0.848896
Operator cumsum Performance Test (torch.bfloat16)
Size        Torch Latency (ms)   Gems Latency (ms)
--------------------------------------------------
1024                  0.077824            0.013312
6144                  0.159744            0.079872
11264                 0.284672            0.089088
16384                   0.4096            0.130048
21504                 0.380928            0.166912
26624                 0.467968            0.200704
31744                 0.557056            0.236544
36864                  0.64512             0.27136
41984                 0.733184            0.306176
47104                 0.821248            0.344064
52224                 0.909312            0.376832
57344                   0.9984             0.41472
62464                  1.08544            0.454656
67584                  1.02707            0.482304
72704                  1.10387             0.52224
77824                  1.17965            0.553984

I ran the cumsum benchmark on A100, the result of float32 is different:

test_reduction_perf.py Operator cumsum Performance Test (torch.float16)
Size        Torch Latency (ms)   Gems Latency (ms)
--------------------------------------------------
1024                  0.074944            0.012896
6144                  0.124256             0.04704
11264                 0.220736            0.079232
16384                 0.315392            0.125248
21504                 0.284992            0.159616
26624                  0.35088            0.191744
31744                 0.416768            0.225632
36864                 0.482048            0.260576
41984                 0.548352            0.293824
47104                 0.613568            0.337056
52224                 0.678784            0.359104
57344                 0.746112            0.401952
62464                   0.8104            0.442272
67584                 0.771552            0.464224
72704                 0.828736            0.508992
77824                 0.885536            0.537536
Operator cumsum Performance Test (torch.float32)
Size        Torch Latency (ms)   Gems Latency (ms)
--------------------------------------------------
1024                   0.06416            0.013664
6144                  0.117504            0.070144
11264                 0.207968            0.150912
16384                 0.295808            0.213248
21504                 0.263392            0.274528
26624                  0.32352            0.336928
31744                 0.383744            0.398848
36864                 0.444832            0.470752
41984                 0.504832              0.5312
47104                 0.565056            0.590976
52224                 0.624896            0.649152
57344                 0.684736            0.723904
62464                 0.745184            0.789696
67584                 0.741888            0.844032
72704                 0.796416            0.952448
77824                 0.850624             0.98816
Operator cumsum Performance Test (torch.bfloat16)
Size        Torch Latency (ms)   Gems Latency (ms)
--------------------------------------------------
1024                  0.061248              0.0128
6144                  0.125184            0.047104
11264                 0.221408            0.078752
16384                  0.31744            0.124096
21504                  0.29152             0.15824
26624                 0.359296            0.191296
31744                 0.426912             0.22592
36864                 0.494432            0.260224
41984                 0.561792             0.29264
47104                 0.628352            0.333312
52224                 0.695808            0.360288
57344                 0.763744            0.407584
62464                 0.829856            0.449312
67584                   0.7904            0.464032
72704                 0.849376            0.514336
77824                 0.907168            0.539552
.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I switched to another A100 and got a torch performance close to yours. The Gems's perf alse become faster, which is close to the perf of torch in big case.

Operator cumsum Performance Test (torch.float32)
Size        Torch Latency (ms)   Gems Latency (ms)
--------------------------------------------------
1024                  0.063488            0.012288
6144                  0.115712            0.083968
11264                   0.2048            0.131072
16384                 0.293888            0.183296
21504                 0.260096            0.237568
26624                 0.320512             0.29184
31744                 0.380928            0.344064
36864                  0.44032            0.401408
41984                 0.499712             0.45568
47104                 0.559104            0.504832
52224                 0.618496             0.55808
57344                 0.678912            0.621568
62464                 0.738304            0.679936
67584                 0.731136            0.720896
72704                 0.786432            0.787456
77824                  0.83968            0.836608

@wuke1993 wuke1993 force-pushed the dev_nonzero branch 2 times, most recently from 7e6a8e3 to 9f0ccf4 Compare August 29, 2024 02:31

prefix_sum = inp_bool.cumsum(axis=0)

num_nonzeros = n_elements
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can it be " num_nonzeros = prefix_sum[-1] " here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The original cumsum kernel only gives one single block for N, so the performance will drop much when the cases become bigger.
It takes like hours to run the entire performance test cases for the original cumsum... Here is performance of the first 4 cases for batch=1024 on A100:

Operator cumsum Performance Test (torch.float16)
Size    Torch Latency (ms)     STF Latency (ms)   original Latency (ms)
----------------------------------------------------------------------
1024              0.077824            0.014336            0.018432
6144               0.15872            0.078848            0.270336
11264             0.282624            0.089088            0.733184
16384             0.406528            0.131072            0.804864

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can it be " num_nonzeros = prefix_sum[-1] " here?

Yes, and actually the size should be 'prefix_sum[-1] '.
Here we trading space for time. 'prefix_sum[n_elements - 1].item()' will trigger a d2h, so we move it to the end. Then those cuda kernels will not be interrupted.
Here is the performance for batch=1024 on A100:

Operator nonzero Performance Test (torch.bool)
Size        fun1 Latency (ms)   fun2 Latency (ms)
--------------------------------------------------
1024                  0.302656            0.304128
6144                  0.505504            0.530432
11264                 0.620224             0.64512
16384                 0.741536            0.769024
21504                 0.858496            0.884736
26624                  0.97568             1.00557
31744                  1.09635             1.12435
36864                  1.24224             1.26566
41984                   1.3944             1.40698
47104                  1.54192             1.56365
52224                  1.68931             1.70906
57344                  1.83971              1.8647
62464                  1.99014             2.01216
67584                  2.14531             2.16474
72704                  2.29331             2.31322
77824                  2.44122             2.46579

op_name="nonzero",
torch_op=torch.nonzero,
arg_func=nonzero_args,
dtypes=FLOAT_DTYPES + [torch.bool],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need a INT_DTYPES in test utils?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done, add INT_DTYPES in benchmark and test

}
)

@triton.jit
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Adding do_not_specialize=['n_elements', 'part_num'] can help avoid redundant kernel generation when recursion does happen.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

tl.store(out_ptrs, out_vals, mask=mask)


@triton.jit
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do_not_specialize=['part_num'] may help too.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

tl.store(partial_sum_ptrs, part_sum_via_sum)


@triton.jit
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

Comment on lines 55 to 60
if dtype == torch.bfloat16:
atol = 1e-3 * reduce_dim
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't change this.

Copy link
Contributor Author

@wuke1993 wuke1993 Aug 30, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done.
bf16 precision and different partition num may lead to different result for cumsum.

Test with cases (4096, 256 * i) for i in range(1, 99, 2):

The original BLOCK_SIZE is set as 1024, which randomly lead to mismatched elements like below(usually happen in cases (4096, 1792) and (4096, 2304), not sure why. Any thoughts on that?):

0.24+ measure up to the precision of bf16(approximately 2 to 3 significant digits).
图片

After changing the BLOCK_SIZE to next_power_of_2 for small cases(maybe same strategy with torch?), all the cases passed. So I use this as the final BLOCK_SIZE.
图片

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Floating point results are very sensitive to ordering, rounding and accumulate nuances. When it comes to GPU its almost impossible to reason what makes the results change. Let's not bothered too much about it. One or two mismatching runs doesn't mean the result is incorrect. We may discover better accuracy tests but let's just stick to it for now.

Copy link
Contributor

@tongxin tongxin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please also test cumsum and strengthen benchmarks.

@@ -250,6 +251,22 @@ def test_accuracy_cumsum(shape, dtype):
gems_assert_close(res_out, ref_out, dtype, reduce_dim=shape[dim])


@pytest.mark.parametrize("shape", REDUCTION_SHAPES)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I test shape=(1024 * 1024 * 64, 1) and get a wrong result. Can you try this shape in your environment?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I got wrong result too.
The bug is located as below. The root cause is the precision error due to float. inp is bool, and it is converted to float32 to calculate the cumsum. It should be int32.
Same precision issue may happen when the input dtype is float16 or bfloat16. I will correct them together later.

# wrong
inp_vals = tl.load(inp_ptrs, mask=mask).to(tl.float32)
result = tl.cumsum(inp_vals, axis=0)

# correct
inp_vals = tl.load(inp_ptrs, mask=mask).to(tl.int32)
result = tl.cumsum(inp_vals, axis=0)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

@wuke1993 wuke1993 force-pushed the dev_nonzero branch 2 times, most recently from 39f71f9 to 9a82815 Compare August 30, 2024 09:21
Copy link
Contributor

@tongxin tongxin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@tongxin tongxin merged commit 40945c3 into FlagOpen:master Sep 2, 2024
4 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants