From 5c23f4b52224bd6047b7caf53bb6c7826f1cf2c6 Mon Sep 17 00:00:00 2001 From: Vedant Puri Date: Thu, 27 Jul 2023 11:25:39 -0400 Subject: [PATCH] Rename ProjectionStyle -> AdjointStyle, _mul -> AdjointMul, and improve docs (#109) * make ProjectionStyle abstract type so we can subtype in downstream packages. add a few lines of docs * Rename ProjectionStyle -> AdjointStyle, _mul -> AdjointMul, improve docs * Clarify normalization * Clarify documentation, rename _output_size -> output_size * Remove unnecessary def * Remove confusing commas * Tweak docstring wording * Reposition and improve size/output_size docstrings * Note that size needs to be implemented in docs --------- Co-authored-by: Gaurav Arya --- docs/src/api.md | 19 +++++- docs/src/implementations.md | 10 +-- src/definitions.jl | 131 +++++++++++++++++++++++++++--------- test/testplans.jl | 10 +-- 4 files changed, 128 insertions(+), 42 deletions(-) diff --git a/docs/src/api.md b/docs/src/api.md index bb3b849..713e62d 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -1,5 +1,7 @@ # Public Interface +## FFT and FFT planning functions + ```@docs AbstractFFTs.fft AbstractFFTs.fft! @@ -20,11 +22,26 @@ AbstractFFTs.plan_rfft AbstractFFTs.plan_brfft AbstractFFTs.plan_irfft AbstractFFTs.fftdims -Base.adjoint AbstractFFTs.fftshift AbstractFFTs.fftshift! AbstractFFTs.ifftshift AbstractFFTs.ifftshift! AbstractFFTs.fftfreq AbstractFFTs.rfftfreq +Base.size +``` + +## Adjoint functionality + +The following API is supported by plans that support adjoint functionality. +It is also relevant to implementers of FFT plans that wish to support adjoints. +```@docs +Base.adjoint +AbstractFFTs.AdjointStyle +AbstractFFTs.output_size +AbstractFFTs.adjoint_mul +AbstractFFTs.FFTAdjointStyle +AbstractFFTs.RFFTAdjointStyle +AbstractFFTs.IRFFTAdjointStyle +AbstractFFTs.UnitaryAdjointStyle ``` diff --git a/docs/src/implementations.md b/docs/src/implementations.md index 7367fd4..b7e6751 100644 --- a/docs/src/implementations.md +++ b/docs/src/implementations.md @@ -18,7 +18,8 @@ To define a new FFT implementation in your own module, you should inverse plan. * Define a new method `AbstractFFTs.plan_fft(x, region; kws...)` that returns a `MyPlan` for at least some types of - `x` and some set of dimensions `region`. The `region` (or a copy thereof) should be accessible via `fftdims(p::MyPlan)` (which defaults to `p.region`). + `x` and some set of dimensions `region`. The `region` (or a copy thereof) should be accessible via `fftdims(p::MyPlan)` + (which defaults to `p.region`), and the input size `size(x)` should be accessible via `size(p::MyPlan)`. * Define a method of `LinearAlgebra.mul!(y, p::MyPlan, x)` that computes the transform `p` of `x` and stores the result in `y`. @@ -32,10 +33,9 @@ To define a new FFT implementation in your own module, you should * You can also define similar methods of `plan_rfft` and `plan_brfft` for real-input FFTs. -* To enable automatic computation of adjoint plans via [`Base.adjoint`](@ref) (used in rules for reverse-mode differentiation), define the trait `AbstractFFTs.ProjectionStyle(::MyPlan)`, which can return: - * `AbstractFFTs.NoProjectionStyle()`, - * `AbstractFFTs.RealProjectionStyle()`, for plans that halve one of the output's dimensions analogously to [`rfft`](@ref), - * `AbstractFFTs.RealInverseProjectionStyle(d::Int)`, for plans that expect an input with a halved dimension analogously to [`irfft`](@ref), where `d` is the original length of the dimension. +* To support adjoints in a new plan, define the trait [`AbstractFFTs.AdjointStyle`](@ref). + `AbstractFFTs` implements the following adjoint styles: [`AbstractFFTs.FFTAdjointStyle`](@ref), [`AbstractFFTs.RFFTAdjointStyle`](@ref), [`AbstractFFTs.IRFFTAdjointStyle`](@ref), and [`AbstractFFTs.UnitaryAdjointStyle`](@ref). + To define a new adjoint style, define the methods [`AbstractFFTs.adjoint_mul`](@ref) and [`AbstractFFTs.output_size`](@ref). The normalization convention for your FFT should be that it computes ``y_k = \sum_j x_j \exp(-2\pi i j k/n)`` for a transform of length ``n``, and the "backwards" (unnormalized inverse) transform computes the same thing but with ``\exp(+2\pi i jk/n)``. diff --git a/src/definitions.jl b/src/definitions.jl index 4ec176e..2ac28b7 100644 --- a/src/definitions.jl +++ b/src/definitions.jl @@ -10,9 +10,12 @@ abstract type Plan{T} end eltype(::Type{<:Plan{T}}) where {T} = T -# size(p) should return the size of the input array for p -size(p::Plan, d) = size(p)[d] -output_size(p::Plan, d) = output_size(p)[d] +""" + size(p::Plan, [dim]) + +Return the size of the input of a plan `p`, optionally at a specified dimenion `dim`. +""" +size(p::Plan, dim) = size(p)[dim] ndims(p::Plan) = length(size(p)) length(p::Plan) = prod(size(p))::Int @@ -583,17 +586,73 @@ plan_brfft ############################################################################## -struct NoProjectionStyle end -struct RealProjectionStyle end -struct RealInverseProjectionStyle +""" + AbstractFFTs.AdjointStyle(::Plan) + +Return the adjoint style of a plan, enabling automatic computation of adjoint plans via +[`Base.adjoint`](@ref). Instructions for supporting adjoint styles are provided in the +[implementation instructions](implementations.md#Defining-a-new-implementation). +""" +abstract type AdjointStyle end + +""" + FFTAdjointStyle() + +Adjoint style for complex to complex discrete Fourier transforms that normalize +the output analogously to [`fft`](@ref). + +Since the Fourier transform is unitary up to a scaling, the adjoint simply applies +the transform's inverse with an appropriate scaling. +""" +struct FFTAdjointStyle <: AdjointStyle end + +""" + RFFTAdjointStyle() + +Adjoint style for real to complex discrete Fourier transforms that halve one of +the output's dimensions and normalize the output analogously to [`rfft`](@ref). + +Since the Fourier transform is unitary up to a scaling, the adjoint applies the transform's +inverse, but with appropriate scaling and additional logic to handle the fact that the +output is projected to exploit its conjugate symmetry (see [`rfft`](@ref)). +""" +struct RFFTAdjointStyle <: AdjointStyle end + +""" + IRFFTAdjointStyle(d::Dim) + +Adjoint style for complex to real discrete Fourier transforms that expect an input +with a halved dimension and normalize the output analogously to [`irfft`](@ref), +where `d` is the original length of the dimension. + +Since the Fourier transform is unitary up to a scaling, the adjoint applies the transform's +inverse, but with appropriate scaling and additional logic to handle the fact that the +input is projected to exploit its conjugate symmetry (see [`irfft`](@ref)). +""" +struct IRFFTAdjointStyle <: AdjointStyle dim::Int end -const ProjectionStyle = Union{NoProjectionStyle, RealProjectionStyle, RealInverseProjectionStyle} -output_size(p::Plan) = _output_size(p, ProjectionStyle(p)) -_output_size(p::Plan, ::NoProjectionStyle) = size(p) -_output_size(p::Plan, ::RealProjectionStyle) = rfft_output_size(size(p), fftdims(p)) -_output_size(p::Plan, s::RealInverseProjectionStyle) = brfft_output_size(size(p), s.dim, fftdims(p)) +""" + UnitaryAdjointStyle() + +Adjoint style for unitary transforms, whose adjoint equals their inverse. +""" +struct UnitaryAdjointStyle <: AdjointStyle end + +""" + output_size(p::Plan, [dim]) + +Return the size of the output of a plan `p`, optionally at a specified dimension `dim`. + +Implementations of a new adjoint style `AS <: AbstractFFTs.AdjointStyle` should define `output_size(::Plan, ::AS)`. +""" +output_size(p::Plan) = output_size(p, AdjointStyle(p)) +output_size(p::Plan, dim) = output_size(p)[dim] +output_size(p::Plan, ::FFTAdjointStyle) = size(p) +output_size(p::Plan, ::RFFTAdjointStyle) = rfft_output_size(size(p), fftdims(p)) +output_size(p::Plan, s::IRFFTAdjointStyle) = brfft_output_size(size(p), s.dim, fftdims(p)) +output_size(p::Plan, ::UnitaryAdjointStyle) = size(p) struct AdjointPlan{T,P<:Plan} <: Plan{T} p::P @@ -604,9 +663,7 @@ end (p::Plan)' adjoint(p::Plan) -Form the adjoint operator of an FFT plan. Returns a plan that performs the adjoint operation of -the original plan. Note that this differs from the corresponding backwards plan in the case of real -FFTs due to the halving of one of the dimensions of the FFT output, as described in [`rfft`](@ref). +Return a plan that performs the adjoint operation of the original plan. !!! note Adjoint plans do not currently support `LinearAlgebra.mul!`. Further, as a new addition to `AbstractFFTs`, @@ -620,40 +677,52 @@ Base.adjoint(p::ScaledPlan) = ScaledPlan(p.p', p.scale) size(p::AdjointPlan) = output_size(p.p) output_size(p::AdjointPlan) = size(p.p) -Base.:*(p::AdjointPlan, x::AbstractArray) = _mul(p, x, ProjectionStyle(p.p)) +Base.:*(p::AdjointPlan, x::AbstractArray) = adjoint_mul(p.p, x) + +""" + adjoint_mul(p::Plan, x::AbstractArray) + +Multiply an array `x` by the adjoint of a plan `p`. This is equivalent to `p' * x`. + +Implementations of a new adjoint style `AS <: AbstractFFTs.AdjointStyle` should define +`adjoint_mul(::Plan, ::AbstractArray, ::AS)`. +""" +adjoint_mul(p::Plan, x::AbstractArray) = adjoint_mul(p, x, AdjointStyle(p)) -function _mul(p::AdjointPlan{T}, x::AbstractArray, ::NoProjectionStyle) where {T} - dims = fftdims(p.p) - N = normalization(T, size(p.p), dims) - return (p.p \ x) / N +function adjoint_mul(p::Plan{T}, x::AbstractArray, ::FFTAdjointStyle) where {T} + dims = fftdims(p) + N = normalization(T, size(p), dims) + return (p \ x) / N end -function _mul(p::AdjointPlan{T}, x::AbstractArray, ::RealProjectionStyle) where {T<:Real} - dims = fftdims(p.p) - N = normalization(T, size(p.p), dims) +function adjoint_mul(p::Plan{T}, x::AbstractArray, ::RFFTAdjointStyle) where {T<:Real} + dims = fftdims(p) + N = normalization(T, size(p), dims) halfdim = first(dims) - d = size(p.p, halfdim) - n = output_size(p.p, halfdim) + d = size(p, halfdim) + n = output_size(p, halfdim) scale = reshape( [(i == 1 || (i == n && 2 * (i - 1)) == d) ? N : 2 * N for i in 1:n], ntuple(i -> i == halfdim ? n : 1, Val(ndims(x))) ) - return p.p \ (x ./ convert(typeof(x), scale)) + return p \ (x ./ convert(typeof(x), scale)) end -function _mul(p::AdjointPlan{T}, x::AbstractArray, ::RealInverseProjectionStyle) where {T} - dims = fftdims(p.p) - N = normalization(real(T), output_size(p.p), dims) +function adjoint_mul(p::Plan{T}, x::AbstractArray, ::IRFFTAdjointStyle) where {T} + dims = fftdims(p) + N = normalization(real(T), output_size(p), dims) halfdim = first(dims) - n = size(p.p, halfdim) - d = output_size(p.p, halfdim) + n = size(p, halfdim) + d = output_size(p, halfdim) scale = reshape( [(i == 1 || (i == n && 2 * (i - 1)) == d) ? 1 : 2 for i in 1:n], ntuple(i -> i == halfdim ? n : 1, Val(ndims(x))) ) - return (convert(typeof(x), scale) ./ N) .* (p.p \ x) + return (convert(typeof(x), scale) ./ N) .* (p \ x) end +adjoint_mul(p::Plan, x::AbstractArray, ::UnitaryAdjointStyle) = p \ x + # Analogously to ScaledPlan, define both plan_inv (for no caching) and inv (caches inner plan only). plan_inv(p::AdjointPlan) = adjoint(plan_inv(p.p)) inv(p::AdjointPlan) = adjoint(inv(p.p)) diff --git a/test/testplans.jl b/test/testplans.jl index 09b3f67..c6a76f9 100644 --- a/test/testplans.jl +++ b/test/testplans.jl @@ -21,8 +21,8 @@ Base.ndims(::TestPlan{T,N}) where {T,N} = N Base.size(p::InverseTestPlan) = p.sz Base.ndims(::InverseTestPlan{T,N}) where {T,N} = N -AbstractFFTs.ProjectionStyle(::TestPlan) = AbstractFFTs.NoProjectionStyle() -AbstractFFTs.ProjectionStyle(::InverseTestPlan) = AbstractFFTs.NoProjectionStyle() +AbstractFFTs.AdjointStyle(::TestPlan) = AbstractFFTs.FFTAdjointStyle() +AbstractFFTs.AdjointStyle(::InverseTestPlan) = AbstractFFTs.FFTAdjointStyle() function AbstractFFTs.plan_fft(x::AbstractArray{T}, region; kwargs...) where {T} return TestPlan{T}(region, size(x)) @@ -110,8 +110,8 @@ mutable struct InverseTestRPlan{T,N,G} <: Plan{Complex{T}} end end -AbstractFFTs.ProjectionStyle(::TestRPlan) = AbstractFFTs.RealProjectionStyle() -AbstractFFTs.ProjectionStyle(p::InverseTestRPlan) = AbstractFFTs.RealInverseProjectionStyle(p.d) +AbstractFFTs.AdjointStyle(::TestRPlan) = AbstractFFTs.RFFTAdjointStyle() +AbstractFFTs.AdjointStyle(p::InverseTestRPlan) = AbstractFFTs.IRFFTAdjointStyle(p.d) function AbstractFFTs.plan_rfft(x::AbstractArray{T}, region; kwargs...) where {T<:Real} return TestRPlan{T}(region, size(x)) @@ -241,7 +241,7 @@ end Base.size(p::InplaceTestPlan) = size(p.plan) Base.ndims(p::InplaceTestPlan) = ndims(p.plan) -AbstractFFTs.ProjectionStyle(p::InplaceTestPlan) = AbstractFFTs.ProjectionStyle(p.plan) +AbstractFFTs.AdjointStyle(p::InplaceTestPlan) = AbstractFFTs.AdjointStyle(p.plan) function AbstractFFTs.plan_fft!(x::AbstractArray, region; kwargs...) return InplaceTestPlan(plan_fft(x, region; kwargs...))