From f3575aa3b402faac5abcff32a7b21424f78274bb Mon Sep 17 00:00:00 2001 From: Gaurav Arya Date: Tue, 18 Jul 2023 12:44:04 -0400 Subject: [PATCH] Rename ProjectionStyle -> AdjointStyle, _mul -> AdjointMul, improve docs --- docs/src/api.md | 4 ++ docs/src/implementations.md | 9 ++-- src/definitions.jl | 86 ++++++++++++++++++++++++------------- test/testplans.jl | 10 ++--- 4 files changed, 69 insertions(+), 40 deletions(-) diff --git a/docs/src/api.md b/docs/src/api.md index bb3b849..ef998f8 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -21,6 +21,10 @@ AbstractFFTs.plan_brfft AbstractFFTs.plan_irfft AbstractFFTs.fftdims Base.adjoint +AbstractFFTs.FFTAdjointStyle +AbstractFFTs.RFFTAdjointStyle +AbstractFFTs.IRFFTAdjointStyle +AbstractFFTs.UnitaryAdjointStyle AbstractFFTs.fftshift AbstractFFTs.fftshift! AbstractFFTs.ifftshift diff --git a/docs/src/implementations.md b/docs/src/implementations.md index 7367fd4..364d21b 100644 --- a/docs/src/implementations.md +++ b/docs/src/implementations.md @@ -32,10 +32,11 @@ To define a new FFT implementation in your own module, you should * You can also define similar methods of `plan_rfft` and `plan_brfft` for real-input FFTs. -* To enable automatic computation of adjoint plans via [`Base.adjoint`](@ref) (used in rules for reverse-mode differentiation), define the trait `AbstractFFTs.ProjectionStyle(::MyPlan)`, which can return: - * `AbstractFFTs.NoProjectionStyle()`, - * `AbstractFFTs.RealProjectionStyle()`, for plans that halve one of the output's dimensions analogously to [`rfft`](@ref), - * `AbstractFFTs.RealInverseProjectionStyle(d::Int)`, for plans that expect an input with a halved dimension analogously to [`irfft`](@ref), where `d` is the original length of the dimension. +* We offer an experimental `AdjointStyle` trait to enable automatic computation of adjoint plans via [`Base.adjoint`](@ref). + To support adjoints in a new plan, define the trait `AbstractFFTs.AdjointStyle(::MyPlan)`. This should return a subtype of `AS <: AbstractFFTs.AdjointStyle` supporting `AbstractFFTs.adjoint_mul(::Plan, ::AbstractArray, ::AS)` and + `AbstractFFTs._output_size(::Plan, ::AS)`. + + `AbstractFFTs` pre-implements the following adjoint styles: [`AbstractFFTs.FFTAdjointStyle`](@ref), [`AbstractFFTs.RFFTAdjointStyle`](@ref), [`AbstractFFTs.IRFFTAdjointStyle`](@ref), and [`AbstractFFTs.UnitaryAdjointStyle`](@ref). The normalization convention for your FFT should be that it computes ``y_k = \sum_j x_j \exp(-2\pi i j k/n)`` for a transform of length ``n``, and the "backwards" (unnormalized inverse) transform computes the same thing but with ``\exp(+2\pi i jk/n)``. diff --git a/src/definitions.jl b/src/definitions.jl index 604329f..e495104 100644 --- a/src/definitions.jl +++ b/src/definitions.jl @@ -583,35 +583,57 @@ plan_brfft ############################################################################## -abstract type ProjectionStyle end +abstract type AdjointStyle end """ - NoProjectionStyle() + FFTAdjointStyle() -Projection style for complex to complex discrete Fourier transform +Projection style for complex to complex discrete Fourier transforms. + +Since the Fourier transform is unitary up to a scaling, the adjoint simply applies +the transform's inverse with an appropriate scaling. """ -struct NoProjectionStyle <: ProjectionStyle end +struct FFTAdjointStyle <: AdjointStyle end """ - RealProjectionStyle() + RFFTAdjointStyle() -Projection style for complex to real discrete Fourier transform +Projection style for real to complex discrete Fourier transforms, for plans that +halve one of the output's dimensions analogously to [`rfft`](@ref). + +Since the Fourier transform is unitary up to a scaling, the adjoint applies the transform's +inverse, but with additional logic to handle the fact that the output is projected +to exploit its conjugate symmetry (see [`rfft`](@ref)). """ -struct RealProjectionStyle <: ProjectionStyle end +struct RFFTAdjointStyle <: AdjointStyle end """ - RealInverseProjectionStyle() + IRFFTAdjointStyle(d::Dim) -Projection style for inverse of complex to real discrete Fourier transform +Projection style for complex to real discrete Fourier transforms, for plans that +expect an input with a halved dimension analogously to [`irfft`](@ref), where `d` +is the original length of the dimension. + +Since the Fourier transform is unitary up to a scaling, the adjoint applies the transform's +inverse, but with additional logic to handle the fact that the input is projected +to exploit its conjugate symmetry (see [`irfft`](@ref)). """ -struct RealInverseProjectionStyle <: ProjectionStyle +struct IRFFTAdjointStyle <: AdjointStyle dim::Int end -output_size(p::Plan) = _output_size(p, ProjectionStyle(p)) -_output_size(p::Plan, ::NoProjectionStyle) = size(p) -_output_size(p::Plan, ::RealProjectionStyle) = rfft_output_size(size(p), fftdims(p)) -_output_size(p::Plan, s::RealInverseProjectionStyle) = brfft_output_size(size(p), s.dim, fftdims(p)) +""" + UnitaryAdjointStyle() + +Projection style for unitary transforms, whose adjoint equals their inverse. +""" +struct UnitaryAdjointStyle <: AdjointStyle end + +output_size(p::Plan) = _output_size(p, AdjointStyle(p)) +_output_size(p::Plan, ::FFTAdjointStyle) = size(p) +_output_size(p::Plan, ::RFFTAdjointStyle) = rfft_output_size(size(p), fftdims(p)) +_output_size(p::Plan, s::IRFFTAdjointStyle) = brfft_output_size(size(p), s.dim, fftdims(p)) +_output_size(p::Plan, ::UnitaryAdjointStyle) = size(p) struct AdjointPlan{T,P<:Plan} <: Plan{T} p::P @@ -638,40 +660,42 @@ Base.adjoint(p::ScaledPlan) = ScaledPlan(p.p', p.scale) size(p::AdjointPlan) = output_size(p.p) output_size(p::AdjointPlan) = size(p.p) -Base.:*(p::AdjointPlan, x::AbstractArray) = _mul(p, x, ProjectionStyle(p.p)) +Base.:*(p::AdjointPlan, x::AbstractArray) = adjoint_mul(p.p, x, AdjointStyle(p.p)) -function _mul(p::AdjointPlan{T}, x::AbstractArray, ::NoProjectionStyle) where {T} - dims = fftdims(p.p) - N = normalization(T, size(p.p), dims) - return (p.p \ x) / N +function adjoint_mul(p::Plan{T}, x::AbstractArray, ::FFTAdjointStyle) where {T} + dims = fftdims(p) + N = normalization(T, size(p), dims) + return (p \ x) / N end -function _mul(p::AdjointPlan{T}, x::AbstractArray, ::RealProjectionStyle) where {T<:Real} - dims = fftdims(p.p) - N = normalization(T, size(p.p), dims) +function adjoint_mul(p::Plan{T}, x::AbstractArray, ::RFFTAdjointStyle) where {T<:Real} + dims = fftdims(p) + N = normalization(T, size(p), dims) halfdim = first(dims) - d = size(p.p, halfdim) - n = output_size(p.p, halfdim) + d = size(p, halfdim) + n = output_size(p, halfdim) scale = reshape( [(i == 1 || (i == n && 2 * (i - 1)) == d) ? N : 2 * N for i in 1:n], ntuple(i -> i == halfdim ? n : 1, Val(ndims(x))) ) - return p.p \ (x ./ convert(typeof(x), scale)) + return p \ (x ./ convert(typeof(x), scale)) end -function _mul(p::AdjointPlan{T}, x::AbstractArray, ::RealInverseProjectionStyle) where {T} - dims = fftdims(p.p) - N = normalization(real(T), output_size(p.p), dims) +function adjoint_mul(p::Plan{T}, x::AbstractArray, ::IRFFTAdjointStyle) where {T} + dims = fftdims(p) + N = normalization(real(T), output_size(p), dims) halfdim = first(dims) - n = size(p.p, halfdim) - d = output_size(p.p, halfdim) + n = size(p, halfdim) + d = output_size(p, halfdim) scale = reshape( [(i == 1 || (i == n && 2 * (i - 1)) == d) ? 1 : 2 for i in 1:n], ntuple(i -> i == halfdim ? n : 1, Val(ndims(x))) ) - return (convert(typeof(x), scale) ./ N) .* (p.p \ x) + return (convert(typeof(x), scale) ./ N) .* (p \ x) end +adjoint_mul(p::Plan, x::AbstractArray, ::UnitaryAdjointStyle) = p \ x + # Analogously to ScaledPlan, define both plan_inv (for no caching) and inv (caches inner plan only). plan_inv(p::AdjointPlan) = adjoint(plan_inv(p.p)) inv(p::AdjointPlan) = adjoint(inv(p.p)) diff --git a/test/testplans.jl b/test/testplans.jl index 09b3f67..c6a76f9 100644 --- a/test/testplans.jl +++ b/test/testplans.jl @@ -21,8 +21,8 @@ Base.ndims(::TestPlan{T,N}) where {T,N} = N Base.size(p::InverseTestPlan) = p.sz Base.ndims(::InverseTestPlan{T,N}) where {T,N} = N -AbstractFFTs.ProjectionStyle(::TestPlan) = AbstractFFTs.NoProjectionStyle() -AbstractFFTs.ProjectionStyle(::InverseTestPlan) = AbstractFFTs.NoProjectionStyle() +AbstractFFTs.AdjointStyle(::TestPlan) = AbstractFFTs.FFTAdjointStyle() +AbstractFFTs.AdjointStyle(::InverseTestPlan) = AbstractFFTs.FFTAdjointStyle() function AbstractFFTs.plan_fft(x::AbstractArray{T}, region; kwargs...) where {T} return TestPlan{T}(region, size(x)) @@ -110,8 +110,8 @@ mutable struct InverseTestRPlan{T,N,G} <: Plan{Complex{T}} end end -AbstractFFTs.ProjectionStyle(::TestRPlan) = AbstractFFTs.RealProjectionStyle() -AbstractFFTs.ProjectionStyle(p::InverseTestRPlan) = AbstractFFTs.RealInverseProjectionStyle(p.d) +AbstractFFTs.AdjointStyle(::TestRPlan) = AbstractFFTs.RFFTAdjointStyle() +AbstractFFTs.AdjointStyle(p::InverseTestRPlan) = AbstractFFTs.IRFFTAdjointStyle(p.d) function AbstractFFTs.plan_rfft(x::AbstractArray{T}, region; kwargs...) where {T<:Real} return TestRPlan{T}(region, size(x)) @@ -241,7 +241,7 @@ end Base.size(p::InplaceTestPlan) = size(p.plan) Base.ndims(p::InplaceTestPlan) = ndims(p.plan) -AbstractFFTs.ProjectionStyle(p::InplaceTestPlan) = AbstractFFTs.ProjectionStyle(p.plan) +AbstractFFTs.AdjointStyle(p::InplaceTestPlan) = AbstractFFTs.AdjointStyle(p.plan) function AbstractFFTs.plan_fft!(x::AbstractArray, region; kwargs...) return InplaceTestPlan(plan_fft(x, region; kwargs...))