-
Notifications
You must be signed in to change notification settings - Fork 439
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Introduce autotuning to conv2d
and conv_transpose2d
with a new im2col
/GEMM
algorithm
#2287
Conversation
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #2287 +/- ##
==========================================
- Coverage 85.67% 85.21% -0.47%
==========================================
Files 760 766 +6
Lines 99082 100293 +1211
==========================================
+ Hits 84888 85462 +574
- Misses 14194 14831 +637 ☔ View full report in Codecov by Sentry. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, waiting for @louisfd to review before merging, but thanks a lot 🙏
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have a few questions, but overall it's very good, awesome!
|
||
/**************************** Bounds Check + CMMA Op*********************************/ | ||
if a_row < gemm_m && k < gemm_k && b_col < gemm_n { | ||
cmma::load(&matrix_a, input_tile.as_slice(), CMMA_K); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When doing the matmul, i found you can safely reuse a fragment for several executions. So this line could go in the outer loop.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure what you mean? I am reusing the fragments, but I still need to load a tile and execute the matmul for each k
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
Checklist
run-checks all
script has been executed.Related Issues/PRs
This goes some way to resolving this old enhancement suggestion: #805
Changes
Adds the required infrastructure to autotune
conv2d
andconv_transpose2d
, as well as adding a second algorithm based onim2col
which provides significant speedups at the cost of memory usage. I'd like to add more algorithms when I have time, but this will already put in the infrastructure to make that much easier and less breaking.Update; Now includes implicit GEMM when:
batch_size * out_h * out_w
is divisible by 16out_channels
is divisible by 16in_channels * kernel_h * kernel_w
is divisible by 16Testing
The autotuned, direct (current) and
im2col
implementations pass all existing tests (except matches_reference_backend, see below).Benchmarking
Conv2d
Direct
batch_size = 4
jit<wgpu>
fusion<jit<wgpu>>
jit<cuda>
batch_size = 16
jit<wgpu>
fusion<jit<wgpu>>
jit<cuda>
Im2col
batch_size = 4
jit<wgpu>
fusion<jit<wgpu>>
jit<cuda>
batch_size = 16 (split into 4 sub-batches)
jit<wgpu>
fusion<jit<wgpu>>
jit<cuda>
Implicit GEMM
batch_size = 16
jit<cuda>