Skip to content

Commit

Permalink
Chain rules for FFT plans via AdjointPlans (#67)
Browse files Browse the repository at this point in the history
* Implement AdjointPlans

* Implement chain rules for FFT plans

* Test plan adjoints and AD rules

* Apply suggestions from adjoint plan code review

Co-authored-by: David Widmann <[email protected]>

* Include irrft_dim in RealInverseProjectionStyle

Co-authored-by: David Widmann <[email protected]>

* update to new fftdims interface

* fix broken tests

* Explicitly don't support mul! for adjoint plans

* Document adjoint plans

* remove incorrectly thrown error

* Update adjoint plan docs

* Update adjoint docs

* Fix typos

* tweak adjoint doc string

* Tweaks to adjoint description

* Immutable AdjointPlan

* Add rules and tests for ScaledPlan

* Apply suggestions from code review

Co-authored-by: David Widmann <[email protected]>

* More tweaks to address code review

* Restrict to T<:Real for rfft adjoint

* Get type T correct for test irfft

* Test complex input when appropriate for adjoint tests

* Add plan_inv implementation for adjoint plan and test it

* Apply suggestions from code review

Co-authored-by: Seth Axen <[email protected]>

* Apply suggestions from code review

* Test in-place plans

---------

Co-authored-by: David Widmann <[email protected]>
Co-authored-by: David Widmann <[email protected]>
Co-authored-by: Seth Axen <[email protected]>
  • Loading branch information
4 people authored Jul 5, 2023
1 parent b5109aa commit 8601a92
Show file tree
Hide file tree
Showing 9 changed files with 349 additions and 55 deletions.
3 changes: 2 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,10 @@ julia = "^1.0"
[extras]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a"
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d"

[targets]
test = ["ChainRulesCore", "ChainRulesTestUtils", "Random", "Test", "Unitful"]
test = ["ChainRulesCore", "ChainRulesTestUtils", "FiniteDifferences", "Random", "Test", "Unitful"]
23 changes: 1 addition & 22 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,26 +16,5 @@ This allows multiple FFT packages to co-exist with the same underlying `fft(x)`

## Developer information

To define a new FFT implementation in your own module, you should
To define a new FFT implementation in your own module, see [defining a new implementation](https://juliamath.github.io/AbstractFFTs.jl/stable/implementations/#Defining-a-new-implementation).

* Define a new subtype (e.g. `MyPlan`) of `AbstractFFTs.Plan{T}` for FFTs and related transforms on arrays of `T`.
This must have a `pinv::Plan` field, initially undefined when a `MyPlan` is created, that is used for caching the
inverse plan.

* Define a new method `AbstractFFTs.plan_fft(x, region; kws...)` that returns a `MyPlan` for at least some types of
`x` and some set of dimensions `region`. The `region` (or a copy thereof) should be accessible via `fftdims(p::MyPlan)` (which defaults to `p.region`).

* Define a method of `LinearAlgebra.mul!(y, p::MyPlan, x)` (or `A_mul_B!(y, p::MyPlan, x)` on Julia prior to
0.7.0-DEV.3204) that computes the transform `p` of `x` and stores the result in `y`.

* Define a method of `*(p::MyPlan, x)`, which can simply call your `mul!` (or `A_mul_B!`) method.
This is not defined generically in this package due to subtleties that arise for in-place and real-input FFTs.

* If the inverse transform is implemented, you should also define `plan_inv(p::MyPlan)`, which should construct the
inverse plan to `p`, and `plan_bfft(x, region; kws...)` for an unnormalized inverse ("backwards") transform of `x`.

* You can also define similar methods of `plan_rfft` and `plan_brfft` for real-input FFTs.

The normalization convention for your FFT should be that it computes $y_k = \sum_j \exp\(-2 \pi i \cdot \frac{j k}{n}\) x_j$
for a transform of length $n$, and the "backwards" (unnormalized inverse) transform computes the same thing but with
$\exp\(+2 \pi i \cdot \frac{j k}{n}\)$.
1 change: 1 addition & 0 deletions docs/Project.toml
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
[deps]
AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c"
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"

[compat]
Expand Down
1 change: 1 addition & 0 deletions docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ AbstractFFTs.plan_rfft
AbstractFFTs.plan_brfft
AbstractFFTs.plan_irfft
AbstractFFTs.fftdims
Base.adjoint
AbstractFFTs.fftshift
AbstractFFTs.fftshift!
AbstractFFTs.ifftshift
Expand Down
41 changes: 28 additions & 13 deletions docs/src/implementations.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,31 @@ The following packages extend the functionality provided by AbstractFFTs:

## Defining a new implementation

Implementations should implement `LinearAlgebra.mul!(Y, plan, X)` (or
`A_mul_B!(y, p::MyPlan, x)` on Julia prior to 0.7.0-DEV.3204) so as to support
pre-allocated output arrays.
We don't define `*` in terms of `mul!` generically here, however, because
of subtleties for in-place and real FFT plans.

To support `inv`, `\`, and `ldiv!(y, plan, x)`, we require `Plan` subtypes
to have a `pinv::Plan` field, which caches the inverse plan, and which should be
initially undefined.
They should also implement `plan_inv(p)` to construct the inverse of a plan `p`.

Implementations only need to provide the unnormalized backwards FFT,
similar to FFTW, and we do the scaling generically to get the inverse FFT.
To define a new FFT implementation in your own module, you should

* Define a new subtype (e.g. `MyPlan`) of `AbstractFFTs.Plan{T}` for FFTs and related transforms on arrays of `T`.
This must have a `pinv::Plan` field, initially undefined when a `MyPlan` is created, that is used for caching the
inverse plan.

* Define a new method `AbstractFFTs.plan_fft(x, region; kws...)` that returns a `MyPlan` for at least some types of
`x` and some set of dimensions `region`. The `region` (or a copy thereof) should be accessible via `fftdims(p::MyPlan)` (which defaults to `p.region`).

* Define a method of `LinearAlgebra.mul!(y, p::MyPlan, x)` that computes the transform `p` of `x` and stores the result in `y`.

* Define a method of `*(p::MyPlan, x)`, which can simply call your `mul!` method.
This is not defined generically in this package due to subtleties that arise for in-place and real-input FFTs.

* If the inverse transform is implemented, you should also define `plan_inv(p::MyPlan)`, which should construct the
inverse plan to `p`, and `plan_bfft(x, region; kws...)` for an unnormalized inverse ("backwards") transform of `x`.
Implementations only need to provide the unnormalized backwards FFT, similar to FFTW, and we do the scaling generically
to get the inverse FFT.

* You can also define similar methods of `plan_rfft` and `plan_brfft` for real-input FFTs.

* To enable automatic computation of adjoint plans via [`Base.adjoint`](@ref) (used in rules for reverse-mode differentiation), define the trait `AbstractFFTs.ProjectionStyle(::MyPlan)`, which can return:
* `AbstractFFTs.NoProjectionStyle()`,
* `AbstractFFTs.RealProjectionStyle()`, for plans that halve one of the output's dimensions analogously to [`rfft`](@ref),
* `AbstractFFTs.RealInverseProjectionStyle(d::Int)`, for plans that expect an input with a halved dimension analogously to [`irfft`](@ref), where `d` is the original length of the dimension.

The normalization convention for your FFT should be that it computes ``y_k = \sum_j x_j \exp(-2\pi i j k/n)`` for a transform of
length ``n``, and the "backwards" (unnormalized inverse) transform computes the same thing but with ``\exp(+2\pi i jk/n)``.
50 changes: 50 additions & 0 deletions ext/AbstractFFTsChainRulesCoreExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -159,4 +159,54 @@ function ChainRulesCore.rrule(::typeof(ifftshift), x::AbstractArray, dims)
return y, ifftshift_pullback
end

# plans
function ChainRulesCore.frule((_, _, Δx), ::typeof(*), P::AbstractFFTs.Plan, x::AbstractArray)
y = P * x
if Base.mightalias(y, x)
throw(ArgumentError("differentiation rules are not supported for in-place plans"))
end
Δy = P * Δx
return y, Δy
end
function ChainRulesCore.rrule(::typeof(*), P::AbstractFFTs.Plan, x::AbstractArray)
y = P * x
if Base.mightalias(y, x)
throw(ArgumentError("differentiation rules are not supported for in-place plans"))
end
project_x = ChainRulesCore.ProjectTo(x)
Pt = P'
function mul_plan_pullback(ȳ)
= project_x(Pt * ȳ)
return ChainRulesCore.NoTangent(), ChainRulesCore.NoTangent(), x̄
end
return y, mul_plan_pullback
end

function ChainRulesCore.frule((_, ΔP, Δx), ::typeof(*), P::AbstractFFTs.ScaledPlan, x::AbstractArray)
y = P * x
if Base.mightalias(y, x)
throw(ArgumentError("differentiation rules are not supported for in-place plans"))
end
Δy = P * Δx .+ (ΔP.scale / P.scale) .* y
return y, Δy
end
function ChainRulesCore.rrule(::typeof(*), P::AbstractFFTs.ScaledPlan, x::AbstractArray)
y = P * x
if Base.mightalias(y, x)
throw(ArgumentError("differentiation rules are not supported for in-place plans"))
end
Pt = P'
scale = P.scale
project_x = ChainRulesCore.ProjectTo(x)
project_scale = ChainRulesCore.ProjectTo(scale)
function mul_scaledplan_pullback(ȳ)
= ChainRulesCore.@thunk(project_x(Pt * ȳ))
scale_tangent = ChainRulesCore.@thunk(project_scale(AbstractFFTs.dot(y, ȳ) / conj(scale)))
plan_tangent = ChainRulesCore.Tangent{typeof(P)}(;p=ChainRulesCore.NoTangent(), scale=scale_tangent)
return ChainRulesCore.NoTangent(), plan_tangent, x̄
end
return y, mul_scaledplan_pullback
end

end # module

79 changes: 79 additions & 0 deletions src/definitions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ eltype(::Type{<:Plan{T}}) where {T} = T

# size(p) should return the size of the input array for p
size(p::Plan, d) = size(p)[d]
output_size(p::Plan, d) = output_size(p)[d]
ndims(p::Plan) = length(size(p))
length(p::Plan) = prod(size(p))::Int

Expand Down Expand Up @@ -255,6 +256,7 @@ ScaledPlan(p::Plan{T}, scale::Number) where {T} = ScaledPlan{T}(p, scale)
ScaledPlan(p::ScaledPlan, α::Number) = ScaledPlan(p.p, p.scale * α)

size(p::ScaledPlan) = size(p.p)
output_size(p::ScaledPlan) = output_size(p.p)

fftdims(p::ScaledPlan) = fftdims(p.p)

Expand Down Expand Up @@ -578,3 +580,80 @@ Pre-plan an optimized real-input unnormalized transform, similar to
the same as for [`brfft`](@ref).
"""
plan_brfft

##############################################################################

struct NoProjectionStyle end
struct RealProjectionStyle end
struct RealInverseProjectionStyle
dim::Int
end
const ProjectionStyle = Union{NoProjectionStyle, RealProjectionStyle, RealInverseProjectionStyle}

output_size(p::Plan) = _output_size(p, ProjectionStyle(p))
_output_size(p::Plan, ::NoProjectionStyle) = size(p)
_output_size(p::Plan, ::RealProjectionStyle) = rfft_output_size(size(p), fftdims(p))
_output_size(p::Plan, s::RealInverseProjectionStyle) = brfft_output_size(size(p), s.dim, fftdims(p))

struct AdjointPlan{T,P<:Plan} <: Plan{T}
p::P
AdjointPlan{T,P}(p) where {T,P} = new(p)
end

"""
(p::Plan)'
adjoint(p::Plan)
Form the adjoint operator of an FFT plan. Returns a plan that performs the adjoint operation of
the original plan. Note that this differs from the corresponding backwards plan in the case of real
FFTs due to the halving of one of the dimensions of the FFT output, as described in [`rfft`](@ref).
!!! note
Adjoint plans do not currently support `LinearAlgebra.mul!`. Further, as a new addition to `AbstractFFTs`,
coverage of `Base.adjoint` in downstream implementations may be limited.
"""
Base.adjoint(p::Plan{T}) where {T} = AdjointPlan{T, typeof(p)}(p)
Base.adjoint(p::AdjointPlan) = p.p
# always have AdjointPlan inside ScaledPlan.
Base.adjoint(p::ScaledPlan) = ScaledPlan(p.p', p.scale)

size(p::AdjointPlan) = output_size(p.p)
output_size(p::AdjointPlan) = size(p.p)

Base.:*(p::AdjointPlan, x::AbstractArray) = _mul(p, x, ProjectionStyle(p.p))

function _mul(p::AdjointPlan{T}, x::AbstractArray, ::NoProjectionStyle) where {T}
dims = fftdims(p.p)
N = normalization(T, size(p.p), dims)
return (p.p \ x) / N
end

function _mul(p::AdjointPlan{T}, x::AbstractArray, ::RealProjectionStyle) where {T<:Real}
dims = fftdims(p.p)
N = normalization(T, size(p.p), dims)
halfdim = first(dims)
d = size(p.p, halfdim)
n = output_size(p.p, halfdim)
scale = reshape(
[(i == 1 || (i == n && 2 * (i - 1)) == d) ? N : 2 * N for i in 1:n],
ntuple(i -> i == halfdim ? n : 1, Val(ndims(x)))
)
return p.p \ (x ./ convert(typeof(x), scale))
end

function _mul(p::AdjointPlan{T}, x::AbstractArray, ::RealInverseProjectionStyle) where {T}
dims = fftdims(p.p)
N = normalization(real(T), output_size(p.p), dims)
halfdim = first(dims)
n = size(p.p, halfdim)
d = output_size(p.p, halfdim)
scale = reshape(
[(i == 1 || (i == n && 2 * (i - 1)) == d) ? 1 : 2 for i in 1:n],
ntuple(i -> i == halfdim ? n : 1, Val(ndims(x)))
)
return (convert(typeof(x), scale) ./ N) .* (p.p \ x)
end

# Analogously to ScaledPlan, define both plan_inv (for no caching) and inv (caches inner plan only).
plan_inv(p::AdjointPlan) = adjoint(plan_inv(p.p))
inv(p::AdjointPlan) = adjoint(inv(p.p))
Loading

0 comments on commit 8601a92

Please sign in to comment.