-
Notifications
You must be signed in to change notification settings - Fork 29
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
SpeedyTransforms: Differentiable transforms via custom Enzyme rule #589
base: main
Are you sure you want to change the base?
Conversation
The forward and backward transforms are differentiable via Enzyme! @milankl This might be a good point for you to already have a brief look at it. Especially at what I wrote about how we want to organise this code and testing going forward |
There seem to be some issue running those tests in the CI that I don't have on my computer. All tests are passing on my laptop (except for the EnzymeTestUtils one) |
With the last commit: You think we can have all Enzyme functionality as an extension? Happy to try, but also loading/precompiling it is not as expensive as CUDA I believe? |
I think so. Because the only thing we'd need Enzyme for directly in the package is custom rules. And I don't think we need that many. If we are lucky, those two might even be the only ones (because they are calling the only non-Julia code). Enzyme.jl installs the C++ Enzyme library. So I thought, maybe good to move it to an extension when it's not needed for every user. |
|
||
# compute the actual vjp | ||
dfnorthval = zero(f_north.val) | ||
dfsouthval = zero(f_south.val) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Because these allocate: Is that because we can't reuse the forward memory for the gradient?
|
||
# seed | ||
dspecs = zero(specs) | ||
fill!(dspecs, 1+1im) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What does "seed" mean in this context and why is it 1+i?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We do a reverse mode differentiation here. So you propagate the value of the output in reverse direction to determine the sensitivity of output wrt its inputs. The seed is an example output so to say. If you set it to one, you do actually compute directly the gradient of the function. The reverse mode computes the Jacobian vector product. So with a one seed, you just get the gradient. That's only a very short explanation, should I send you some more recourses about that?
As I wrote above in the big post, differentiating complex numbers is a bit of a special topic in itself. In our case we treat the imaginary and real part kind of like separate numbers, so setting both to one made sense to me intuitively. But differentiating complex numbers is something we should still think about a bit.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes please! Just checking, because I think of
julia> z = exp(im*π/4)
0.7071067811865476 + 0.7071067811865475im
julia> abs(z)
1.0
more as complex "1" than 1 + 1im
because it's actually a unit vector
|
||
fd_jvp = FiniteDifferences.j′vp(central_fdm(5,1), x -> transform(x, S), dgrid2, specs) | ||
|
||
@test isapprox(dspecs, fd_jvp[1]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We could also test that the roundtrip transform is identity no?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What roundtrip exactly? I am not sure I understand. You could differentiate through
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Vanish=1 no? It's d(S^-1(S(x)))/dx = dx/dx = 1 I thought? And we could test this starting in either space?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah yes sure, you are right.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I added this as tests.
Starting in spectral space it works very well
Starting in grid space, the result is sort of close to 1, but not quite. Currently I am not sure what goes wrong there.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The test starting in grid space is everywhere 0.961122
instead of 1
. Looks like a normalisation issue. I'll look into this later
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am not surprised. This depends on what data you start with in grid space, which can hold more information than spectral space, so the first roundtrip is usually lossy. But subsequent ones should be fine. I would start in spectral space, transform to grid space then start the test grid->spectral->grid there
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We already do that in the tests. But in principal you are right, there's something wrong with the setup for this test. Because the finite difference differentiation yields the same gradient as Enzyme. So it's not an issue of the custom rule.
The CI is passing for Julia 1.10. There are some problems with Enzyme for Julia 1.11, but they are working on it. |
@vchuravy we (well, Max!) are making some progress to differentiate through SpeedyWeather! |
@swilliamson7 this might be also relevant for you, let us know if you have any more insights / general wisdom!! |
Looking through the thread, but rather exciting for SpeedyWeather.jl! In addition to Valentin, maybe @wsmoses would be interested in seeing this? |
I had a very quick check if some non-sensical differentiation through a timestep of a barotropic model is possible. I get an error from Enzyme and the compiler from the |
We want to make the spherical harmonics transforms differentiable.
The complete model is supposed to rely on Enzyme for that, Enzyme can deal with many of our functions already by itself in principle. However, the FFTs are not computed by Julia code, but with FFTW and cuFFT. Transforms are also the most performance critical part of the model, so optimising the gradients of the transforms also makes sense.
Enzyme
_legendre!
Enzyme can handle by itself_fourier!
, the forward version usesrfft
plans, the inversebrfft
plans._fourier!
isn't normalised, the normalisation is done in_legendre!
.brfft
). For real-valued FFT you have an additional scale in the adjoint reflecting the different coefficients in the formula (there's a two in front of the cos/sine terms)ChainRules/Zygote/…
For uses outside of our model (e.g. Spherical FNOs) it would be great to also have this differentiable by other ADs. These ADs usually don’t support mutation, so that I would define rules for the non-mutating
transform(grid, S)
. I’ll do that a bit later. This might also need some changes to the code.This PR
_fourier!
and tests it.Testing
Complex Numbers
Differentiating complex numbers is a bit of a topic for itself. Enzyme also has a quite long explanation of their approach in their FAQs (https://enzymead.github.io/Enzyme.jl/stable/faq/#Complex-numbers). I think for us, we can treat the real and imaginary part basically as separate numbers, as they would be for real SPH (for which coefficients with m<0 correspond to our imaginary part (sort of)). I am not quite sure yet if this works well out of the box with Enzyme or if we have to be more carefully somewhere. It's something we should think about.