Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SpeedyTransforms: Differentiable transforms via custom Enzyme rule #589

Draft
wants to merge 58 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
58 commits
Select commit Hold shift + click to select a range
366ea11
Do not interweave transform
milankl Oct 11, 2024
42f7e9c
propagate nlayers argument from inputs
milankl Oct 11, 2024
628f529
Do not store recursion factors
milankl Oct 11, 2024
c3a340d
SpectralTransform.Grid now keyword argument
milankl Oct 11, 2024
41eec88
introduce scratch memory for FFT output because strides don't match
milankl Oct 11, 2024
394da96
Legendre shortcut generalised
milankl Oct 14, 2024
56a82e5
with fused_oddeven_dot
milankl Oct 14, 2024
823b66f
Fourier/Legendre into functions, smaller scratch output for FFTW
milankl Oct 15, 2024
e8c01fd
first version of a fused matvec
milankl Oct 15, 2024
41433a3
Batched Legendre transform
milankl Oct 15, 2024
4c47e15
docstrings added
milankl Oct 15, 2024
1b80380
move even +- odd into kernel
milankl Oct 15, 2024
b057f66
ismatching with nlayers check
milankl Oct 15, 2024
873f6d0
grid to spectral versions v0
milankl Oct 15, 2024
7dd5811
ad scale implemented
maximilian-gelbrecht Oct 16, 2024
b5d04c2
SpectralTransform constructors overhauled
milankl Oct 16, 2024
9b8363f
Forward _legendre!
milankl Oct 16, 2024
57539cf
Legendre transforms nlayers flexible
milankl Oct 16, 2024
e14aca9
variables renamed to match harmonics meaning of odd/even
milankl Oct 16, 2024
bd27856
function barrier for fourier_batched, fourier_serial
milankl Oct 16, 2024
594fc8d
tests adapted
milankl Oct 16, 2024
9721f6c
start writing rules
maximilian-gelbrecht Oct 17, 2024
3269b6e
Merge remote-tracking branch 'origin/mk/transform' into mg/transform-…
maximilian-gelbrecht Oct 17, 2024
35055ee
autodiff computing without error
maximilian-gelbrecht Oct 17, 2024
89d355f
docs and changelog
maximilian-gelbrecht Oct 17, 2024
67be3a3
started full transform test (still error)
maximilian-gelbrecht Oct 17, 2024
5702659
autodiff on transform! w/o error but no correctness check
maximilian-gelbrecht Oct 17, 2024
9148ed7
adjust tests (still stuck though)
maximilian-gelbrecht Oct 17, 2024
227c98f
start FiniteDifferencesExtension
maximilian-gelbrecht Oct 18, 2024
df1142e
jvp with FiniteDifferences.jl working
maximilian-gelbrecht Oct 18, 2024
585889f
seperate test env added
maximilian-gelbrecht Oct 18, 2024
58d7e77
size(::LTA, ::Integer)
milankl Oct 18, 2024
39858d0
attempts to fix correctness
maximilian-gelbrecht Oct 18, 2024
6633ccc
fix FFTW alignment problem
milankl Oct 21, 2024
21831d2
Merge branch 'main' into mk/transform
milankl Oct 21, 2024
5a9bac7
speedy transforms docs, recompute legendre removed
milankl Oct 21, 2024
876900e
Merge branch 'mk/transform' of https://github.com/SpeedyWeather/Speed…
milankl Oct 21, 2024
628c860
improve boundscheck error messages
milankl Oct 21, 2024
db4b46c
boundschecks fixed
milankl Oct 21, 2024
450f636
update rules
maximilian-gelbrecht Oct 22, 2024
2cc71d7
JLArray dependency 0.1 instead of 0.1.4
maximilian-gelbrecht Oct 22, 2024
1783662
transform test sort of working
maximilian-gelbrecht Oct 22, 2024
94b760f
inverse transform differentiable
maximilian-gelbrecht Oct 22, 2024
97ddc43
Merge branch 'mk/transform' into mg/transform-enzyme
maximilian-gelbrecht Oct 22, 2024
da88939
Merge branch 'main' into mg/transform-enzyme
maximilian-gelbrecht Oct 22, 2024
a69bfea
add KA to tests dep
maximilian-gelbrecht Oct 22, 2024
b3f7a43
test typo in ad tests
maximilian-gelbrecht Oct 22, 2024
46eb4c8
add NCDatasets to test env
maximilian-gelbrecht Oct 22, 2024
d803e59
update tests for EnzymeTestUtils
maximilian-gelbrecht Oct 22, 2024
ecf76bf
Add abstractffts to test env
maximilian-gelbrecht Oct 22, 2024
7339f1c
Enzyme rules as extension
maximilian-gelbrecht Oct 23, 2024
acee592
temporarily deactivate 1.11 CI
maximilian-gelbrecht Oct 23, 2024
f9d444a
add transform derivative identitiy test
maximilian-gelbrecht Oct 23, 2024
660c675
deactivate stalling EnzymeTestUtils
maximilian-gelbrecht Oct 23, 2024
ec3af2e
add addtional truncation
maximilian-gelbrecht Oct 23, 2024
bf27c17
update to tests
maximilian-gelbrecht Oct 23, 2024
e5a6bd1
only one grid for CI
maximilian-gelbrecht Oct 23, 2024
fc4a793
additional FD test for identity
maximilian-gelbrecht Oct 24, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ jobs:
matrix:
version:
- '1.10'
- '1.11'
os:
- ubuntu-latest
arch:
Expand Down
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

## Unreleased

- Added custom EnzymeRules for the SpectralTransform and an extension for compatibility with FiniteDifferences.jl [#589](https://github.com/SpeedyWeather/SpeedyWeather.jl/pull/589)
- Also allow SpectralGrid as positional argument to model constructors [#593](https://github.com/SpeedyWeather/SpeedyWeather.jl/pull/593)
- De-interweave SpectralTransform [#587](https://github.com/SpeedyWeather/SpeedyWeather.jl/pull/587)
- Rossby-Haurwitz wave initial conditions [#591](https://github.com/SpeedyWeather/SpeedyWeather.jl/pull/591)
Expand Down
17 changes: 8 additions & 9 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,11 @@ BitInformation = "de688a37-743e-4ac2-a6f0-bd62414d1aa7"
CodecZlib = "944b1d66-785c-5afd-91f1-9de20f533193"
Dates = "ade2ca70-3891-5945-98fb-dc099432e06a"
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
EnzymeTestUtils = "12d8515a-0907-448a-8884-5fe00fdf1c5a"
FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341"
FastGaussQuadrature = "442a2c76-b920-505d-bb47-c5924d526838"
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7"
GenericFFT = "a8297547-1b15-4a5a-a998-a2ac5f1cef28"
JLD2 = "033835bb-8acc-5ee8-8aae-3f567f8a3819"
Expand All @@ -29,11 +32,15 @@ UnicodePlots = "b8865327-cd53-5732-bb35-84acbb429228"

[weakdeps]
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
JLArrays = "27aeb0d3-9eb9-45fb-866b-73c2ecf80fcb"
Makie = "ee78f7c6-11fb-53f2-987a-cfe4a2b5a57a"

[extensions]
SpeedyWeatherCUDAExt = "CUDA"
SpeedyWeatherEnzymeExt = "Enzyme"
SpeedyWeatherFiniteDifferencesExt = "FiniteDifferences"
SpeedyWeatherJLArraysExt = "JLArrays"
SpeedyWeatherMakieExt = "Makie"

Expand All @@ -50,7 +57,7 @@ FFTW = "1"
FastGaussQuadrature = "0.4, 0.5, 1"
GPUArrays = "10"
GenericFFT = "0.1"
JLArrays = "0.1.4"
JLArrays = "0.1"
JLD2 = "0.4, 0.5"
KernelAbstractions = "0.9"
LinearAlgebra = "1.10"
Expand All @@ -64,11 +71,3 @@ Statistics = "1.10"
TOML = "1"
UnicodePlots = "3.3"
julia = "1.10"

[extras]
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
JLArrays = "27aeb0d3-9eb9-45fb-866b-73c2ecf80fcb"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["Test", "JLArrays", "CUDA"]
115 changes: 115 additions & 0 deletions ext/SpeedyWeatherEnzymeExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
module SpeedyWeatherEnzymeExt

using SpeedyWeather
using Enzyme
import .EnzymeRules: reverse, augmented_primal
using .EnzymeRules

# import all functions for which we define rules
import SpeedyWeather.SpeedyTransforms: _fourier!

# Rules for SpeedyTransforms

# _fourier!

# Computes the scale for the adjoint/pullback of all discrete Fourier transforms.
function adjoint_scale(S::SpectralTransform)
(; nlat_half, nlons, rfft_plans) = S
nfreqs = [rfft_plan.osz[1] for rfft_plan in rfft_plans] # TODO: This works with FFTW, but does it with cuFFT as well?

scale = zeros(Int, maximum(nfreqs), nlat_half)

for i=1:nlat_half
scale[1:nfreqs[i],i] = rfft_adjoint_scale(nfreqs[i], nlons[i])
end

# TODO: transfer array to GPU in case we are on GPU
return reshape(scale, maximum(nfreqs), 1, nlat_half) # the scratch memory is (Freq x lvl x lat), so we insert
# an additional dimension here for easier matrix multiply
end

# Computes the scale for the adjoint/pullback of a real discrete fourier transform.
function rfft_adjoint_scale(n_freq::Int, n_real::Int)
if iseven(n_real)
return [1; [2 for i=2:(n_freq-1)]; 1]
else
return [1; [2 for i=2:n_freq]]
end
end

### Custom rule for _fourier!(f_north, f_north, grid, S)
function augmented_primal(config::EnzymeRules.RevConfigWidth{1}, func::Const{typeof(_fourier!)}, ::Type{<:Const},
f_north::Duplicated, f_south::Duplicated, grids::Duplicated{<:AbstractGridArray}, S::Union{Const, MixedDuplicated})

func.val(f_north.val, f_south.val, grids.val, S.val) # forward pass

# save grids in tape if grids will be overwritten
if overwritten(config)[4] # TODO: Not sure this is really necessary because grids won't ever get overwritten by this _fourier!
tape = copy(grids.val)
else
tape = nothing
end

return AugmentedReturn(nothing, nothing, tape) # because the function actually returns nothing

end

function reverse(config::EnzymeRules.RevConfigWidth{1}, func::Const{typeof(_fourier!)}, ::Type{<:Const}, tape,
f_north::Duplicated, f_south::Duplicated, grids::Duplicated{<:AbstractGridArray}, S::Union{Const, MixedDuplicated})
milankl marked this conversation as resolved.
Show resolved Hide resolved

# adjoint/jvp of FFT has a different scaling, compute it, apply it later to f_north, f_south
scale = adjoint_scale(S.val)

# retrieve grids value, either from original grids or from tape if grids may have been overwritten.
gridsval = overwritten(config)[4] ? tape : grids.val

# compute the adjoint
dgridval = zero(gridsval)
_fourier!(dgridval, f_north.dval ./ scale, f_south.dval ./ scale, S.val) # inverse FFT (w/o normalization)
grids.dval .+= dgridval

# no derivative wrt the f_north and f_south that were input because they are overwritten
make_zero!(f_north.dval)
make_zero!(f_south.dval)

# the function has no return values, so we also return nothing here
return (nothing, nothing, nothing, nothing)
end

### Custom rule for _fourier!(grid, f_north, f_south, S)
function augmented_primal(config::EnzymeRules.RevConfigWidth{1}, func::Const{typeof(_fourier!)}, ::Type{<:Const},
grids::Duplicated{<:AbstractGridArray}, f_north::Duplicated, f_south::Duplicated, S::Union{Const, MixedDuplicated})

func.val(grids.val, f_north.val, f_south.val, S.val) # forward pass

# TODO: make an overwritten check here?

return AugmentedReturn(nothing, nothing, nothing) # because the function actually returns nothing

end

function reverse(config::EnzymeRules.RevConfigWidth{1}, func::Const{typeof(_fourier!)}, ::Type{<:Const}, tape,
grids::Duplicated{<:AbstractGridArray}, f_north::Duplicated, f_south::Duplicated, S::Union{Const, MixedDuplicated})

# adjoint/vjp of FFT has a different scaling, compute it, apply it later to f_north, f_south
scale = adjoint_scale(S.val)

# TODO: retrieve from tape here if overwritten?

# compute the adjoint # TODO: could we reuse the f_north.val for that a well? and not allocate here
dfnorthval = zero(f_north.val)
dfsouthval = zero(f_south.val)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because these allocate: Is that because we can't reuse the forward memory for the gradient?


_fourier!(dfnorthval, dfsouthval, grids.dval, S.val) # inverse FFT (w/o normalization)

f_north.dval .+= scale .* dfnorthval
f_south.dval .+= scale .* dfsouthval

# no derivative wrt the grids that were input because they are overwritten
make_zero!(grids.dval)
milankl marked this conversation as resolved.
Show resolved Hide resolved

# the function has no return values, so we also return nothing here
return (nothing, nothing, nothing, nothing)
end

end
30 changes: 30 additions & 0 deletions ext/SpeedyWeatherFiniteDifferencesExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
module SpeedyWeatherFiniteDifferencesExt

using SpeedyWeather
import FiniteDifferences
import FiniteDifferences: to_vec

# FiniteDifferences needs to be able to convert data structures to Vectors and back
# This doesn't work out of the box with our data types, so we'll define those
# conversions here.
function FiniteDifferences.to_vec(x::Grid) where Grid <: AbstractGridArray
x_vec, from_vec = FiniteDifferences.to_vec(Array(x))

function GridArray_from_vec(x_vec)
return Grid(reshape(from_vec(x_vec), size(x)), x.nlat_half)
end

return x_vec, GridArray_from_vec
end

function FiniteDifferences.to_vec(x::LTA) where LTA <: LowerTriangularArray
x_vec, from_vec = FiniteDifferences.to_vec(x.data)

function LowerTriangularArray_from_vec(x_vec)
return LowerTriangularArray(reshape(from_vec(x_vec), size(x)), x.m, x.n)
end

return x_vec, LowerTriangularArray_from_vec
end

end
13 changes: 13 additions & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
[deps]
AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c"
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
Dates = "ade2ca70-3891-5945-98fb-dc099432e06a"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
EnzymeTestUtils = "12d8515a-0907-448a-8884-5fe00fdf1c5a"
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
JLArrays = "27aeb0d3-9eb9-45fb-866b-73c2ecf80fcb"
KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c"
NCDatasets = "85f8d34a-cbdd-5861-8df4-14fed0d494ab"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ using SpeedyWeather
using Test

# GENERAL
include("test_transforms_ad_rules.jl") # TODO: put the test somewhere else (and next to the transforms)
milankl marked this conversation as resolved.
Show resolved Hide resolved
include("utility_functions.jl")
include("dates.jl")
include("lower_triangular_matrix.jl")
Expand Down
Loading