Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduce autotuning to conv2d and conv_transpose2d with a new im2col/GEMM algorithm #2287

Merged
merged 28 commits into from
Sep 23, 2024

Conversation

wingertge
Copy link
Contributor

@wingertge wingertge commented Sep 17, 2024

Checklist

  • Confirmed that run-checks all script has been executed.
  • Made sure the book is up to date with changes in this PR.

Related Issues/PRs

This goes some way to resolving this old enhancement suggestion: #805

Changes

Adds the required infrastructure to autotune conv2d and conv_transpose2d, as well as adding a second algorithm based on im2col which provides significant speedups at the cost of memory usage. I'd like to add more algorithms when I have time, but this will already put in the infrastructure to make that much easier and less breaking.

Update; Now includes implicit GEMM when:

  • CMMA is available
  • batch_size * out_h * out_w is divisible by 16
  • out_channels is divisible by 16
  • in_channels * kernel_h * kernel_w is divisible by 16

Testing

The autotuned, direct (current) and im2col implementations pass all existing tests (except matches_reference_backend, see below).

Benchmarking

Conv2d

Direct

batch_size = 4

Benchmark Feature Backend Device Median
conv2d wgpu jit<wgpu> BestAvailable 8.688ms
conv2d wgpu-fusion fusion<jit<wgpu>> BestAvailable 8.351ms
conv2d cuda-jit jit<cuda> CudaDevice { index: 0 } 6.250ms

batch_size = 16

Benchmark Feature Backend Device Median
conv2d wgpu jit<wgpu> BestAvailable 31.056ms
conv2d wgpu-fusion fusion<jit<wgpu>> BestAvailable 32.943ms
conv2d cuda-jit jit<cuda> CudaDevice { index: 0 } 26.165ms

Im2col

batch_size = 4

Benchmark Feature Backend Device Median
conv2d wgpu jit<wgpu> BestAvailable 4.831ms
conv2d wgpu-fusion fusion<jit<wgpu>> BestAvailable 4.759ms
conv2d cuda-jit jit<cuda> CudaDevice { index: 0 } 2.617ms

batch_size = 16 (split into 4 sub-batches)

Benchmark Feature Backend Device Median
conv2d wgpu jit<wgpu> BestAvailable 21.284ms
conv2d wgpu-fusion fusion<jit<wgpu>> BestAvailable 21.207ms
conv2d cuda-jit jit<cuda> CudaDevice { index: 0 } 17.202ms

Implicit GEMM

batch_size = 16

Benchmark Feature Backend Device Median
conv2d cuda-jit jit<cuda> CudaDevice { index: 0 } 4.859ms

Copy link

codecov bot commented Sep 17, 2024

Codecov Report

Attention: Patch coverage is 49.16733% with 641 lines in your changes missing coverage. Please review.

Project coverage is 85.21%. Comparing base (2c8514c) to head (3cfce1b).
Report is 1 commits behind head on main.

Files with missing lines Patch % Lines
...s/burn-jit/src/kernel/conv/conv2d/implicit_gemm.rs 7.25% 422 Missing ⚠️
crates/burn-jit/src/kernel/conv/conv2d/im2col.rs 53.13% 112 Missing ⚠️
crates/burn-jit/src/kernel/conv/conv2d/col2im.rs 62.91% 89 Missing ⚠️
crates/burn-jit/src/kernel/conv/conv2d/base.rs 77.55% 11 Missing ⚠️
crates/burn-jit/src/fusion/tracing/builder.rs 0.00% 3 Missing ⚠️
crates/burn-jit/src/kernel/conv/conv2d/direct.rs 87.50% 2 Missing ⚠️
crates/burn-jit/src/tune_key.rs 0.00% 2 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #2287      +/-   ##
==========================================
- Coverage   85.67%   85.21%   -0.47%     
==========================================
  Files         760      766       +6     
  Lines       99082   100293    +1211     
==========================================
+ Hits        84888    85462     +574     
- Misses      14194    14831     +637     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

Copy link
Member

@nathanielsimard nathanielsimard left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, waiting for @louisfd to review before merging, but thanks a lot 🙏

Copy link
Member

@louisfd louisfd left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have a few questions, but overall it's very good, awesome!

crates/burn-jit/src/kernel/conv/conv2d/base.rs Outdated Show resolved Hide resolved
crates/burn-jit/src/kernel/conv/conv2d/base.rs Outdated Show resolved Hide resolved
crates/burn-jit/src/kernel/conv/conv2d/im2col.rs Outdated Show resolved Hide resolved
crates/burn-jit/src/kernel/conv/conv2d/implicit_gemm.rs Outdated Show resolved Hide resolved
crates/burn-jit/src/kernel/conv/conv2d/implicit_gemm.rs Outdated Show resolved Hide resolved
crates/burn-jit/src/kernel/conv/conv2d/implicit_gemm.rs Outdated Show resolved Hide resolved
crates/burn-jit/src/kernel/conv/conv2d/implicit_gemm.rs Outdated Show resolved Hide resolved

/**************************** Bounds Check + CMMA Op*********************************/
if a_row < gemm_m && k < gemm_k && b_col < gemm_n {
cmma::load(&matrix_a, input_tile.as_slice(), CMMA_K);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When doing the matmul, i found you can safely reuse a fragment for several executions. So this line could go in the outer loop.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure what you mean? I am reusing the fragments, but I still need to load a tile and execute the matmul for each k.

crates/burn-jit/src/kernel/conv/conv2d/col2im.rs Outdated Show resolved Hide resolved
Copy link
Member

@louisfd louisfd left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@louisfd louisfd merged commit 97af8c6 into tracel-ai:main Sep 23, 2024
11 checks passed
@wingertge wingertge deleted the im2col branch September 23, 2024 20:32
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants