Skip to content

Commit

Permalink
Rename ProjectionStyle -> AdjointStyle, _mul -> AdjointMul, improve docs
Browse files Browse the repository at this point in the history
  • Loading branch information
gaurav-arya committed Jul 18, 2023
1 parent d53f57d commit f3575aa
Show file tree
Hide file tree
Showing 4 changed files with 69 additions and 40 deletions.
4 changes: 4 additions & 0 deletions docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,10 @@ AbstractFFTs.plan_brfft
AbstractFFTs.plan_irfft
AbstractFFTs.fftdims
Base.adjoint
AbstractFFTs.FFTAdjointStyle
AbstractFFTs.RFFTAdjointStyle
AbstractFFTs.IRFFTAdjointStyle
AbstractFFTs.UnitaryAdjointStyle
AbstractFFTs.fftshift
AbstractFFTs.fftshift!
AbstractFFTs.ifftshift
Expand Down
9 changes: 5 additions & 4 deletions docs/src/implementations.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,11 @@ To define a new FFT implementation in your own module, you should

* You can also define similar methods of `plan_rfft` and `plan_brfft` for real-input FFTs.

* To enable automatic computation of adjoint plans via [`Base.adjoint`](@ref) (used in rules for reverse-mode differentiation), define the trait `AbstractFFTs.ProjectionStyle(::MyPlan)`, which can return:
* `AbstractFFTs.NoProjectionStyle()`,
* `AbstractFFTs.RealProjectionStyle()`, for plans that halve one of the output's dimensions analogously to [`rfft`](@ref),
* `AbstractFFTs.RealInverseProjectionStyle(d::Int)`, for plans that expect an input with a halved dimension analogously to [`irfft`](@ref), where `d` is the original length of the dimension.
* We offer an experimental `AdjointStyle` trait to enable automatic computation of adjoint plans via [`Base.adjoint`](@ref).
To support adjoints in a new plan, define the trait `AbstractFFTs.AdjointStyle(::MyPlan)`. This should return a subtype of `AS <: AbstractFFTs.AdjointStyle` supporting `AbstractFFTs.adjoint_mul(::Plan, ::AbstractArray, ::AS)` and
`AbstractFFTs._output_size(::Plan, ::AS)`.

`AbstractFFTs` pre-implements the following adjoint styles: [`AbstractFFTs.FFTAdjointStyle`](@ref), [`AbstractFFTs.RFFTAdjointStyle`](@ref), [`AbstractFFTs.IRFFTAdjointStyle`](@ref), and [`AbstractFFTs.UnitaryAdjointStyle`](@ref).

The normalization convention for your FFT should be that it computes ``y_k = \sum_j x_j \exp(-2\pi i j k/n)`` for a transform of
length ``n``, and the "backwards" (unnormalized inverse) transform computes the same thing but with ``\exp(+2\pi i jk/n)``.
86 changes: 55 additions & 31 deletions src/definitions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -583,35 +583,57 @@ plan_brfft

##############################################################################

abstract type ProjectionStyle end
abstract type AdjointStyle end

"""
NoProjectionStyle()
FFTAdjointStyle()
Projection style for complex to complex discrete Fourier transform
Projection style for complex to complex discrete Fourier transforms.
Since the Fourier transform is unitary up to a scaling, the adjoint simply applies
the transform's inverse with an appropriate scaling.
"""
struct NoProjectionStyle <: ProjectionStyle end
struct FFTAdjointStyle <: AdjointStyle end

"""
RealProjectionStyle()
RFFTAdjointStyle()
Projection style for complex to real discrete Fourier transform
Projection style for real to complex discrete Fourier transforms, for plans that
halve one of the output's dimensions analogously to [`rfft`](@ref).
Since the Fourier transform is unitary up to a scaling, the adjoint applies the transform's
inverse, but with additional logic to handle the fact that the output is projected
to exploit its conjugate symmetry (see [`rfft`](@ref)).
"""
struct RealProjectionStyle <: ProjectionStyle end
struct RFFTAdjointStyle <: AdjointStyle end

"""
RealInverseProjectionStyle()
IRFFTAdjointStyle(d::Dim)
Projection style for inverse of complex to real discrete Fourier transform
Projection style for complex to real discrete Fourier transforms, for plans that
expect an input with a halved dimension analogously to [`irfft`](@ref), where `d`
is the original length of the dimension.
Since the Fourier transform is unitary up to a scaling, the adjoint applies the transform's
inverse, but with additional logic to handle the fact that the input is projected
to exploit its conjugate symmetry (see [`irfft`](@ref)).
"""
struct RealInverseProjectionStyle <: ProjectionStyle
struct IRFFTAdjointStyle <: AdjointStyle
dim::Int
end

output_size(p::Plan) = _output_size(p, ProjectionStyle(p))
_output_size(p::Plan, ::NoProjectionStyle) = size(p)
_output_size(p::Plan, ::RealProjectionStyle) = rfft_output_size(size(p), fftdims(p))
_output_size(p::Plan, s::RealInverseProjectionStyle) = brfft_output_size(size(p), s.dim, fftdims(p))
"""
UnitaryAdjointStyle()
Projection style for unitary transforms, whose adjoint equals their inverse.
"""
struct UnitaryAdjointStyle <: AdjointStyle end

output_size(p::Plan) = _output_size(p, AdjointStyle(p))
_output_size(p::Plan, ::FFTAdjointStyle) = size(p)
_output_size(p::Plan, ::RFFTAdjointStyle) = rfft_output_size(size(p), fftdims(p))
_output_size(p::Plan, s::IRFFTAdjointStyle) = brfft_output_size(size(p), s.dim, fftdims(p))
_output_size(p::Plan, ::UnitaryAdjointStyle) = size(p)

Check warning on line 636 in src/definitions.jl

View check run for this annotation

Codecov / codecov/patch

src/definitions.jl#L636

Added line #L636 was not covered by tests

struct AdjointPlan{T,P<:Plan} <: Plan{T}
p::P
Expand All @@ -638,40 +660,42 @@ Base.adjoint(p::ScaledPlan) = ScaledPlan(p.p', p.scale)
size(p::AdjointPlan) = output_size(p.p)
output_size(p::AdjointPlan) = size(p.p)

Base.:*(p::AdjointPlan, x::AbstractArray) = _mul(p, x, ProjectionStyle(p.p))
Base.:*(p::AdjointPlan, x::AbstractArray) = adjoint_mul(p.p, x, AdjointStyle(p.p))

function _mul(p::AdjointPlan{T}, x::AbstractArray, ::NoProjectionStyle) where {T}
dims = fftdims(p.p)
N = normalization(T, size(p.p), dims)
return (p.p \ x) / N
function adjoint_mul(p::Plan{T}, x::AbstractArray, ::FFTAdjointStyle) where {T}
dims = fftdims(p)
N = normalization(T, size(p), dims)
return (p \ x) / N
end

function _mul(p::AdjointPlan{T}, x::AbstractArray, ::RealProjectionStyle) where {T<:Real}
dims = fftdims(p.p)
N = normalization(T, size(p.p), dims)
function adjoint_mul(p::Plan{T}, x::AbstractArray, ::RFFTAdjointStyle) where {T<:Real}
dims = fftdims(p)
N = normalization(T, size(p), dims)
halfdim = first(dims)
d = size(p.p, halfdim)
n = output_size(p.p, halfdim)
d = size(p, halfdim)
n = output_size(p, halfdim)
scale = reshape(
[(i == 1 || (i == n && 2 * (i - 1)) == d) ? N : 2 * N for i in 1:n],
ntuple(i -> i == halfdim ? n : 1, Val(ndims(x)))
)
return p.p \ (x ./ convert(typeof(x), scale))
return p \ (x ./ convert(typeof(x), scale))
end

function _mul(p::AdjointPlan{T}, x::AbstractArray, ::RealInverseProjectionStyle) where {T}
dims = fftdims(p.p)
N = normalization(real(T), output_size(p.p), dims)
function adjoint_mul(p::Plan{T}, x::AbstractArray, ::IRFFTAdjointStyle) where {T}
dims = fftdims(p)
N = normalization(real(T), output_size(p), dims)
halfdim = first(dims)
n = size(p.p, halfdim)
d = output_size(p.p, halfdim)
n = size(p, halfdim)
d = output_size(p, halfdim)
scale = reshape(
[(i == 1 || (i == n && 2 * (i - 1)) == d) ? 1 : 2 for i in 1:n],
ntuple(i -> i == halfdim ? n : 1, Val(ndims(x)))
)
return (convert(typeof(x), scale) ./ N) .* (p.p \ x)
return (convert(typeof(x), scale) ./ N) .* (p \ x)
end

adjoint_mul(p::Plan, x::AbstractArray, ::UnitaryAdjointStyle) = p \ x

Check warning on line 697 in src/definitions.jl

View check run for this annotation

Codecov / codecov/patch

src/definitions.jl#L697

Added line #L697 was not covered by tests

# Analogously to ScaledPlan, define both plan_inv (for no caching) and inv (caches inner plan only).
plan_inv(p::AdjointPlan) = adjoint(plan_inv(p.p))
inv(p::AdjointPlan) = adjoint(inv(p.p))
10 changes: 5 additions & 5 deletions test/testplans.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@ Base.ndims(::TestPlan{T,N}) where {T,N} = N
Base.size(p::InverseTestPlan) = p.sz
Base.ndims(::InverseTestPlan{T,N}) where {T,N} = N

AbstractFFTs.ProjectionStyle(::TestPlan) = AbstractFFTs.NoProjectionStyle()
AbstractFFTs.ProjectionStyle(::InverseTestPlan) = AbstractFFTs.NoProjectionStyle()
AbstractFFTs.AdjointStyle(::TestPlan) = AbstractFFTs.FFTAdjointStyle()
AbstractFFTs.AdjointStyle(::InverseTestPlan) = AbstractFFTs.FFTAdjointStyle()

function AbstractFFTs.plan_fft(x::AbstractArray{T}, region; kwargs...) where {T}
return TestPlan{T}(region, size(x))
Expand Down Expand Up @@ -110,8 +110,8 @@ mutable struct InverseTestRPlan{T,N,G} <: Plan{Complex{T}}
end
end

AbstractFFTs.ProjectionStyle(::TestRPlan) = AbstractFFTs.RealProjectionStyle()
AbstractFFTs.ProjectionStyle(p::InverseTestRPlan) = AbstractFFTs.RealInverseProjectionStyle(p.d)
AbstractFFTs.AdjointStyle(::TestRPlan) = AbstractFFTs.RFFTAdjointStyle()
AbstractFFTs.AdjointStyle(p::InverseTestRPlan) = AbstractFFTs.IRFFTAdjointStyle(p.d)

function AbstractFFTs.plan_rfft(x::AbstractArray{T}, region; kwargs...) where {T<:Real}
return TestRPlan{T}(region, size(x))
Expand Down Expand Up @@ -241,7 +241,7 @@ end

Base.size(p::InplaceTestPlan) = size(p.plan)
Base.ndims(p::InplaceTestPlan) = ndims(p.plan)
AbstractFFTs.ProjectionStyle(p::InplaceTestPlan) = AbstractFFTs.ProjectionStyle(p.plan)
AbstractFFTs.AdjointStyle(p::InplaceTestPlan) = AbstractFFTs.AdjointStyle(p.plan)

function AbstractFFTs.plan_fft!(x::AbstractArray, region; kwargs...)
return InplaceTestPlan(plan_fft(x, region; kwargs...))
Expand Down

0 comments on commit f3575aa

Please sign in to comment.