Skip to content

Commit

Permalink
Put tests of FFT backends into TestUtils submodule (#78)
Browse files Browse the repository at this point in the history
* Add TestUtils submodule/extension

* Fix typo

* Support Julia 1.0

* Add missing test deps

* Add adjoint testing to test utilities

* Remove mul! method from inplace test plan (consistent with fftw)

* Fix typo

* Document test utilities

* Apply code review suggestions and refactor TestUtils

* Support Julia 1.0

* Reorder kwargs in doc string

* Also explicitly test AbstractFFTs.plan_inv

* Lift isdefined checks out of __init__

* Update src/definitions.jl

Co-authored-by: David Widmann <[email protected]>

* Note TestUtils is a weak extension

* Update function names in error handler

* Add missing test_adjoint's for BRFFT, IRFFT

* Collect x_rfft so as to not hit #112

---------

Co-authored-by: David Widmann <[email protected]>
  • Loading branch information
gaurav-arya and devmotion authored Jul 29, 2023
1 parent 5c23f4b commit 313a180
Show file tree
Hide file tree
Showing 8 changed files with 384 additions and 199 deletions.
3 changes: 3 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,15 @@ version = "1.4.0"
[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[weakdeps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[extensions]
AbstractFFTsChainRulesCoreExt = "ChainRulesCore"
AbstractFFTsTestExt = "Test"

[compat]
ChainRulesCore = "1"
Expand Down
15 changes: 15 additions & 0 deletions docs/src/implementations.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,3 +39,18 @@ To define a new FFT implementation in your own module, you should

The normalization convention for your FFT should be that it computes ``y_k = \sum_j x_j \exp(-2\pi i j k/n)`` for a transform of
length ``n``, and the "backwards" (unnormalized inverse) transform computes the same thing but with ``\exp(+2\pi i jk/n)``.

## Testing implementations

`AbstractFFTs.jl` provides an experimental `TestUtils` module to help with testing downstream implementations,
available as a [weak extension](https://pkgdocs.julialang.org/v1.9/creating-packages/#Conditional-loading-of-code-in-packages-(Extensions)) of `Test`.
The following functions test that all FFT functionality has been correctly implemented:
```@docs
AbstractFFTs.TestUtils.test_complex_ffts
AbstractFFTs.TestUtils.test_real_ffts
```
`TestUtils` also exposes lower level functions for generically testing particular plans:
```@docs
AbstractFFTs.TestUtils.test_plan
AbstractFFTs.TestUtils.test_plan_adjoint
```
232 changes: 232 additions & 0 deletions ext/AbstractFFTsTestExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,232 @@
# This file contains code that was formerly part of Julia. License is MIT: https://julialang.org/license

module AbstractFFTsTestExt

using AbstractFFTs
using AbstractFFTs: TestUtils
using AbstractFFTs.LinearAlgebra
using Test

# Ground truth x_fft computed using FFTW library
const TEST_CASES = (
(; x = collect(1:7), dims = 1,
x_fft = [28.0 + 0.0im,
-3.5 + 7.267824888003178im,
-3.5 + 2.7911568610884143im,
-3.5 + 0.7988521603655248im,
-3.5 - 0.7988521603655248im,
-3.5 - 2.7911568610884143im,
-3.5 - 7.267824888003178im]),
(; x = collect(1:8), dims = 1,
x_fft = [36.0 + 0.0im,
-4.0 + 9.65685424949238im,
-4.0 + 4.0im,
-4.0 + 1.6568542494923806im,
-4.0 + 0.0im,
-4.0 - 1.6568542494923806im,
-4.0 - 4.0im,
-4.0 - 9.65685424949238im]),
(; x = collect(reshape(1:8, 2, 4)), dims = 2,
x_fft = [16.0+0.0im -4.0+4.0im -4.0+0.0im -4.0-4.0im;
20.0+0.0im -4.0+4.0im -4.0+0.0im -4.0-4.0im]),
(; x = collect(reshape(1:9, 3, 3)), dims = 2,
x_fft = [12.0+0.0im -4.5+2.598076211353316im -4.5-2.598076211353316im;
15.0+0.0im -4.5+2.598076211353316im -4.5-2.598076211353316im;
18.0+0.0im -4.5+2.598076211353316im -4.5-2.598076211353316im]),
(; x = collect(reshape(1:8, 2, 2, 2)), dims = 1:2,
x_fft = cat([10.0 + 0.0im -4.0 + 0.0im; -2.0 + 0.0im 0.0 + 0.0im],
[26.0 + 0.0im -4.0 + 0.0im; -2.0 + 0.0im 0.0 + 0.0im],
dims=3)),
(; x = collect(1:7) + im * collect(8:14), dims = 1,
x_fft = [28.0 + 77.0im,
-10.76782488800318 + 3.767824888003175im,
-6.291156861088416 - 0.7088431389115883im,
-4.298852160365525 - 2.7011478396344746im,
-2.7011478396344764 - 4.298852160365524im,
-0.7088431389115866 - 6.291156861088417im,
3.767824888003177 - 10.76782488800318im]),
(; x = collect(reshape(1:8, 2, 2, 2)) + im * reshape(9:16, 2, 2, 2), dims = 1:2,
x_fft = cat([10.0 + 42.0im -4.0 - 4.0im; -2.0 - 2.0im 0.0 + 0.0im],
[26.0 + 58.0im -4.0 - 4.0im; -2.0 - 2.0im 0.0 + 0.0im],
dims=3)),
)


function TestUtils.test_plan(P::AbstractFFTs.Plan, x::AbstractArray, x_transformed::AbstractArray; inplace_plan=false, copy_input=false)
_copy = copy_input ? copy : identity
if !inplace_plan
@test P * _copy(x) x_transformed
@test P \ (P * _copy(x)) x
_x_out = similar(P * _copy(x))
@test mul!(_x_out, P, _copy(x)) x_transformed
@test _x_out x_transformed
else
_x = copy(x)
@test P * _copy(_x) x_transformed
@test _x x_transformed
@test P \ _copy(_x) x
@test _x x
end
end

function TestUtils.test_plan_adjoint(P::AbstractFFTs.Plan, x::AbstractArray; real_plan=false, copy_input=false)
_copy = copy_input ? copy : identity
y = rand(eltype(P * _copy(x)), size(P * _copy(x)))
# test basic properties
@test_skip eltype(P') === typeof(y) # (AbstractFFTs.jl#110)
@test (P')' === P # test adjoint of adjoint
@test size(P') == AbstractFFTs.output_size(P) # test size of adjoint
# test correctness of adjoint and its inverse via the dot test
if !real_plan
@test dot(y, P * _copy(x)) dot(P' * _copy(y), x)
@test dot(y, P \ _copy(x)) dot(P' \ _copy(y), x)
else
_component_dot(x, y) = dot(real.(x), real.(y)) + dot(imag.(x), imag.(y))
@test _component_dot(y, P * _copy(x)) _component_dot(P' * _copy(y), x)
@test _component_dot(x, P \ _copy(y)) _component_dot(P' \ _copy(x), y)
end
@test_throws MethodError mul!(x, P', y)
end

function TestUtils.test_complex_ffts(ArrayType=Array; test_inplace=true, test_adjoint=true)
@testset "correctness of fft, bfft, ifft" begin
for test_case in TEST_CASES
_x, dims, _x_fft = copy(test_case.x), test_case.dims, copy(test_case.x_fft)
x = convert(ArrayType, _x) # dummy array that will be passed to plans
x_complexf = convert(ArrayType, complex.(float.(x))) # for testing mutating complex FFTs
x_fft = convert(ArrayType, _x_fft)

# FFT
@test fft(x, dims) x_fft
if test_inplace
_x_complexf = copy(x_complexf)
@test fft!(_x_complexf, dims) x_fft
@test _x_complexf x_fft
end
# test OOP plans, checking plan_fft and also inv and plan_inv of plan_ifft,
# which should give functionally identical plans
for P in (plan_fft(similar(x_complexf), dims),
(_inv(plan_ifft(similar(x_complexf), dims)) for _inv in (inv, AbstractFFTs.plan_inv))...)
@test eltype(P) <: Complex
@test fftdims(P) == dims
TestUtils.test_plan(P, x_complexf, x_fft)
if test_adjoint
@test fftdims(P') == fftdims(P)
TestUtils.test_plan_adjoint(P, x_complexf)
end
end
if test_inplace
# test IIP plans
for P in (plan_fft!(similar(x_complexf), dims),
(_inv(plan_ifft!(similar(x_complexf), dims)) for _inv in (inv, AbstractFFTs.plan_inv))...)
TestUtils.test_plan(P, x_complexf, x_fft; inplace_plan=true)
end
end

# BFFT
x_scaled = prod(size(x, d) for d in dims) .* x
@test bfft(x_fft, dims) x_scaled
if test_inplace
_x_fft = copy(x_fft)
@test bfft!(_x_fft, dims) x_scaled
@test _x_fft x_scaled
end
# test OOP plans. Just 1 plan to test, but we use a for loop for consistent style
for P in (plan_bfft(similar(x_fft), dims),)
@test eltype(P) <: Complex
@test fftdims(P) == dims
TestUtils.test_plan(P, x_fft, x_scaled)
if test_adjoint
TestUtils.test_plan_adjoint(P, x_fft)
end
end
# test IIP plans
for P in (plan_bfft!(similar(x_fft), dims),)
@test eltype(P) <: Complex
@test fftdims(P) == dims
TestUtils.test_plan(P, x_fft, x_scaled; inplace_plan=true)
end

# IFFT
@test ifft(x_fft, dims) x
if test_inplace
_x_fft = copy(x_fft)
@test ifft!(_x_fft, dims) x
@test _x_fft x
end
# test OOP plans
for P in (plan_ifft(similar(x_complexf), dims),
(_inv(plan_fft(similar(x_complexf), dims)) for _inv in (inv, AbstractFFTs.plan_inv))...)
@test eltype(P) <: Complex
@test fftdims(P) == dims
TestUtils.test_plan(P, x_fft, x)
if test_adjoint
TestUtils.test_plan_adjoint(P, x_fft)
end
end
# test IIP plans
if test_inplace
for P in (plan_ifft!(similar(x_complexf), dims),
(_inv(plan_fft!(similar(x_complexf), dims)) for _inv in (inv, AbstractFFTs.plan_inv))...)
@test eltype(P) <: Complex
@test fftdims(P) == dims
TestUtils.test_plan(P, x_fft, x; inplace_plan=true)
end
end
end
end
end

function TestUtils.test_real_ffts(ArrayType=Array; test_adjoint=true, copy_input=false)
@testset "correctness of rfft, brfft, irfft" begin
for test_case in TEST_CASES
_x, dims, _x_fft = copy(test_case.x), test_case.dims, copy(test_case.x_fft)
x = convert(ArrayType, _x) # dummy array that will be passed to plans
x_real = float.(x) # for testing mutating real FFTs
x_fft = convert(ArrayType, _x_fft)
x_rfft = collect(selectdim(x_fft, first(dims), 1:(size(x_fft, first(dims)) ÷ 2 + 1)))

if !(eltype(x) <: Real)
continue
end

# RFFT
@test rfft(x, dims) x_rfft
for P in (plan_rfft(similar(x_real), dims),
(_inv(plan_irfft(similar(x_rfft), size(x, first(dims)), dims)) for _inv in (inv, AbstractFFTs.plan_inv))...)
@test eltype(P) <: Real
@test fftdims(P) == dims
TestUtils.test_plan(P, x_real, x_rfft; copy_input=copy_input)
if test_adjoint
TestUtils.test_plan_adjoint(P, x_real; real_plan=true, copy_input=copy_input)
end
end

# BRFFT
x_scaled = prod(size(x, d) for d in dims) .* x
@test brfft(x_rfft, size(x, first(dims)), dims) x_scaled
for P in (plan_brfft(similar(x_rfft), size(x, first(dims)), dims),)
@test eltype(P) <: Complex
@test fftdims(P) == dims
TestUtils.test_plan(P, x_rfft, x_scaled; copy_input=copy_input)
if test_adjoint
TestUtils.test_plan_adjoint(P, x_rfft; real_plan=true, copy_input=copy_input)
end
end

# IRFFT
@test irfft(x_rfft, size(x, first(dims)), dims) x
for P in (plan_irfft(similar(x_rfft), size(x, first(dims)), dims),
(_inv(plan_rfft(similar(x_real), dims)) for _inv in (inv, AbstractFFTs.plan_inv))...)
@test eltype(P) <: Complex
@test fftdims(P) == dims
TestUtils.test_plan(P, x_rfft, x; copy_input=copy_input)
if test_adjoint
TestUtils.test_plan_adjoint(P, x_rfft; real_plan=true, copy_input=copy_input)
end
end
end
end
end

end
2 changes: 2 additions & 0 deletions src/AbstractFFTs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,11 @@ export fft, ifft, bfft, fft!, ifft!, bfft!,
fftdims, fftshift, ifftshift, fftshift!, ifftshift!, Frequencies, fftfreq, rfftfreq

include("definitions.jl")
include("TestUtils.jl")

if !isdefined(Base, :get_extension)
include("../ext/AbstractFFTsChainRulesCoreExt.jl")
include("../ext/AbstractFFTsTestExt.jl")
end

end # module
73 changes: 73 additions & 0 deletions src/TestUtils.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
module TestUtils

"""
TestUtils.test_complex_ffts(ArrayType=Array; test_inplace=true, test_adjoint=true)
Run tests to verify correctness of FFT, BFFT, and IFFT functionality using a particular backend plan implementation.
The backend implementation is assumed to be loaded prior to calling this function.
# Arguments
- `ArrayType`: determines the `AbstractArray` implementation for
which the correctness tests are run. Arrays are constructed via
`convert(ArrayType, ...)`.
- `test_inplace=true`: whether to test in-place plans.
- `test_adjoint=true`: whether to test [plan adjoints](api.md#Base.adjoint).
"""
function test_complex_ffts end

"""
TestUtils.test_real_ffts(ArrayType=Array; test_adjoint=true, copy_input=false)
Run tests to verify correctness of RFFT, BRFFT, and IRFFT functionality using a particular backend plan implementation.
The backend implementation is assumed to be loaded prior to calling this function.
# Arguments
- `ArrayType`: determines the `AbstractArray` implementation for
which the correctness tests are run. Arrays are constructed via
`convert(ArrayType, ...)`.
- `test_adjoint=true`: whether to test [plan adjoints](api.md#Base.adjoint).
- `copy_input=false`: whether to copy the input before applying the plan in tests, to accomodate for
[input-mutating behaviour of real FFTW plans](https://github.com/JuliaMath/AbstractFFTs.jl/issues/101).
"""
function test_real_ffts end

# Always copy input before application due to FFTW real plans possibly mutating input (AbstractFFTs.jl#101)
"""
TestUtils.test_plan(P::Plan, x::AbstractArray, x_transformed::AbstractArray;
inplace_plan=false, copy_input=false)
Test basic properties of a plan `P` given an input array `x` and expected output `x_transformed`.
Because [real FFTW plans may mutate their input in some cases](https://github.com/JuliaMath/AbstractFFTs.jl/issues/101),
we allow specifying `copy_input=true` to allow for this behaviour in tests by copying the input before applying the plan.
"""
function test_plan end

"""
TestUtils.test_plan_adjoint(P::Plan, x::AbstractArray; real_plan=false, copy_input=false)
Test basic properties of the [adjoint](api.md#Base.adjoint) `P'` of a particular plan given an input array `x`,
including its accuracy via the dot test.
Real-to-complex and complex-to-real plans require a slightly modified dot test, in which case `real_plan=true` should be provided.
The plan is assumed out-of-place, as adjoints are not yet supported for in-place plans.
Because [real FFTW plans may mutate their input in some cases](https://github.com/JuliaMath/AbstractFFTs.jl/issues/101),
we allow specifying `copy_input=true` to allow for this behaviour in tests by copying the input before applying the plan.
"""
function test_plan_adjoint end

if isdefined(Base, :get_extension) && isdefined(Base.Experimental, :register_error_hint)
function __init__()
# Better error message if users forget to load Test
Base.Experimental.register_error_hint(MethodError) do io, exc, _, _
if any(f -> (f === exc.f), (test_real_ffts, test_complex_ffts, test_plan, test_plan_adjoint)) &&
(Base.get_extension(AbstractFFTs, :AbstractFFTsTestExt) === nothing)
print(io, "\nDid you forget to load Test?")
end
end
end
end

end
3 changes: 2 additions & 1 deletion src/definitions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -665,7 +665,7 @@ end
Return a plan that performs the adjoint operation of the original plan.
!!! note
!!! warning
Adjoint plans do not currently support `LinearAlgebra.mul!`. Further, as a new addition to `AbstractFFTs`,
coverage of `Base.adjoint` in downstream implementations may be limited.
"""
Expand All @@ -676,6 +676,7 @@ Base.adjoint(p::ScaledPlan) = ScaledPlan(p.p', p.scale)

size(p::AdjointPlan) = output_size(p.p)
output_size(p::AdjointPlan) = size(p.p)
fftdims(p::AdjointPlan) = fftdims(p.p)

Base.:*(p::AdjointPlan, x::AbstractArray) = adjoint_mul(p.p, x)

Expand Down
Loading

0 comments on commit 313a180

Please sign in to comment.