Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SpeedyTransforms: Differentiable transforms via custom Enzyme rule #589

Draft
wants to merge 58 commits into
base: main
Choose a base branch
from

Conversation

maximilian-gelbrecht
Copy link
Member

@maximilian-gelbrecht maximilian-gelbrecht commented Oct 17, 2024

We want to make the spherical harmonics transforms differentiable.

The complete model is supposed to rely on Enzyme for that, Enzyme can deal with many of our functions already by itself in principle. However, the FFTs are not computed by Julia code, but with FFTW and cuFFT. Transforms are also the most performance critical part of the model, so optimising the gradients of the transforms also makes sense.

Enzyme

  • The Legendre transformation _legendre! Enzyme can handle by itself
  • We do Fourier transforms consecutively on each latitude ring in _fourier!, the forward version uses rfft plans, the inverse brfft plans. _fourier! isn't normalised, the normalisation is done in _legendre!.
  • The FFT needs a rule here, the adjoint of a Fourier transform is the unnormalised inverse Fourier transform (so brfft). For real-valued FFT you have an additional scale in the adjoint reflecting the different coefficients in the formula (there's a two in front of the cos/sine terms)
  • Rules are tested with the test tools from Enzyme

ChainRules/Zygote/…

For uses outside of our model (e.g. Spherical FNOs) it would be great to also have this differentiable by other ADs. These ADs usually don’t support mutation, so that I would define rules for the non-mutating transform(grid, S). I’ll do that a bit later. This might also need some changes to the code.

This PR

  • This PR implements the Enzyme Rule for _fourier! and tests it.
  • As far as I can see it, this rule should also work with the GPU version as soon as we have it as it is
  • We could pre-compute the scaling. It’s a relatively small Int valued matrix. That would make the adjoint marginally faster.
  • I also introduce a new extension for FiniteDifferences.jl that makes our data structure compatible with the library
  • I set up a new environment for the tests.
  • Currently I just add Enzyme, EnzymeRules and EnzymeTestUtils to the main env to make testing a bit easier for me while developing this PR, that will be changed before merging this
    • Do we want to do the differentiability test always in the same files as the other tests, or put them in a separate folder?
    • Enzyme (and most of this code) could go in an extension. Is this what we want long-term? Probably yes, but I don’t have a good impression yet how many custom rules we have to define. Hopefully not many though.

Testing

  • I do a manual test comparing with FiniteDifferences of the full spherical harmonics transforms. Those tests pass
  • I'd love to test the rules directly with EnzymeTestUtils as well, but there are several problems currently
    • There's a problem with the FFT plans having an uninitailzied fields. There's a quite hacky way around this though
    • There's a problem with the function having complex valued outputs
    • The tests are stuck, I waited 30 min without a result.
    • I have to look into this a bit further
  • FiniteDifferences tests can take a while. Going forward do we want to put all gradient correctness checks in the regular CI? The four quite simple tests that are currently in this PR already take 5 mins.

Complex Numbers

Differentiating complex numbers is a bit of a topic for itself. Enzyme also has a quite long explanation of their approach in their FAQs (https://enzymead.github.io/Enzyme.jl/stable/faq/#Complex-numbers). I think for us, we can treat the real and imaginary part basically as separate numbers, as they would be for real SPH (for which coefficients with m<0 correspond to our imaginary part (sort of)). I am not quite sure yet if this works well out of the box with Enzyme or if we have to be more carefully somewhere. It's something we should think about.

@maximilian-gelbrecht maximilian-gelbrecht added transform ⬅️ ➡️ Our spherical harmonic transform, grid to spectral, spectral to grid differentiability 🤖 Making the model differentiable via AD labels Oct 17, 2024
@maximilian-gelbrecht maximilian-gelbrecht changed the base branch from mk/transform to main October 22, 2024 12:25
@maximilian-gelbrecht
Copy link
Member Author

The forward and backward transforms are differentiable via Enzyme!
The tests comparing the results to finite differences pass on my laptop, hopefully they will in the CI as well

@milankl This might be a good point for you to already have a brief look at it. Especially at what I wrote about how we want to organise this code and testing going forward

@maximilian-gelbrecht
Copy link
Member Author

maximilian-gelbrecht commented Oct 22, 2024

There seem to be some issue running those tests in the CI that I don't have on my computer. All tests are passing on my laptop (except for the EnzymeTestUtils one)

@milankl
Copy link
Member

milankl commented Oct 23, 2024

With the last commit: You think we can have all Enzyme functionality as an extension? Happy to try, but also loading/precompiling it is not as expensive as CUDA I believe?

@maximilian-gelbrecht
Copy link
Member Author

maximilian-gelbrecht commented Oct 23, 2024

I think so. Because the only thing we'd need Enzyme for directly in the package is custom rules. And I don't think we need that many. If we are lucky, those two might even be the only ones (because they are calling the only non-Julia code).

Enzyme.jl installs the C++ Enzyme library. So I thought, maybe good to move it to an extension when it's not needed for every user.

ext/SpeedyWeatherEnzymeExt.jl Show resolved Hide resolved

# compute the actual vjp
dfnorthval = zero(f_north.val)
dfsouthval = zero(f_south.val)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because these allocate: Is that because we can't reuse the forward memory for the gradient?

ext/SpeedyWeatherEnzymeExt.jl Show resolved Hide resolved
test/runtests.jl Show resolved Hide resolved

# seed
dspecs = zero(specs)
fill!(dspecs, 1+1im)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What does "seed" mean in this context and why is it 1+i?

Copy link
Member Author

@maximilian-gelbrecht maximilian-gelbrecht Oct 23, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We do a reverse mode differentiation here. So you propagate the value of the output in reverse direction to determine the sensitivity of output wrt its inputs. The seed is an example output so to say. If you set it to one, you do actually compute directly the gradient of the function. The reverse mode computes the Jacobian vector product. So with a one seed, you just get the gradient. That's only a very short explanation, should I send you some more recourses about that?

As I wrote above in the big post, differentiating complex numbers is a bit of a special topic in itself. In our case we treat the imaginary and real part kind of like separate numbers, so setting both to one made sense to me intuitively. But differentiating complex numbers is something we should still think about a bit.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes please! Just checking, because I think of

julia> z = exp(im*π/4)
0.7071067811865476 + 0.7071067811865475im

julia> abs(z)
1.0

more as complex "1" than 1 + 1im because it's actually a unit vector

test/test_transforms_ad_rules.jl Show resolved Hide resolved

fd_jvp = FiniteDifferences.j′vp(central_fdm(5,1), x -> transform(x, S), dgrid2, specs)

@test isapprox(dspecs, fd_jvp[1])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could also test that the roundtrip transform is identity no?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What roundtrip exactly? I am not sure I understand. You could differentiate through $\mathcal{S}^{-1}(\mathcal{S}(x))$, but that's just identity, so the derivative wrt the input would just vanish. Or is this exactly what you were thinking of?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Vanish=1 no? It's d(S^-1(S(x)))/dx = dx/dx = 1 I thought? And we could test this starting in either space?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah yes sure, you are right.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added this as tests.

Starting in spectral space it works very well
Starting in grid space, the result is sort of close to 1, but not quite. Currently I am not sure what goes wrong there.

Copy link
Member Author

@maximilian-gelbrecht maximilian-gelbrecht Oct 23, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The test starting in grid space is everywhere 0.961122 instead of 1. Looks like a normalisation issue. I'll look into this later

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not surprised. This depends on what data you start with in grid space, which can hold more information than spectral space, so the first roundtrip is usually lossy. But subsequent ones should be fine. I would start in spectral space, transform to grid space then start the test grid->spectral->grid there

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We already do that in the tests. But in principal you are right, there's something wrong with the setup for this test. Because the finite difference differentiation yields the same gradient as Enzyme. So it's not an issue of the custom rule.

@maximilian-gelbrecht
Copy link
Member Author

The CI is passing for Julia 1.10.

There are some problems with Enzyme for Julia 1.11, but they are working on it.

@milankl
Copy link
Member

milankl commented Oct 23, 2024

@vchuravy we (well, Max!) are making some progress to differentiate through SpeedyWeather!

@milankl
Copy link
Member

milankl commented Oct 23, 2024

@swilliamson7 this might be also relevant for you, let us know if you have any more insights / general wisdom!!

@swilliamson7
Copy link

Looking through the thread, but rather exciting for SpeedyWeather.jl! In addition to Valentin, maybe @wsmoses would be interested in seeing this?

@maximilian-gelbrecht
Copy link
Member Author

I had a very quick check if some non-sensical differentiation through a timestep of a barotropic model is possible.

I get an error from Enzyme and the compiler from the _divergence! function. That's already a bit what I suspected. After the transform itself, the spatial gradient functions are the other challenging bit. I'll look into differentiating them again, when we also make steps towards GPU-ifying those bits. That should be next on the agenda after the transforms.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
differentiability 🤖 Making the model differentiable via AD transform ⬅️ ➡️ Our spherical harmonic transform, grid to spectral, spectral to grid
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants