Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Operator] Add tile op #148

Merged
merged 2 commits into from
Aug 23, 2024
Merged

[Operator] Add tile op #148

merged 2 commits into from
Aug 23, 2024

Conversation

zfu82
Copy link
Collaborator

@zfu82 zfu82 commented Aug 2, 2024

Perf result in NV-A100

benchmark/test_pointwise_perf.py::test_perf_tile Operator tile Performance Test (torch.float16)
Size        Torch Latency (ms)   Gems Latency (ms)
--------------------------------------------------
1024                  0.048128             6.72563
6144                   0.21504             6.88333
11264                 0.387072             6.96422
16384                  0.55808             6.96627
21504                 0.729088             7.12192
26624                  0.90112             7.08198
31744                  1.07213              7.0103
36864                  1.24109             7.02054
41984                   1.4121              7.0615
47104                  1.57901              7.0871
52224                  1.74899             7.12909
57344                  1.93638             7.15571
62464                   2.0992             7.19155
67584                  2.25997             7.17824
72704                  2.43712             7.31443
77824                   2.6071             7.31853
Operator tile Performance Test (torch.float32)
Size        Torch Latency (ms)   Gems Latency (ms)
--------------------------------------------------
1024                  0.043008             6.97958
6144                  0.231424             7.29907
11264                 0.415744             7.11475
16384                 0.601088              7.1465
21504                 0.782336             7.16083
26624                 0.965632             7.19258
31744                  1.15405             7.20896
36864                  1.33018             7.21613
41984                  1.51859             7.24275
47104                  1.69677             7.26016
52224                  1.87904             7.27347
57344                  2.11456             7.28166
62464                  2.33779             7.31648
67584                  2.35418             7.28269
72704                  2.46784             7.28986
77824                  2.59686              7.3257
Operator tile Performance Test (torch.bfloat16)
Size        Torch Latency (ms)   Gems Latency (ms)
--------------------------------------------------
1024                   0.04096             6.98675
6144                   0.21504             7.20282
11264                 0.387072             7.21818
16384                  0.55808             7.23456
21504                 0.730112             7.25197
26624                  0.90112             7.26323
31744                  1.07213             7.17619
36864                  1.24211             7.27962
41984                   1.4121             7.28371
47104                  1.57901             7.29498
52224                  1.74797             7.32365
57344                  1.93638             7.33696
62464                   2.0992             7.37997
67584                  2.26099             7.30522
72704                  2.43814             7.25402
77824                  2.60608             7.25606
PASSED

code.writeline(
f"in{i}_strides = broadcasted_stride(in{i}.shape, in{i}.stride(), shape)"
)
code.writeline(f"if 'in{i}_shape' in kwargs:")
Copy link
Collaborator

@iclementine iclementine Aug 12, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This modification relax the requirements that operand's shape are broadcastable when explicitly passed in input shape.

It may conflict with our further changes to the codegen functionality, but we are considering adding a more powerful analysis to handle these cases. Thanks~

)
code.writeline(
f"in{i}_shape = [(num_tasks + 1) for _ in range(len(shape))]"
)
Copy link
Collaborator

@iclementine iclementine Aug 12, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why use in{i}_shape = [(num_tasks + 1) for _ in range(len(shape))] as input shape in this case? Why not just use shape?

Copy link
Collaborator Author

@zfu82 zfu82 Aug 13, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In tile op, in_shape[i] will be used when tl.load.
Specifically when rank == 2, change in0 = tl.load(in0_ptr + i0 * in0_stride0 + i1 * in0_stride1, mask=mask) into in0 = tl.load(in0_ptr + (i0 % in0.shape[0]) * in0_stride0 + (i1 % in0.shape[1]) * in0_stride1, mask=mask).

But for other ops, we don't need to % in0.shape[0]. To make sure this change will not affect other ops, (num_tasks + 1) is set.

@zfu82 zfu82 force-pushed the dev_tile branch 2 times, most recently from 6b01543 to 3cca5e5 Compare August 14, 2024 04:56
iclementine
iclementine previously approved these changes Aug 23, 2024
Copy link

☂️ Python Coverage

current status: ✅

Overall Coverage

Lines Covered Coverage Threshold Status
8319 5787 70% 0% 🟢

New Files

File Coverage Status
src/flag_gems/ops/tile.py 99% 🟢
TOTAL 99% 🟢

Modified Files

File Coverage Status
src/flag_gems/init.py 100% 🟢
src/flag_gems/ops/init.py 100% 🟢
tests/accuracy_utils.py 93% 🟢
tests/test_unary_pointwise_ops.py 100% 🟢
TOTAL 98% 🟢

updated for commit: c6aa7f6 by action🐍

@iclementine iclementine merged commit 6404d38 into master Aug 23, 2024
3 of 4 checks passed
@iclementine iclementine deleted the dev_tile branch August 23, 2024 06:58
@zfu82 zfu82 mentioned this pull request Aug 30, 2024
3 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants