Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Chain rules for FFT plans via AdjointPlans #67

Merged
merged 28 commits into from
Jul 5, 2023
Merged
Show file tree
Hide file tree
Changes from 22 commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
ad71816
Implement AdjointPlans
gaurav-arya Jun 9, 2022
c91ad50
Implement chain rules for FFT plans
gaurav-arya Jun 9, 2022
061eef9
Test plan adjoints and AD rules
gaurav-arya Jun 9, 2022
497ff4d
Apply suggestions from adjoint plan code review
gaurav-arya Jun 9, 2022
5d5c06c
Include irrft_dim in RealInverseProjectionStyle
gaurav-arya Jun 9, 2022
ef84edf
update to new fftdims interface
gaurav-arya Jul 1, 2022
d7ff394
fix broken tests
gaurav-arya Jul 1, 2022
aa8e575
Explicitly don't support mul! for adjoint plans
gaurav-arya Jul 1, 2022
9d99886
Document adjoint plans
gaurav-arya Jul 1, 2022
ac7c78c
remove incorrectly thrown error
gaurav-arya Jul 1, 2022
8474141
Update adjoint plan docs
gaurav-arya Jul 14, 2022
769c090
Update adjoint docs
gaurav-arya Jul 14, 2022
3ed83df
Fix typos
gaurav-arya Jul 14, 2022
552d49f
tweak adjoint doc string
gaurav-arya Jul 14, 2022
1e9ece2
Tweaks to adjoint description
gaurav-arya Jul 15, 2022
8ddfa97
Immutable AdjointPlan
gaurav-arya Jul 16, 2022
87758c8
Add rules and tests for ScaledPlan
gaurav-arya Aug 6, 2022
09b8b38
Apply suggestions from code review
gaurav-arya Aug 16, 2022
d967aa2
More tweaks to address code review
gaurav-arya Aug 16, 2022
2a423e2
Restrict to T<:Real for rfft adjoint
gaurav-arya Aug 16, 2022
eedba14
Get type T correct for test irfft
gaurav-arya Aug 16, 2022
25bb86b
Test complex input when appropriate for adjoint tests
gaurav-arya Aug 16, 2022
2a2d685
Merge remote-tracking branch 'origin/master' into adjoint
gaurav-arya Aug 28, 2022
fe3b06a
Add plan_inv implementation for adjoint plan and test it
gaurav-arya Aug 28, 2022
266c88f
Merge branch 'master' into adjoint
devmotion Jun 30, 2023
403ce47
Apply suggestions from code review
devmotion Jul 4, 2023
e137ae3
Apply suggestions from code review
devmotion Jul 5, 2023
e601347
Test in-place plans
devmotion Jul 5, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,10 @@ julia = "^1.0"

[extras]
ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a"
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d"

[targets]
test = ["ChainRulesTestUtils", "Random", "Test", "Unitful"]
test = ["ChainRulesTestUtils", "FiniteDifferences", "Random", "Test", "Unitful"]
23 changes: 1 addition & 22 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,25 +16,4 @@ This allows multiple FFT packages to co-exist with the same underlying `fft(x)`

## Developer information

To define a new FFT implementation in your own module, you should

* Define a new subtype (e.g. `MyPlan`) of `AbstractFFTs.Plan{T}` for FFTs and related transforms on arrays of `T`.
This must have a `pinv::Plan` field, initially undefined when a `MyPlan` is created, that is used for caching the
inverse plan.

* Define a new method `AbstractFFTs.plan_fft(x, region; kws...)` that returns a `MyPlan` for at least some types of
`x` and some set of dimensions `region`. The `region` (or a copy thereof) should be accessible via `fftdims(p::MyPlan)` (which defaults to `p.region`).

* Define a method of `LinearAlgebra.mul!(y, p::MyPlan, x)` (or `A_mul_B!(y, p::MyPlan, x)` on Julia prior to
0.7.0-DEV.3204) that computes the transform `p` of `x` and stores the result in `y`.

* Define a method of `*(p::MyPlan, x)`, which can simply call your `mul!` (or `A_mul_B!`) method.
This is not defined generically in this package due to subtleties that arise for in-place and real-input FFTs.

* If the inverse transform is implemented, you should also define `plan_inv(p::MyPlan)`, which should construct the
inverse plan to `p`, and `plan_bfft(x, region; kws...)` for an unnormalized inverse ("backwards") transform of `x`.

* You can also define similar methods of `plan_rfft` and `plan_brfft` for real-input FFTs.

The normalization convention for your FFT should be that it computes yₖ = ∑ⱼ xⱼ exp(-2πi jk/n) for a transform of
length n, and the "backwards" (unnormalized inverse) transform computes the same thing but with exp(+2πi jk/n).
To define a new FFT implementation in your own module, see [defining a new implementation](https://juliamath.github.io/AbstractFFTs.jl/stable/implementations/#Defining-a-new-implementation).
1 change: 1 addition & 0 deletions docs/Project.toml
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
[deps]
AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c"
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"

[compat]
Expand Down
1 change: 1 addition & 0 deletions docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ AbstractFFTs.plan_rfft
AbstractFFTs.plan_brfft
AbstractFFTs.plan_irfft
AbstractFFTs.fftdims
Base.adjoint
AbstractFFTs.fftshift
AbstractFFTs.fftshift!
AbstractFFTs.ifftshift
Expand Down
41 changes: 28 additions & 13 deletions docs/src/implementations.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,31 @@ The following packages extend the functionality provided by AbstractFFTs:

## Defining a new implementation

Implementations should implement `LinearAlgebra.mul!(Y, plan, X)` (or
`A_mul_B!(y, p::MyPlan, x)` on Julia prior to 0.7.0-DEV.3204) so as to support
pre-allocated output arrays.
We don't define `*` in terms of `mul!` generically here, however, because
of subtleties for in-place and real FFT plans.

To support `inv`, `\`, and `ldiv!(y, plan, x)`, we require `Plan` subtypes
to have a `pinv::Plan` field, which caches the inverse plan, and which should be
initially undefined.
They should also implement `plan_inv(p)` to construct the inverse of a plan `p`.

Implementations only need to provide the unnormalized backwards FFT,
similar to FFTW, and we do the scaling generically to get the inverse FFT.
To define a new FFT implementation in your own module, you should

* Define a new subtype (e.g. `MyPlan`) of `AbstractFFTs.Plan{T}` for FFTs and related transforms on arrays of `T`.
This must have a `pinv::Plan` field, initially undefined when a `MyPlan` is created, that is used for caching the
inverse plan.

* Define a new method `AbstractFFTs.plan_fft(x, region; kws...)` that returns a `MyPlan` for at least some types of
`x` and some set of dimensions `region`. The `region` (or a copy thereof) should be accessible via `fftdims(p::MyPlan)` (which defaults to `p.region`).

* Define a method of `LinearAlgebra.mul!(y, p::MyPlan, x)` that computes the transform `p` of `x` and stores the result in `y`.

* Define a method of `*(p::MyPlan, x)`, which can simply call your `mul!` method.
This is not defined generically in this package due to subtleties that arise for in-place and real-input FFTs.

* If the inverse transform is implemented, you should also define `plan_inv(p::MyPlan)`, which should construct the
inverse plan to `p`, and `plan_bfft(x, region; kws...)` for an unnormalized inverse ("backwards") transform of `x`.
Implementations only need to provide the unnormalized backwards FFT, similar to FFTW, and we do the scaling generically
to get the inverse FFT.

* You can also define similar methods of `plan_rfft` and `plan_brfft` for real-input FFTs.

* To enable automatic computation of adjoint plans via [`Base.adjoint`](@ref) (used in rules for reverse-mode differentiation), define the trait `AbstractFFTs.ProjectionStyle(::MyPlan)`, which can take values:
devmotion marked this conversation as resolved.
Show resolved Hide resolved
* `AbstractFFTs.NoProjectionStyle()`,
* `AbstractFFTs.RealProjectionStyle()`, for plans which halve one of the output's dimensions analogously to [`rfft`](@ref),
devmotion marked this conversation as resolved.
Show resolved Hide resolved
* `AbstractFFTs.RealInverseProjectionStyle(d::Int)`, for plans which expect an input with a halved dimension analogously to [`irfft`](@ref), where `d` is the original length of the dimension.
devmotion marked this conversation as resolved.
Show resolved Hide resolved

The normalization convention for your FFT should be that it computes yₖ = ∑ⱼ xⱼ exp(-2πi jk/n) for a transform of
devmotion marked this conversation as resolved.
Show resolved Hide resolved
length n, and the "backwards" (unnormalized inverse) transform computes the same thing but with exp(+2πi jk/n).
devmotion marked this conversation as resolved.
Show resolved Hide resolved
37 changes: 37 additions & 0 deletions src/chainrules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -150,3 +150,40 @@ function ChainRulesCore.rrule(::typeof(ifftshift), x::AbstractArray, dims)
end
return y, ifftshift_pullback
end

# plans
function ChainRulesCore.frule((_, _, Δx), ::typeof(*), P::Plan, x::AbstractArray)
y = P * x
Δy = P * Δx
return y, Δy
end
function ChainRulesCore.rrule(::typeof(*), P::Plan, x::AbstractArray)
y = P * x
project_x = ChainRulesCore.ProjectTo(x)
Pt = P'
function mul_plan_pullback(ȳ)
x̄ = project_x(Pt * ȳ)
return ChainRulesCore.NoTangent(), ChainRulesCore.NoTangent(), x̄
devmotion marked this conversation as resolved.
Show resolved Hide resolved
end
return y, mul_plan_pullback
end

function ChainRulesCore.frule((_, ΔP, Δx), ::typeof(*), P::ScaledPlan, x::AbstractArray)
y = P * x
Δy = P * Δx .+ (ΔP.scale / P.scale) .* y
return y, Δy
end
function ChainRulesCore.rrule(::typeof(*), P::ScaledPlan, x::AbstractArray)
y = P * x
Pt = P'
scale = P.scale
project_x = ChainRulesCore.ProjectTo(x)
project_scale = ChainRulesCore.ProjectTo(scale)
function mul_scaledplan_pullback(ȳ)
x̄ = ChainRulesCore.@thunk(project_x(Pt * ȳ))
scale_tangent = ChainRulesCore.@thunk(project_scale(dot(y, ȳ) / conj(scale)))
plan_tangent = ChainRulesCore.Tangent{typeof(P)}(;p=ChainRulesCore.NoTangent(), scale=scale_tangent)
return ChainRulesCore.NoTangent(), plan_tangent, x̄
end
return y, mul_scaledplan_pullback
end
77 changes: 77 additions & 0 deletions src/definitions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ eltype(::Type{<:Plan{T}}) where {T} = T

# size(p) should return the size of the input array for p
size(p::Plan, d) = size(p)[d]
output_size(p::Plan, d) = output_size(p)[d]
ndims(p::Plan) = length(size(p))
length(p::Plan) = prod(size(p))::Int

Expand Down Expand Up @@ -255,6 +256,7 @@ ScaledPlan(p::Plan{T}, scale::Number) where {T} = ScaledPlan{T}(p, scale)
ScaledPlan(p::ScaledPlan, α::Number) = ScaledPlan(p.p, p.scale * α)

size(p::ScaledPlan) = size(p.p)
output_size(p::ScaledPlan) = output_size(p.p)

fftdims(p::ScaledPlan) = fftdims(p.p)

Expand Down Expand Up @@ -576,3 +578,78 @@ Pre-plan an optimized real-input unnormalized transform, similar to
the same as for [`brfft`](@ref).
"""
plan_brfft

##############################################################################

struct NoProjectionStyle end
struct RealProjectionStyle end
struct RealInverseProjectionStyle
dim::Int
end
const ProjectionStyle = Union{NoProjectionStyle, RealProjectionStyle, RealInverseProjectionStyle}

output_size(p::Plan) = _output_size(p, ProjectionStyle(p))
_output_size(p::Plan, ::NoProjectionStyle) = size(p)
_output_size(p::Plan, ::RealProjectionStyle) = rfft_output_size(size(p), fftdims(p))
_output_size(p::Plan, s::RealInverseProjectionStyle) = brfft_output_size(size(p), s.dim, fftdims(p))

struct AdjointPlan{T,P<:Plan} <: Plan{T}
p::P
AdjointPlan{T,P}(p) where {T,P} = new(p)
end

"""
p'
devmotion marked this conversation as resolved.
Show resolved Hide resolved
adjoint(p::Plan)

Form the adjoint operator of an FFT plan. Returns a plan which performs the adjoint operation
devmotion marked this conversation as resolved.
Show resolved Hide resolved
the original plan. Note that this differs from the corresponding backwards plan in the case of real
FFTs due to the halving of one of the dimensions of the FFT output, as described in [`rfft`](@ref).

!!! note
Adjoint plans do not currently support `LinearAlgebra.mul!`. Further, as a new addition to `AbstractFFTs`,
coverage of `Base.adjoint` in downstream implementations may be limited.
"""
Base.adjoint(p::Plan{T}) where {T} = AdjointPlan{T, typeof(p)}(p)
Base.adjoint(p::AdjointPlan) = p.p
# always have AdjointPlan inside ScaledPlan.
Base.adjoint(p::ScaledPlan) = ScaledPlan(p.p', p.scale)

size(p::AdjointPlan) = output_size(p.p)
output_size(p::AdjointPlan) = size(p.p)

Base.:*(p::AdjointPlan, x::AbstractArray) = _mul(p, x, ProjectionStyle(p.p))

function _mul(p::AdjointPlan{T}, x::AbstractArray, ::NoProjectionStyle) where {T}
dims = fftdims(p.p)
N = normalization(T, size(p.p), dims)
return (p.p \ x) / N
end

function _mul(p::AdjointPlan{T}, x::AbstractArray, ::RealProjectionStyle) where {T<:Real}
dims = fftdims(p.p)
N = normalization(T, size(p.p), dims)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't need real here, in contrast to RealInverseProjectionStyle below?

Copy link
Contributor Author

@gaurav-arya gaurav-arya Aug 16, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, that's right, since we should only expect an rfft plan to operate on real arrays. I've added T<:Real to make this clear.

Also, regarding the use of real(T) for the RealInverseProjectionStyle: I could possibly just match AdjointPlan{Complex{T}} and then use T instead of real(T), since we probably should expect an irfft to operate on complex arrays. (The test plans were actually getting T wrong, i.e. T<:Real for the inverse of a rfft, but I've fixed that in eedba14). However, real(T) seems a little safer in case someone ever wants to write a specialized irfft plan that accepts only real inputs.

halfdim = first(dims)
d = size(p.p, halfdim)
n = output_size(p.p, halfdim)
scale = reshape(
[(i == 1 || (i == n && 2 * (i - 1)) == d) ? N : 2 * N for i in 1:n],
ntuple(i -> i == halfdim ? n : 1, Val(ndims(x)))
)
return p.p \ (x ./ scale)
devmotion marked this conversation as resolved.
Show resolved Hide resolved
end

function _mul(p::AdjointPlan{T}, x::AbstractArray, ::RealInverseProjectionStyle) where {T}
dims = fftdims(p.p)
N = normalization(real(T), output_size(p.p), dims)
halfdim = first(dims)
n = size(p.p, halfdim)
d = output_size(p.p, halfdim)
scale = reshape(
[(i == 1 || (i == n && 2 * (i - 1)) == d) ? 1 : 2 for i in 1:n],
ntuple(i -> i == halfdim ? n : 1, Val(ndims(x)))
)
return scale ./ N .* (p.p \ x)
devmotion marked this conversation as resolved.
Show resolved Hide resolved
end

inv(p::AdjointPlan) = adjoint(inv(p.p))
118 changes: 114 additions & 4 deletions test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
# This file contains code that was formerly part of Julia. License is MIT: https://julialang.org/license

using AbstractFFTs
using AbstractFFTs: Plan
using AbstractFFTs: Plan, ScaledPlan
using ChainRulesTestUtils
using FiniteDifferences
import ChainRulesCore

using LinearAlgebra
using Random
Expand Down Expand Up @@ -197,6 +199,85 @@ end
@test @inferred(f9(plan_fft(zeros(10), 1), 10)) == 1/10
end

@testset "output size" begin
@testset "complex fft output size" begin
for x_shape in ((3,), (3, 4), (3, 4, 5))
N = length(x_shape)
real_x = randn(x_shape)
complex_x = randn(ComplexF64, x_shape)
for x in (real_x, complex_x)
for dims in unique((1, 1:N, N))
P = plan_fft(x, dims)
@test @inferred(AbstractFFTs.output_size(P)) == size(x)
@test AbstractFFTs.output_size(P') == size(x)
Pinv = plan_ifft(x)
@test AbstractFFTs.output_size(Pinv) == size(x)
@test AbstractFFTs.output_size(Pinv') == size(x)
end
end
end
end
@testset "real fft output size" begin
for x in (randn(3), randn(4), randn(3, 4), randn(3, 4, 5)) # test odd and even lengths
N = ndims(x)
for dims in unique((1, 1:N, N))
P = plan_rfft(x, dims)
Px_sz = size(P * x)
@test AbstractFFTs.output_size(P) == Px_sz
@test AbstractFFTs.output_size(P') == size(x)
y = randn(ComplexF64, Px_sz)
Pinv = plan_irfft(y, size(x)[first(dims)], dims)
@test AbstractFFTs.output_size(Pinv) == size(Pinv * y)
@test AbstractFFTs.output_size(Pinv') == size(y)
end
end
end
end

@testset "adjoint" begin
@testset "complex fft adjoint" begin
for x_shape in ((3,), (3, 4), (3, 4, 5))
N = length(x_shape)
real_x = randn(x_shape)
complex_x = randn(ComplexF64, x_shape)
y = randn(ComplexF64, x_shape)
for x in (real_x, complex_x)
for dims in unique((1, 1:N, N))
P = plan_fft(x, dims)
@test (P')' === P # test adjoint of adjoint
@test size(P') == AbstractFFTs.output_size(P) # test size of adjoint
@test dot(y, P * x) ≈ dot(P' * y, x) # test validity of adjoint
@test dot(y, P \ x) ≈ dot(P' \ y, x)
Pinv = plan_ifft(y)
@test (Pinv')' * y == Pinv * y
@test size(Pinv') == AbstractFFTs.output_size(Pinv)
@test dot(x, Pinv * y) ≈ dot(Pinv' * x, y)
@test dot(x, Pinv \ y) ≈ dot(Pinv' \ x, y)
@test_throws MethodError mul!(x, P', y)
end
end
end
end
@testset "real fft adjoint" begin
for x in (randn(3), randn(4), randn(3, 4), randn(3, 4, 5)) # test odd and even lengths
N = ndims(x)
for dims in unique((1, 1:N, N))
P = plan_rfft(x, dims)
y = randn(ComplexF64, size(P * x))
@test (P')' * x == P * x
@test size(P') == AbstractFFTs.output_size(P)
@test dot(real.(y), real.(P * x)) + dot(imag.(y), imag.(P * x)) ≈ dot(P' * y, x)
@test dot(real.(y), real.(P' \ x)) + dot(imag.(y), imag.(P' \ x)) ≈ dot(P \ y, x)
Pinv = plan_irfft(y, size(x)[first(dims)], dims)
@test (Pinv')' * y == Pinv * y
@test size(Pinv') == AbstractFFTs.output_size(Pinv)
@test dot(x, Pinv * y) ≈ dot(real.(y), real.(Pinv' * x)) + dot(imag.(y), imag.(Pinv' * x))
@test dot(x, Pinv' \ y) ≈ dot(real.(y), real.(Pinv \ x)) + dot(imag.(y), imag.(Pinv \ x))
end
end
end
end

@testset "ChainRules" begin
@testset "shift functions" begin
for x in (randn(3), randn(3, 4), randn(3, 4, 5))
Expand All @@ -218,20 +299,43 @@ end
end

@testset "fft" begin
for x in (randn(3), randn(3, 4), randn(3, 4, 5))
N = ndims(x)
complex_x = complex.(x)
# Overloads to allow ChainRulesTestUtils to test rules w.r.t. ScaledPlan's. See https://github.com/JuliaDiff/ChainRulesTestUtils.jl/issues/256
InnerPlan = Union{TestPlan, InverseTestPlan, TestRPlan, InverseTestRPlan}
function FiniteDifferences.to_vec(x::InnerPlan)
function FFTPlan_from_vec(x_vec::Vector)
return x
end
return Bool[], FFTPlan_from_vec
end
ChainRulesTestUtils.test_approx(::ChainRulesCore.AbstractZero, x::InnerPlan, msg=""; kwargs...) = true
ChainRulesTestUtils.rand_tangent(::AbstractRNG, x::InnerPlan) = ChainRulesCore.NoTangent()

for x_shape in ((2,), (2, 3), (3, 4, 5))
N = length(x_shape)
x = randn(x_shape)
complex_x = randn(ComplexF64, x_shape)
for dims in unique((1, 1:N, N))
# fft, ifft, bfft
for f in (fft, ifft, bfft)
test_frule(f, x, dims)
test_rrule(f, x, dims)
test_frule(f, complex_x, dims)
test_rrule(f, complex_x, dims)
end
for pf in (plan_fft, plan_ifft, plan_bfft)
test_frule(*, pf(x, dims), x)
test_rrule(*, pf(x, dims), x)
test_frule(*, pf(complex_x, dims), complex_x)
test_rrule(*, pf(complex_x, dims), complex_x)
end

# rfft
test_frule(rfft, x, dims)
test_rrule(rfft, x, dims)
test_frule(*, plan_rfft(x, dims), x)
test_rrule(*, plan_rfft(x, dims), x)

# irfft, brfft
for f in (irfft, brfft)
for d in (2 * size(x, first(dims)) - 1, 2 * size(x, first(dims)) - 2)
test_frule(f, x, d, dims)
Expand All @@ -240,6 +344,12 @@ end
test_rrule(f, complex_x, d, dims)
end
end
for pf in (plan_irfft, plan_brfft)
for d in (2 * size(x, first(dims)) - 1, 2 * size(x, first(dims)) - 2)
test_frule(*, pf(complex_x, d, dims), complex_x)
test_rrule(*, pf(complex_x, d, dims), complex_x)
end
end
end
end
end
Expand Down
Loading