Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Do not interweave transform #587

Merged
merged 27 commits into from
Oct 21, 2024
Merged

Do not interweave transform #587

merged 27 commits into from
Oct 21, 2024

Conversation

milankl
Copy link
Member

@milankl milankl commented Oct 11, 2024

  • introduces SpectralTransform.scratch_memory_north and _south for Legendre-yet-Fourier-transformed fields and vice versa
  • de-interweaves the transforms: instead of LT, FT, next ring, LT, FT, next ring, we do first all LTs then all FTs

The scratch memory (needed because with an interweaved transform you can reuse small vectors, now we need to store those half-transformed fields somewhere) adds a little (octahedral Gaussian grid, dealiasing of 2)

Resolution Grid (MB) Spec (MB) Legendre polynomials (MB) Scratch memory (MB) Added (%)
T31L8 0.2 0.04 0.1 0.35 103
T127L8 10 0.5 6.45 4.94 29
T127L32 39.5 2.1 6.45 19.8 40
T511L32 611 34 405 305 29

The problem with this scratch memory is that currently you can create a SpectralTransform and use it for any numbers of layers (because it’s the same thing you do on every layer) but with the memory if you allocate for 3 layers and then you want to transform 8 layers, what do you do? Two options

  • make SpectralTransform mutable and replace scratch memory with bigger memory in that case
  • return an error, a SpectralTransform for 3 layers can only be used for <=3 layers

@jackleland @maximilian-gelbrecht what do you think?

@milankl milankl added the transform ⬅️ ➡️ Our spherical harmonic transform, grid to spectral, spectral to grid label Oct 11, 2024
@milankl milankl self-assigned this Oct 11, 2024
@jackleland
Copy link
Collaborator

Am I right in thinking it's basically a trade-off between flexibility (3 -> 8 layers easily) and ease of kernel-ising the transform? We could feasibly try both ways and then make a decision based on which provides the greater speed-up

@jackleland
Copy link
Collaborator

If having both is as easy as making the SpectralTransform mutable, why wouldn't we just do that?

@milankl
Copy link
Member Author

milankl commented Oct 11, 2024

CI failing for various small reasons but de-interweaving the transforms and batching the Fourier transforms in the vertical (no batching for Legendre for now) is already 30% faster 🚀 . Transforming T31L8 on main

julia> @btime transform!($grid, $spec, $S);
  267.078 μs (972 allocations: 50.16 KiB)

on this branch

julia> @btime transform!($grid, $spec, $S);
  193.818 μs (248 allocations: 9.12 KiB)

@milankl
Copy link
Member Author

milankl commented Oct 12, 2024

If having both is as easy as making the SpectralTransform mutable, why wouldn't we just do that?

Yes I think in the end it's just a performance-flexibility trade-off and something we can still experiment with and change later. I'll go with option 1 for now and we can see whether option 2 makes a difference. First checks said that a mutable transform isn't slower but we should also experiment with type stability in its fields.

@maximilian-gelbrecht
Copy link
Member

maximilian-gelbrecht commented Oct 14, 2024

Nice PR! I'll do a more thorough review later

Am I seeing it correctly that currently with the non-mutable struct, it would work with less layers (e.g. 8->3) which includes 2D transforms (8 -> 1)? Just not with more layers?

The flexibility with regards to number of layers isn't needed for the SpeedyWeather model itself as far as I can see it. Having some flexibility with regards to the number of layers is nice for uses outside of the SpeedyWeather model itself though, e.g for ML tasks like Spherical Fourier Neural Operators. But if you can just set an upper limit for the number of layers and everything below that still works, I think that's also fine in practice.

Initially I thought that making structs mutable is bad in most situation, but now I am just confused when that's the case and when it isn't. So, it'd need proper benchmarking.

@milankl
Copy link
Member Author

milankl commented Oct 14, 2024

Am I seeing it correctly that currently with the non-mutable struct, it would work with less layers (e.g. 8->3) which includes 2D transforms (8 -> 1)? Just not with more layers?

Yes. Although in practice it currently may not work in 2D but that should be just an indexing generalisation issue. I'm on it.

The flexibility with regards to number of layers isn't needed for the SpeedyWeather model itself as far as I can see it.

Yes, no it's not needed.

if you can just set an upper limit for the number of layers and everything below that still works, I think that's also fine in practice.

I agree, and if that's an issue then we could also get the old transform back as transform_interweaved! and if nlayers > spectral_transform.nlayers just call that version instead?

Initially I thought that making structs mutable is bad in most situation, but now I am just confused when that's the case and when it isn't. So, it'd need proper benchmarking.

I still find it difficult to have a rule-of-thumb when mutability is performance harming. As far as I understand it now it's mostly about "small" structs. Say you put an integer and a float in one struct. If the struct is immutable the compiler can create many copies of it, likely in some lower level caches important for performance. However, if the struct is mutable then its state needs to be "synchronized" and it effectively needs to live as a single copy in memory such that reading it always means reading from memory rather than from some faster caches.

With many of our structs they are essentially just containers for other mutable (!) types. So a spectral transform storing an array (say the Legendre transforms) that array needs to live in memory anyway and it doesn't make a difference whether the struct is mutable or not.

I'm not 100% sure that intuitive explanation is always the case, but that's my mental construct of mutable vs immutable.

@maximilian-gelbrecht
Copy link
Member

Yes, this makes sense.

I think for now it makes sense to keep it as it is (so immutable) as this is enough for our model and most other tasks. But we can look into this again later, together with e.g. 4D transforms that are also useful for things like SFNOs but not needed for our model.

@milankl
Copy link
Member Author

milankl commented Oct 14, 2024

4D transforms that are also useful for things

Yeah, exactly, I'm currently hardcoding things for 2D/3D but I'd suggest for other 4D dimensions we'd then loop again over 3D blocks because easier than to allocate all that memory as well.

@maximilian-gelbrecht
Copy link
Member

In terms of performance on GPU just adding a trailing batch dimension and allowing a big allocation might actually not be that bad. Big ML models also constantly allocate huge tensors. Especially as it's not an application for our model, it'd try the simple thing first and see how it does. But I can also take a look at both possibilities after we got the 2D/3D version going.

src/SpeedyTransforms/spectral_transform.jl Show resolved Hide resolved
src/SpeedyTransforms/spectral_transform.jl Outdated Show resolved Hide resolved
src/SpeedyTransforms/spectral_transform.jl Outdated Show resolved Hide resolved
src/SpeedyTransforms/spectral_transform.jl Outdated Show resolved Hide resolved
@milankl
Copy link
Member Author

milankl commented Oct 14, 2024

It's tricky to find the fastest dot product for the Legendre transform. Because in the end it's a weird one: Inputs are complex and real vector, fused multiply-adds are allowed, but then stride is 2 not one because of the odd-even splitting we need. I'm only getting this fast with handwritten Julia code that seems to compile better to my intel macbook, with some support from @giordano

@inline function fused_oddeven_dot(a::AbstractVector, b::AbstractVector)
    odd  = zero(eltype(a))      # dot prodcut with elements 1, 3, 5, ... of a, b
    even = zero(eltype(a))      # dot product with elements 2, 4, 6, ... of a, b
    
    n = length(a)
    n_even = n - isodd(n)       # if n is odd do last odd element after the loop
    
    @inbounds for i in 1:2:n_even
        odd = muladd(a[i], b[i], odd)
        even = muladd(a[i+1], b[i+1], even)
    end
    
    # now do the last element if n is odd
    odd = isodd(n) ? muladd(a[end], b[end], odd) : odd
    return odd, even
end

the trick here (which we already had before but not put into its own function) is that the even/odd dot product we need is looped over simultaneously, presumably easing cache reads/writes. This is ~2x faster than any other dot product I could find despite the odd-even splitting, e.g.

julia> a = rand(ComplexF32, 1000);

julia> b = rand(Float32, 1000);

julia> @btime LinearAlgebra.dot($a, $b);
  1.207 μs (0 allocations: 0 bytes)

julia> @btime fused_oddeven_dot($a, $b);
  590.227 ns (0 allocations: 0 bytes)

julia> @btime *($(a'), $b);
  1.375 μs (0 allocations: 0 bytes)

@samhatfield you may find that interesting. I don't know how you optimised your dot product but I'm not sure how widely available is support for this weird kind of dot product is in any kind of blas/lapack libraries.

I'll keep experimenting with this and also check whether the 3D generalisation has the same problem in standard libraries. Because the dot product is then turned into a matrix-vector multiplication, with a complex matrix and a real vector but the odd-even splitting still applies. So maybe here again just writing good Julia code and checking that LLVM turns this into good code might be the better approach than to rely on some external libraries.

@milankl
Copy link
Member Author

milankl commented Oct 14, 2024

Oh and now the default T31L8 transform with new benchmarks is about 43% faster, new ~200μs, old ~350μs

julia> @benchmark transform!($grid, $spec, $S)
BenchmarkTools.Trial: 10000 samples with 1 evaluation.
 Range (min  max):  197.604 μs  695.913 μs  ┊ GC (min  max): 0.00%  0.00%
 Time  (median):     250.907 μs               ┊ GC (median):    0.00%
 Time  (mean ± σ):   259.209 μs ±  43.481 μs  ┊ GC (mean ± σ):  0.00% ± 0.00%

        ▇ ▂▂ █ ▁▆  ▃                                             
  ▂▁▂▂█▄████▆████▅▇█▆▃▆▆▃▃▄▃▂▂▃▃▂▁▂▂▂▂▁▁▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁ ▃
  198 μs           Histogram: frequency by time          439 μs <

 Memory estimate: 3.00 KiB, allocs estimate: 48.

and old

julia> @benchmark transform!($grid, $spec, $S)
BenchmarkTools.Trial: 7834 samples with 1 evaluation.
 Range (min  max):  348.337 μs  77.411 ms  ┊ GC (min  max): 0.00%  0.00%
 Time  (median):     513.208 μs              ┊ GC (median):    0.00%
 Time  (mean ± σ):   627.380 μs ±  1.127 ms  ┊ GC (mean ± σ):  1.09% ± 1.90%

    ▂██▇▅▃▁
  ▂▄████████▅▆▅▅▃▄▃▃▃▂▂▂▂▂▂▂▂▂▁▁▁▂▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁ ▂
  348 μs          Histogram: frequency by time         1.72 ms <

 Memory estimate: 49.16 KiB, allocs estimate: 972.

@milankl
Copy link
Member Author

milankl commented Oct 14, 2024

@maximilian-gelbrecht could you do me a favour here and check those benchmarks ☝🏼 with your m3 too?

@maximilian-gelbrecht
Copy link
Member

maximilian-gelbrecht commented Oct 15, 2024

@maximilian-gelbrecht could you do me a favour here and check those benchmarks ☝🏼 with your m3 too?

Sure:

julia> @btime LinearAlgebra.dot($a, $b);
  658.594 ns (0 allocations: 0 bytes)

julia> @btime fused_oddeven_dot($a, $b);
  361.310 ns (0 allocations: 0 bytes)

julia> @btime *($(a'), $b);
  659.683 ns (0 allocations: 0 bytes)
using SpeedyWeather, BenchmarkTools 

spectral_grid = SpectralGrid()
S = SpectralTransform(spectral_grid)
#S = SpectralTransform(spectral_grid.NF, spectral_grid.Grid, spectral_grid.trunc+1, spectral_grid.trunc)

grid = rand(spectral_grid.Grid{spectral_grid.NF}, spectral_grid.nlat_half, spectral_grid.nlayers)
spec = rand(LowerTriangularArray{spectral_grid.NF}, spectral_grid.trunc+2, spectral_grid.trunc+1, spectral_grid.nlayers)

@benchmark transform!($grid, $spec, $S)

On mk/transform:

BenchmarkTools.Trial: 10000 samples with 1 evaluation.
 Range (min  max):  72.875 μs  121.250 μs  ┊ GC (min  max): 0.00%  0.00%
 Time  (median):     79.500 μs               ┊ GC (median):    0.00%
 Time  (mean ± σ):   78.621 μs ±   2.751 μs  ┊ GC (mean ± σ):  0.00% ± 0.00%

                                  ▄▄█▄▇▅▂▄▂▂                    
  ▁▁▁▃▄▆▇▆▇▆▇▆▆▆▄▅▄▄▃▂▂▂▂▂▁▁▁▁▁▂▃▆███████████▆▆▄▄▃▃▂▂▂▂▂▂▁▁▁▁▁ ▃
  72.9 μs         Histogram: frequency by time         84.3 μs <

 Memory estimate: 3.00 KiB, allocs estimate: 48.

On main:

BenchmarkTools.Trial: 10000 samples with 1 evaluation.
 Range (min  max):  269.000 μs   75.545 ms  ┊ GC (min  max): 0.00%  99.56%
 Time  (median):     292.625 μs               ┊ GC (median):    0.00%
 Time  (mean ± σ):   302.090 μs ± 756.848 μs  ┊ GC (mean ± σ):  2.80% ±  2.02%

                    █▃▃▁▂▁▂▁                                     
  ▃▄▆▅▄▃▃▃▃▃▂▂▂▂▂▂▂▅████████▅▅▄▄▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▁▂ ▃
  269 μs           Histogram: frequency by time          337 μs <

 Memory estimate: 49.19 KiB, allocs estimate: 970.

Inverse fast Fourier transform of Legendre-transformed inputs `g_north` and `g_south` to be stored in `grids`.
Not to be called directly, use `transform!` instead."""
function _fourier!(
grids::AbstractGridArray, # gridded output
Copy link
Member

@maximilian-gelbrecht maximilian-gelbrecht Oct 15, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we want to do this as a dispatch based on the array type, or do we want to add the device type to the SpectralTransform struct to make the dispatch based on that?

For me, the second option sounds better, I think.

The scratch memory also has to be on the correct device, so it's probably unavoidable to have a device type in the constructor.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree, the SpectralTransform should decide where to compute things, and then we can probably even have convenience functions around it so that a CPU array is moved to the GPU e.g.

@milankl
Copy link
Member Author

milankl commented Oct 15, 2024

With the last commits I somehow managed to get it another 25% faster, it's a lot about manually tweaking memory access patterns and avoid branching. We now have a vertically-batched Legendre transform, which is a little faster to the -- now also more optimized -- non-batched version.

julia> @benchmark transform!($grid, $spec, $S)
BenchmarkTools.Trial: 10000 samples with 1 evaluation.
 Range (min  max):  143.652 μs  984.251 μs  ┊ GC (min  max): 0.00%  0.00%
 Time  (median):     178.492 μs               ┊ GC (median):    0.00%
 Time  (mean ± σ):   185.457 μs ±  33.265 μs  ┊ GC (mean ± σ):  0.00% ± 0.00%

     ▂ █ ▆  ▃ █  ▃  ▁                                            
  ▁▂▂█▃█▄█▇▇█▄█▅▅█▄▇█▃▃▆▃▂▅▃▂▂▄▂▂▂▂▂▁▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁ ▃
  144 μs           Histogram: frequency by time          314 μs <

Makes me happy to think that all of that optimization likely is turned immediately to a faster model 🚀

@jackleland the code here is now much closer to what the GPU version should look like. transform! is clearly split into _legendre! and _fourier! with some scratch memory being passed between them. Those should have then GPU versions which we can dispatch to given the parameters of the transform (as in parameter of the Julia type, i.e. SpectralTransform{NF, ArrayType})

@maximilian-gelbrecht
Copy link
Member

maximilian-gelbrecht commented Oct 16, 2024

I am thinking to start to implement a rule for Enzyme for the _fourier! function that we'll need due to the FFT being computing by FFTW (and later by CUFFT). For Enzyme _legendre might not need a manual rule, for ChainRules/Zygote it's a different story.

For my understanding: currently both Fourier transforms are completely unnormalised as we use rfft and brfft plans here, correct? So we apply all needed normalisation together with the quadrature weights in the _legendre! function.

@milankl
Copy link
Member Author

milankl commented Oct 16, 2024

For my understanding: currently both Fourier transforms are completely unnormalised as we use rfft and brfft plans here, correct? So we apply all needed normalisation together with the quadrature weights in the _legendre! function.

Yes, convention is that ffts are scaled on the inverse, so ifft would scale, fft would not. Similarly rfft does not apply any scaling, and we use here the brfft for the inverse which isn't scaled. But it's a good question where this $1/N$ scaling actually went to, I believe it's in the quadrature weights, which are generally speaking the solid angles of a grid cell covered, and so there is a division by the number of points there, true.

@milankl
Copy link
Member Author

milankl commented Oct 16, 2024

okay, 3D transforms work in both directions, current timings are

julia> @btime transform!($spec, $grid, $S);
  216.471 μs (96 allocations: 6.00 KiB)

julia> @btime transform!($grid, $spec, $S);
  159.764 μs (96 allocations: 6.75 KiB)

FFTW being a little annoying as it always requires constant stride inputs and stride-1 outputs (@maximilian-gelbrecht I think you mentioned that?) so there's a bunch of scratch memory views and copying back into strides in _fourier! especially in the forward transform that requires it for both inputs and outputs.

FFTW also being a bitch about applying a planned fft for 8 layers to 1 layer, so we'll probably need to preplan also 1D ffts that are applied in that case instead. The Legendre transforms work like a charm though!

@maximilian-gelbrecht
Copy link
Member

Yeah, it's annoying. I made a GitHub issue about that a year ago already in AbstractFFTs.jl, but there seems to be little interest in it. A more general behaviour was already implement for the cuFFT Julia bindings, but not for FFTW as it seems.

@milankl
Copy link
Member Author

milankl commented Oct 21, 2024

Tests failed because FFTW complained about an memory alignment problem, a view of an array can have a non-zero alignment depending on the number of elements that the view skips, e.g.

julia> A = rand(3,4,5);

julia> FFTW.alignment_of(A)
0

julia> FFTW.alignment_of(view(A, :, 1, 1))
0

julia> FFTW.alignment_of(view(A, :, 2, 1))
8

because first dimension is 3 and odd, compare to

julia> A = rand(4,4,5);

julia> FFTW.alignment_of(view(A, :, 1, 1))
0

julia> FFTW.alignment_of(view(A, :, 2, 1))
0

@milankl milankl merged commit b25586d into main Oct 21, 2024
5 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
transform ⬅️ ➡️ Our spherical harmonic transform, grid to spectral, spectral to grid
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants