Skip to content

Commit

Permalink
Merge branch 'master' into backend-tests
Browse files Browse the repository at this point in the history
  • Loading branch information
gaurav-arya committed Jul 27, 2023
2 parents e14c045 + 5c23f4b commit 6fec1b7
Show file tree
Hide file tree
Showing 4 changed files with 128 additions and 42 deletions.
19 changes: 18 additions & 1 deletion docs/src/api.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# Public Interface

## FFT and FFT planning functions

```@docs
AbstractFFTs.fft
AbstractFFTs.fft!
Expand All @@ -20,11 +22,26 @@ AbstractFFTs.plan_rfft
AbstractFFTs.plan_brfft
AbstractFFTs.plan_irfft
AbstractFFTs.fftdims
Base.adjoint
AbstractFFTs.fftshift
AbstractFFTs.fftshift!
AbstractFFTs.ifftshift
AbstractFFTs.ifftshift!
AbstractFFTs.fftfreq
AbstractFFTs.rfftfreq
Base.size
```

## Adjoint functionality

The following API is supported by plans that support adjoint functionality.
It is also relevant to implementers of FFT plans that wish to support adjoints.
```@docs
Base.adjoint
AbstractFFTs.AdjointStyle
AbstractFFTs.output_size
AbstractFFTs.adjoint_mul
AbstractFFTs.FFTAdjointStyle
AbstractFFTs.RFFTAdjointStyle
AbstractFFTs.IRFFTAdjointStyle
AbstractFFTs.UnitaryAdjointStyle
```
10 changes: 5 additions & 5 deletions docs/src/implementations.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@ To define a new FFT implementation in your own module, you should
inverse plan.

* Define a new method `AbstractFFTs.plan_fft(x, region; kws...)` that returns a `MyPlan` for at least some types of
`x` and some set of dimensions `region`. The `region` (or a copy thereof) should be accessible via `fftdims(p::MyPlan)` (which defaults to `p.region`).
`x` and some set of dimensions `region`. The `region` (or a copy thereof) should be accessible via `fftdims(p::MyPlan)`
(which defaults to `p.region`), and the input size `size(x)` should be accessible via `size(p::MyPlan)`.

* Define a method of `LinearAlgebra.mul!(y, p::MyPlan, x)` that computes the transform `p` of `x` and stores the result in `y`.

Expand All @@ -32,10 +33,9 @@ To define a new FFT implementation in your own module, you should

* You can also define similar methods of `plan_rfft` and `plan_brfft` for real-input FFTs.

* To enable automatic computation of adjoint plans via [`Base.adjoint`](@ref) (used in rules for reverse-mode differentiation), define the trait `AbstractFFTs.ProjectionStyle(::MyPlan)`, which can return:
* `AbstractFFTs.NoProjectionStyle()`,
* `AbstractFFTs.RealProjectionStyle()`, for plans that halve one of the output's dimensions analogously to [`rfft`](@ref),
* `AbstractFFTs.RealInverseProjectionStyle(d::Int)`, for plans that expect an input with a halved dimension analogously to [`irfft`](@ref), where `d` is the original length of the dimension.
* To support adjoints in a new plan, define the trait [`AbstractFFTs.AdjointStyle`](@ref).
`AbstractFFTs` implements the following adjoint styles: [`AbstractFFTs.FFTAdjointStyle`](@ref), [`AbstractFFTs.RFFTAdjointStyle`](@ref), [`AbstractFFTs.IRFFTAdjointStyle`](@ref), and [`AbstractFFTs.UnitaryAdjointStyle`](@ref).
To define a new adjoint style, define the methods [`AbstractFFTs.adjoint_mul`](@ref) and [`AbstractFFTs.output_size`](@ref).

The normalization convention for your FFT should be that it computes ``y_k = \sum_j x_j \exp(-2\pi i j k/n)`` for a transform of
length ``n``, and the "backwards" (unnormalized inverse) transform computes the same thing but with ``\exp(+2\pi i jk/n)``.
Expand Down
131 changes: 100 additions & 31 deletions src/definitions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,12 @@ abstract type Plan{T} end

eltype(::Type{<:Plan{T}}) where {T} = T

# size(p) should return the size of the input array for p
size(p::Plan, d) = size(p)[d]
output_size(p::Plan, d) = output_size(p)[d]
"""
size(p::Plan, [dim])
Return the size of the input of a plan `p`, optionally at a specified dimenion `dim`.
"""
size(p::Plan, dim) = size(p)[dim]

Check warning on line 18 in src/definitions.jl

View check run for this annotation

Codecov / codecov/patch

src/definitions.jl#L18

Added line #L18 was not covered by tests
ndims(p::Plan) = length(size(p))
length(p::Plan) = prod(size(p))::Int

Expand Down Expand Up @@ -583,17 +586,73 @@ plan_brfft

##############################################################################

struct NoProjectionStyle end
struct RealProjectionStyle end
struct RealInverseProjectionStyle
"""
AbstractFFTs.AdjointStyle(::Plan)
Return the adjoint style of a plan, enabling automatic computation of adjoint plans via
[`Base.adjoint`](@ref). Instructions for supporting adjoint styles are provided in the
[implementation instructions](implementations.md#Defining-a-new-implementation).
"""
abstract type AdjointStyle end

"""
FFTAdjointStyle()
Adjoint style for complex to complex discrete Fourier transforms that normalize
the output analogously to [`fft`](@ref).
Since the Fourier transform is unitary up to a scaling, the adjoint simply applies
the transform's inverse with an appropriate scaling.
"""
struct FFTAdjointStyle <: AdjointStyle end

"""
RFFTAdjointStyle()
Adjoint style for real to complex discrete Fourier transforms that halve one of
the output's dimensions and normalize the output analogously to [`rfft`](@ref).
Since the Fourier transform is unitary up to a scaling, the adjoint applies the transform's
inverse, but with appropriate scaling and additional logic to handle the fact that the
output is projected to exploit its conjugate symmetry (see [`rfft`](@ref)).
"""
struct RFFTAdjointStyle <: AdjointStyle end

"""
IRFFTAdjointStyle(d::Dim)
Adjoint style for complex to real discrete Fourier transforms that expect an input
with a halved dimension and normalize the output analogously to [`irfft`](@ref),
where `d` is the original length of the dimension.
Since the Fourier transform is unitary up to a scaling, the adjoint applies the transform's
inverse, but with appropriate scaling and additional logic to handle the fact that the
input is projected to exploit its conjugate symmetry (see [`irfft`](@ref)).
"""
struct IRFFTAdjointStyle <: AdjointStyle
dim::Int
end
const ProjectionStyle = Union{NoProjectionStyle, RealProjectionStyle, RealInverseProjectionStyle}

output_size(p::Plan) = _output_size(p, ProjectionStyle(p))
_output_size(p::Plan, ::NoProjectionStyle) = size(p)
_output_size(p::Plan, ::RealProjectionStyle) = rfft_output_size(size(p), fftdims(p))
_output_size(p::Plan, s::RealInverseProjectionStyle) = brfft_output_size(size(p), s.dim, fftdims(p))
"""
UnitaryAdjointStyle()
Adjoint style for unitary transforms, whose adjoint equals their inverse.
"""
struct UnitaryAdjointStyle <: AdjointStyle end

"""
output_size(p::Plan, [dim])
Return the size of the output of a plan `p`, optionally at a specified dimension `dim`.
Implementations of a new adjoint style `AS <: AbstractFFTs.AdjointStyle` should define `output_size(::Plan, ::AS)`.
"""
output_size(p::Plan) = output_size(p, AdjointStyle(p))
output_size(p::Plan, dim) = output_size(p)[dim]
output_size(p::Plan, ::FFTAdjointStyle) = size(p)

Check warning on line 652 in src/definitions.jl

View check run for this annotation

Codecov / codecov/patch

src/definitions.jl#L652

Added line #L652 was not covered by tests
output_size(p::Plan, ::RFFTAdjointStyle) = rfft_output_size(size(p), fftdims(p))
output_size(p::Plan, s::IRFFTAdjointStyle) = brfft_output_size(size(p), s.dim, fftdims(p))
output_size(p::Plan, ::UnitaryAdjointStyle) = size(p)

Check warning on line 655 in src/definitions.jl

View check run for this annotation

Codecov / codecov/patch

src/definitions.jl#L655

Added line #L655 was not covered by tests

struct AdjointPlan{T,P<:Plan} <: Plan{T}
p::P
Expand All @@ -604,9 +663,7 @@ end
(p::Plan)'
adjoint(p::Plan)
Form the adjoint operator of an FFT plan. Returns a plan that performs the adjoint operation of
the original plan. Note that this differs from the corresponding backwards plan in the case of real
FFTs due to the halving of one of the dimensions of the FFT output, as described in [`rfft`](@ref).
Return a plan that performs the adjoint operation of the original plan.
!!! warning
Adjoint plans do not currently support `LinearAlgebra.mul!`. Further, as a new addition to `AbstractFFTs`,
Expand All @@ -621,40 +678,52 @@ size(p::AdjointPlan) = output_size(p.p)
output_size(p::AdjointPlan) = size(p.p)
fftdims(p::AdjointPlan) = fftdims(p.p)

Base.:*(p::AdjointPlan, x::AbstractArray) = _mul(p, x, ProjectionStyle(p.p))
Base.:*(p::AdjointPlan, x::AbstractArray) = adjoint_mul(p.p, x)

"""
adjoint_mul(p::Plan, x::AbstractArray)
Multiply an array `x` by the adjoint of a plan `p`. This is equivalent to `p' * x`.
Implementations of a new adjoint style `AS <: AbstractFFTs.AdjointStyle` should define
`adjoint_mul(::Plan, ::AbstractArray, ::AS)`.
"""
adjoint_mul(p::Plan, x::AbstractArray) = adjoint_mul(p, x, AdjointStyle(p))

function _mul(p::AdjointPlan{T}, x::AbstractArray, ::NoProjectionStyle) where {T}
dims = fftdims(p.p)
N = normalization(T, size(p.p), dims)
return (p.p \ x) / N
function adjoint_mul(p::Plan{T}, x::AbstractArray, ::FFTAdjointStyle) where {T}
dims = fftdims(p)
N = normalization(T, size(p), dims)
return (p \ x) / N
end

function _mul(p::AdjointPlan{T}, x::AbstractArray, ::RealProjectionStyle) where {T<:Real}
dims = fftdims(p.p)
N = normalization(T, size(p.p), dims)
function adjoint_mul(p::Plan{T}, x::AbstractArray, ::RFFTAdjointStyle) where {T<:Real}
dims = fftdims(p)
N = normalization(T, size(p), dims)
halfdim = first(dims)
d = size(p.p, halfdim)
n = output_size(p.p, halfdim)
d = size(p, halfdim)
n = output_size(p, halfdim)
scale = reshape(
[(i == 1 || (i == n && 2 * (i - 1)) == d) ? N : 2 * N for i in 1:n],
ntuple(i -> i == halfdim ? n : 1, Val(ndims(x)))
)
return p.p \ (x ./ convert(typeof(x), scale))
return p \ (x ./ convert(typeof(x), scale))
end

function _mul(p::AdjointPlan{T}, x::AbstractArray, ::RealInverseProjectionStyle) where {T}
dims = fftdims(p.p)
N = normalization(real(T), output_size(p.p), dims)
function adjoint_mul(p::Plan{T}, x::AbstractArray, ::IRFFTAdjointStyle) where {T}
dims = fftdims(p)
N = normalization(real(T), output_size(p), dims)
halfdim = first(dims)
n = size(p.p, halfdim)
d = output_size(p.p, halfdim)
n = size(p, halfdim)
d = output_size(p, halfdim)
scale = reshape(
[(i == 1 || (i == n && 2 * (i - 1)) == d) ? 1 : 2 for i in 1:n],
ntuple(i -> i == halfdim ? n : 1, Val(ndims(x)))
)
return (convert(typeof(x), scale) ./ N) .* (p.p \ x)
return (convert(typeof(x), scale) ./ N) .* (p \ x)
end

adjoint_mul(p::Plan, x::AbstractArray, ::UnitaryAdjointStyle) = p \ x

Check warning on line 725 in src/definitions.jl

View check run for this annotation

Codecov / codecov/patch

src/definitions.jl#L725

Added line #L725 was not covered by tests

# Analogously to ScaledPlan, define both plan_inv (for no caching) and inv (caches inner plan only).
plan_inv(p::AdjointPlan) = adjoint(plan_inv(p.p))
inv(p::AdjointPlan) = adjoint(inv(p.p))
10 changes: 5 additions & 5 deletions test/TestPlans.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@ Base.ndims(::TestPlan{T,N}) where {T,N} = N
Base.size(p::InverseTestPlan) = p.sz
Base.ndims(::InverseTestPlan{T,N}) where {T,N} = N

AbstractFFTs.ProjectionStyle(::TestPlan) = AbstractFFTs.NoProjectionStyle()
AbstractFFTs.ProjectionStyle(::InverseTestPlan) = AbstractFFTs.NoProjectionStyle()
AbstractFFTs.AdjointStyle(::TestPlan) = AbstractFFTs.FFTAdjointStyle()
AbstractFFTs.AdjointStyle(::InverseTestPlan) = AbstractFFTs.FFTAdjointStyle()

function AbstractFFTs.plan_fft(x::AbstractArray{T}, region; kwargs...) where {T}
return TestPlan{T}(region, size(x))
Expand Down Expand Up @@ -116,8 +116,8 @@ mutable struct InverseTestRPlan{T,N,G} <: Plan{Complex{T}}
end
end

AbstractFFTs.ProjectionStyle(::TestRPlan) = AbstractFFTs.RealProjectionStyle()
AbstractFFTs.ProjectionStyle(p::InverseTestRPlan) = AbstractFFTs.RealInverseProjectionStyle(p.d)
AbstractFFTs.AdjointStyle(::TestRPlan) = AbstractFFTs.RFFTAdjointStyle()
AbstractFFTs.AdjointStyle(p::InverseTestRPlan) = AbstractFFTs.IRFFTAdjointStyle(p.d)

function AbstractFFTs.plan_rfft(x::AbstractArray{T}, region; kwargs...) where {T<:Real}
return TestRPlan{T}(region, size(x))
Expand Down Expand Up @@ -263,7 +263,7 @@ end
Base.size(p::InplaceTestPlan) = size(p.plan)
Base.ndims(p::InplaceTestPlan) = ndims(p.plan)
AbstractFFTs.fftdims(p::InplaceTestPlan) = fftdims(p.plan)
AbstractFFTs.ProjectionStyle(p::InplaceTestPlan) = AbstractFFTs.ProjectionStyle(p.plan)
AbstractFFTs.AdjointStyle(p::InplaceTestPlan) = AbstractFFTs.AdjointStyle(p.plan)

function AbstractFFTs.plan_fft!(x::AbstractArray, region; kwargs...)
return InplaceTestPlan(plan_fft(x, region; kwargs...))
Expand Down

0 comments on commit 6fec1b7

Please sign in to comment.