diff --git a/Project.toml b/Project.toml index 3a6a88b..6404c36 100644 --- a/Project.toml +++ b/Project.toml @@ -5,12 +5,15 @@ version = "1.4.0" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [weakdeps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" +Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [extensions] AbstractFFTsChainRulesCoreExt = "ChainRulesCore" +AbstractFFTsTestExt = "Test" [compat] ChainRulesCore = "1" diff --git a/docs/src/implementations.md b/docs/src/implementations.md index b7e6751..81deb76 100644 --- a/docs/src/implementations.md +++ b/docs/src/implementations.md @@ -39,3 +39,18 @@ To define a new FFT implementation in your own module, you should The normalization convention for your FFT should be that it computes ``y_k = \sum_j x_j \exp(-2\pi i j k/n)`` for a transform of length ``n``, and the "backwards" (unnormalized inverse) transform computes the same thing but with ``\exp(+2\pi i jk/n)``. + +## Testing implementations + +`AbstractFFTs.jl` provides an experimental `TestUtils` module to help with testing downstream implementations, +available as a [weak extension](https://pkgdocs.julialang.org/v1.9/creating-packages/#Conditional-loading-of-code-in-packages-(Extensions)) of `Test`. +The following functions test that all FFT functionality has been correctly implemented: +```@docs +AbstractFFTs.TestUtils.test_complex_ffts +AbstractFFTs.TestUtils.test_real_ffts +``` +`TestUtils` also exposes lower level functions for generically testing particular plans: +```@docs +AbstractFFTs.TestUtils.test_plan +AbstractFFTs.TestUtils.test_plan_adjoint +``` diff --git a/ext/AbstractFFTsTestExt.jl b/ext/AbstractFFTsTestExt.jl new file mode 100644 index 0000000..ccea93a --- /dev/null +++ b/ext/AbstractFFTsTestExt.jl @@ -0,0 +1,232 @@ +# This file contains code that was formerly part of Julia. License is MIT: https://julialang.org/license + +module AbstractFFTsTestExt + +using AbstractFFTs +using AbstractFFTs: TestUtils +using AbstractFFTs.LinearAlgebra +using Test + +# Ground truth x_fft computed using FFTW library +const TEST_CASES = ( + (; x = collect(1:7), dims = 1, + x_fft = [28.0 + 0.0im, + -3.5 + 7.267824888003178im, + -3.5 + 2.7911568610884143im, + -3.5 + 0.7988521603655248im, + -3.5 - 0.7988521603655248im, + -3.5 - 2.7911568610884143im, + -3.5 - 7.267824888003178im]), + (; x = collect(1:8), dims = 1, + x_fft = [36.0 + 0.0im, + -4.0 + 9.65685424949238im, + -4.0 + 4.0im, + -4.0 + 1.6568542494923806im, + -4.0 + 0.0im, + -4.0 - 1.6568542494923806im, + -4.0 - 4.0im, + -4.0 - 9.65685424949238im]), + (; x = collect(reshape(1:8, 2, 4)), dims = 2, + x_fft = [16.0+0.0im -4.0+4.0im -4.0+0.0im -4.0-4.0im; + 20.0+0.0im -4.0+4.0im -4.0+0.0im -4.0-4.0im]), + (; x = collect(reshape(1:9, 3, 3)), dims = 2, + x_fft = [12.0+0.0im -4.5+2.598076211353316im -4.5-2.598076211353316im; + 15.0+0.0im -4.5+2.598076211353316im -4.5-2.598076211353316im; + 18.0+0.0im -4.5+2.598076211353316im -4.5-2.598076211353316im]), + (; x = collect(reshape(1:8, 2, 2, 2)), dims = 1:2, + x_fft = cat([10.0 + 0.0im -4.0 + 0.0im; -2.0 + 0.0im 0.0 + 0.0im], + [26.0 + 0.0im -4.0 + 0.0im; -2.0 + 0.0im 0.0 + 0.0im], + dims=3)), + (; x = collect(1:7) + im * collect(8:14), dims = 1, + x_fft = [28.0 + 77.0im, + -10.76782488800318 + 3.767824888003175im, + -6.291156861088416 - 0.7088431389115883im, + -4.298852160365525 - 2.7011478396344746im, + -2.7011478396344764 - 4.298852160365524im, + -0.7088431389115866 - 6.291156861088417im, + 3.767824888003177 - 10.76782488800318im]), + (; x = collect(reshape(1:8, 2, 2, 2)) + im * reshape(9:16, 2, 2, 2), dims = 1:2, + x_fft = cat([10.0 + 42.0im -4.0 - 4.0im; -2.0 - 2.0im 0.0 + 0.0im], + [26.0 + 58.0im -4.0 - 4.0im; -2.0 - 2.0im 0.0 + 0.0im], + dims=3)), + ) + + +function TestUtils.test_plan(P::AbstractFFTs.Plan, x::AbstractArray, x_transformed::AbstractArray; inplace_plan=false, copy_input=false) + _copy = copy_input ? copy : identity + if !inplace_plan + @test P * _copy(x) ≈ x_transformed + @test P \ (P * _copy(x)) ≈ x + _x_out = similar(P * _copy(x)) + @test mul!(_x_out, P, _copy(x)) ≈ x_transformed + @test _x_out ≈ x_transformed + else + _x = copy(x) + @test P * _copy(_x) ≈ x_transformed + @test _x ≈ x_transformed + @test P \ _copy(_x) ≈ x + @test _x ≈ x + end +end + +function TestUtils.test_plan_adjoint(P::AbstractFFTs.Plan, x::AbstractArray; real_plan=false, copy_input=false) + _copy = copy_input ? copy : identity + y = rand(eltype(P * _copy(x)), size(P * _copy(x))) + # test basic properties + @test_skip eltype(P') === typeof(y) # (AbstractFFTs.jl#110) + @test (P')' === P # test adjoint of adjoint + @test size(P') == AbstractFFTs.output_size(P) # test size of adjoint + # test correctness of adjoint and its inverse via the dot test + if !real_plan + @test dot(y, P * _copy(x)) ≈ dot(P' * _copy(y), x) + @test dot(y, P \ _copy(x)) ≈ dot(P' \ _copy(y), x) + else + _component_dot(x, y) = dot(real.(x), real.(y)) + dot(imag.(x), imag.(y)) + @test _component_dot(y, P * _copy(x)) ≈ _component_dot(P' * _copy(y), x) + @test _component_dot(x, P \ _copy(y)) ≈ _component_dot(P' \ _copy(x), y) + end + @test_throws MethodError mul!(x, P', y) +end + +function TestUtils.test_complex_ffts(ArrayType=Array; test_inplace=true, test_adjoint=true) + @testset "correctness of fft, bfft, ifft" begin + for test_case in TEST_CASES + _x, dims, _x_fft = copy(test_case.x), test_case.dims, copy(test_case.x_fft) + x = convert(ArrayType, _x) # dummy array that will be passed to plans + x_complexf = convert(ArrayType, complex.(float.(x))) # for testing mutating complex FFTs + x_fft = convert(ArrayType, _x_fft) + + # FFT + @test fft(x, dims) ≈ x_fft + if test_inplace + _x_complexf = copy(x_complexf) + @test fft!(_x_complexf, dims) ≈ x_fft + @test _x_complexf ≈ x_fft + end + # test OOP plans, checking plan_fft and also inv and plan_inv of plan_ifft, + # which should give functionally identical plans + for P in (plan_fft(similar(x_complexf), dims), + (_inv(plan_ifft(similar(x_complexf), dims)) for _inv in (inv, AbstractFFTs.plan_inv))...) + @test eltype(P) <: Complex + @test fftdims(P) == dims + TestUtils.test_plan(P, x_complexf, x_fft) + if test_adjoint + @test fftdims(P') == fftdims(P) + TestUtils.test_plan_adjoint(P, x_complexf) + end + end + if test_inplace + # test IIP plans + for P in (plan_fft!(similar(x_complexf), dims), + (_inv(plan_ifft!(similar(x_complexf), dims)) for _inv in (inv, AbstractFFTs.plan_inv))...) + TestUtils.test_plan(P, x_complexf, x_fft; inplace_plan=true) + end + end + + # BFFT + x_scaled = prod(size(x, d) for d in dims) .* x + @test bfft(x_fft, dims) ≈ x_scaled + if test_inplace + _x_fft = copy(x_fft) + @test bfft!(_x_fft, dims) ≈ x_scaled + @test _x_fft ≈ x_scaled + end + # test OOP plans. Just 1 plan to test, but we use a for loop for consistent style + for P in (plan_bfft(similar(x_fft), dims),) + @test eltype(P) <: Complex + @test fftdims(P) == dims + TestUtils.test_plan(P, x_fft, x_scaled) + if test_adjoint + TestUtils.test_plan_adjoint(P, x_fft) + end + end + # test IIP plans + for P in (plan_bfft!(similar(x_fft), dims),) + @test eltype(P) <: Complex + @test fftdims(P) == dims + TestUtils.test_plan(P, x_fft, x_scaled; inplace_plan=true) + end + + # IFFT + @test ifft(x_fft, dims) ≈ x + if test_inplace + _x_fft = copy(x_fft) + @test ifft!(_x_fft, dims) ≈ x + @test _x_fft ≈ x + end + # test OOP plans + for P in (plan_ifft(similar(x_complexf), dims), + (_inv(plan_fft(similar(x_complexf), dims)) for _inv in (inv, AbstractFFTs.plan_inv))...) + @test eltype(P) <: Complex + @test fftdims(P) == dims + TestUtils.test_plan(P, x_fft, x) + if test_adjoint + TestUtils.test_plan_adjoint(P, x_fft) + end + end + # test IIP plans + if test_inplace + for P in (plan_ifft!(similar(x_complexf), dims), + (_inv(plan_fft!(similar(x_complexf), dims)) for _inv in (inv, AbstractFFTs.plan_inv))...) + @test eltype(P) <: Complex + @test fftdims(P) == dims + TestUtils.test_plan(P, x_fft, x; inplace_plan=true) + end + end + end + end +end + +function TestUtils.test_real_ffts(ArrayType=Array; test_adjoint=true, copy_input=false) + @testset "correctness of rfft, brfft, irfft" begin + for test_case in TEST_CASES + _x, dims, _x_fft = copy(test_case.x), test_case.dims, copy(test_case.x_fft) + x = convert(ArrayType, _x) # dummy array that will be passed to plans + x_real = float.(x) # for testing mutating real FFTs + x_fft = convert(ArrayType, _x_fft) + x_rfft = collect(selectdim(x_fft, first(dims), 1:(size(x_fft, first(dims)) ÷ 2 + 1))) + + if !(eltype(x) <: Real) + continue + end + + # RFFT + @test rfft(x, dims) ≈ x_rfft + for P in (plan_rfft(similar(x_real), dims), + (_inv(plan_irfft(similar(x_rfft), size(x, first(dims)), dims)) for _inv in (inv, AbstractFFTs.plan_inv))...) + @test eltype(P) <: Real + @test fftdims(P) == dims + TestUtils.test_plan(P, x_real, x_rfft; copy_input=copy_input) + if test_adjoint + TestUtils.test_plan_adjoint(P, x_real; real_plan=true, copy_input=copy_input) + end + end + + # BRFFT + x_scaled = prod(size(x, d) for d in dims) .* x + @test brfft(x_rfft, size(x, first(dims)), dims) ≈ x_scaled + for P in (plan_brfft(similar(x_rfft), size(x, first(dims)), dims),) + @test eltype(P) <: Complex + @test fftdims(P) == dims + TestUtils.test_plan(P, x_rfft, x_scaled; copy_input=copy_input) + if test_adjoint + TestUtils.test_plan_adjoint(P, x_rfft; real_plan=true, copy_input=copy_input) + end + end + + # IRFFT + @test irfft(x_rfft, size(x, first(dims)), dims) ≈ x + for P in (plan_irfft(similar(x_rfft), size(x, first(dims)), dims), + (_inv(plan_rfft(similar(x_real), dims)) for _inv in (inv, AbstractFFTs.plan_inv))...) + @test eltype(P) <: Complex + @test fftdims(P) == dims + TestUtils.test_plan(P, x_rfft, x; copy_input=copy_input) + if test_adjoint + TestUtils.test_plan_adjoint(P, x_rfft; real_plan=true, copy_input=copy_input) + end + end + end + end +end + +end diff --git a/src/AbstractFFTs.jl b/src/AbstractFFTs.jl index 00f6dc2..3225916 100644 --- a/src/AbstractFFTs.jl +++ b/src/AbstractFFTs.jl @@ -6,9 +6,11 @@ export fft, ifft, bfft, fft!, ifft!, bfft!, fftdims, fftshift, ifftshift, fftshift!, ifftshift!, Frequencies, fftfreq, rfftfreq include("definitions.jl") +include("TestUtils.jl") if !isdefined(Base, :get_extension) include("../ext/AbstractFFTsChainRulesCoreExt.jl") + include("../ext/AbstractFFTsTestExt.jl") end end # module diff --git a/src/TestUtils.jl b/src/TestUtils.jl new file mode 100644 index 0000000..adfffea --- /dev/null +++ b/src/TestUtils.jl @@ -0,0 +1,73 @@ +module TestUtils + +""" + TestUtils.test_complex_ffts(ArrayType=Array; test_inplace=true, test_adjoint=true) + +Run tests to verify correctness of FFT, BFFT, and IFFT functionality using a particular backend plan implementation. +The backend implementation is assumed to be loaded prior to calling this function. + +# Arguments + +- `ArrayType`: determines the `AbstractArray` implementation for + which the correctness tests are run. Arrays are constructed via + `convert(ArrayType, ...)`. +- `test_inplace=true`: whether to test in-place plans. +- `test_adjoint=true`: whether to test [plan adjoints](api.md#Base.adjoint). +""" +function test_complex_ffts end + +""" + TestUtils.test_real_ffts(ArrayType=Array; test_adjoint=true, copy_input=false) + +Run tests to verify correctness of RFFT, BRFFT, and IRFFT functionality using a particular backend plan implementation. +The backend implementation is assumed to be loaded prior to calling this function. + +# Arguments + +- `ArrayType`: determines the `AbstractArray` implementation for + which the correctness tests are run. Arrays are constructed via + `convert(ArrayType, ...)`. +- `test_adjoint=true`: whether to test [plan adjoints](api.md#Base.adjoint). +- `copy_input=false`: whether to copy the input before applying the plan in tests, to accomodate for + [input-mutating behaviour of real FFTW plans](https://github.com/JuliaMath/AbstractFFTs.jl/issues/101). +""" +function test_real_ffts end + + # Always copy input before application due to FFTW real plans possibly mutating input (AbstractFFTs.jl#101) +""" + TestUtils.test_plan(P::Plan, x::AbstractArray, x_transformed::AbstractArray; + inplace_plan=false, copy_input=false) + +Test basic properties of a plan `P` given an input array `x` and expected output `x_transformed`. + +Because [real FFTW plans may mutate their input in some cases](https://github.com/JuliaMath/AbstractFFTs.jl/issues/101), +we allow specifying `copy_input=true` to allow for this behaviour in tests by copying the input before applying the plan. +""" +function test_plan end + +""" + TestUtils.test_plan_adjoint(P::Plan, x::AbstractArray; real_plan=false, copy_input=false) + +Test basic properties of the [adjoint](api.md#Base.adjoint) `P'` of a particular plan given an input array `x`, +including its accuracy via the dot test. + +Real-to-complex and complex-to-real plans require a slightly modified dot test, in which case `real_plan=true` should be provided. +The plan is assumed out-of-place, as adjoints are not yet supported for in-place plans. +Because [real FFTW plans may mutate their input in some cases](https://github.com/JuliaMath/AbstractFFTs.jl/issues/101), +we allow specifying `copy_input=true` to allow for this behaviour in tests by copying the input before applying the plan. +""" +function test_plan_adjoint end + +if isdefined(Base, :get_extension) && isdefined(Base.Experimental, :register_error_hint) + function __init__() + # Better error message if users forget to load Test + Base.Experimental.register_error_hint(MethodError) do io, exc, _, _ + if any(f -> (f === exc.f), (test_real_ffts, test_complex_ffts, test_plan, test_plan_adjoint)) && + (Base.get_extension(AbstractFFTs, :AbstractFFTsTestExt) === nothing) + print(io, "\nDid you forget to load Test?") + end + end + end +end + +end diff --git a/src/definitions.jl b/src/definitions.jl index 2ac28b7..5dc703f 100644 --- a/src/definitions.jl +++ b/src/definitions.jl @@ -665,7 +665,7 @@ end Return a plan that performs the adjoint operation of the original plan. -!!! note +!!! warning Adjoint plans do not currently support `LinearAlgebra.mul!`. Further, as a new addition to `AbstractFFTs`, coverage of `Base.adjoint` in downstream implementations may be limited. """ @@ -676,6 +676,7 @@ Base.adjoint(p::ScaledPlan) = ScaledPlan(p.p', p.scale) size(p::AdjointPlan) = output_size(p.p) output_size(p::AdjointPlan) = size(p.p) +fftdims(p::AdjointPlan) = fftdims(p.p) Base.:*(p::AdjointPlan, x::AbstractArray) = adjoint_mul(p.p, x) diff --git a/test/testplans.jl b/test/TestPlans.jl similarity index 90% rename from test/testplans.jl rename to test/TestPlans.jl index c6a76f9..1961113 100644 --- a/test/testplans.jl +++ b/test/TestPlans.jl @@ -1,3 +1,9 @@ +module TestPlans + +using LinearAlgebra +using AbstractFFTs +using AbstractFFTs: Plan + mutable struct TestPlan{T,N,G} <: Plan{T} region::G sz::NTuple{N,Int} @@ -76,13 +82,13 @@ function dft!( return y end -function mul!( +function LinearAlgebra.mul!( y::AbstractArray{<:Complex,N}, p::TestPlan, x::AbstractArray{<:Union{Complex,Real},N} ) where {N} size(y) == size(p) == size(x) || throw(DimensionMismatch()) dft!(y, x, p.region, -1) end -function mul!( +function LinearAlgebra.mul!( y::AbstractArray{<:Complex,N}, p::InverseTestPlan, x::AbstractArray{<:Union{Complex,Real},N} ) where {N} size(y) == size(p) == size(x) || throw(DimensionMismatch()) @@ -194,22 +200,17 @@ end to_real!(x::AbstractArray) = map!(real, x, x) -function Base.:*(p::TestRPlan, x::AbstractArray) +function LinearAlgebra.mul!(y::AbstractArray{<:Complex, N}, p::TestRPlan, x::AbstractArray{<:Real, N}) where {N} size(p) == size(x) || error("array and plan are not consistent") - # create output array - firstdim = first(p.region)::Int - d = size(x, firstdim) - firstdim_size = d ÷ 2 + 1 - T = complex(float(eltype(x))) - sz = ntuple(i -> i == firstdim ? firstdim_size : size(x, i), Val(ndims(x))) - y = similar(x, T, sz) - # compute DFT dft!(y, x, p.region, -1) # we clean the output a bit to make sure that we return real values # whenever the output is mathematically guaranteed to be a real number + firstdim = first(p.region)::Int + d = size(x, firstdim) + firstdim_size = d ÷ 2 + 1 to_real!(selectdim(y, firstdim, 1)) if iseven(d) to_real!(selectdim(y, firstdim, firstdim_size)) @@ -218,29 +219,50 @@ function Base.:*(p::TestRPlan, x::AbstractArray) return y end -function Base.:*(p::InverseTestRPlan, x::AbstractArray) +function Base.:*(p::TestRPlan, x::AbstractArray) + # create output array + firstdim = first(p.region)::Int + d = size(x, firstdim) + firstdim_size = d ÷ 2 + 1 + T = complex(float(eltype(x))) + sz = ntuple(i -> i == firstdim ? firstdim_size : size(x, i), Val(ndims(x))) + y = similar(x, T, sz) + + # run in-place mul! + mul!(y, p, x) + + return y +end + +function LinearAlgebra.mul!(y::AbstractArray{<:Real, N}, p::InverseTestRPlan, x::AbstractArray{<:Complex, N}) where {N} size(p) == size(x) || error("array and plan are not consistent") + # compute DFT + real_invdft!(y, x, p.region) +end + +function Base.:*(p::InverseTestRPlan, x::AbstractArray) # create output array firstdim = first(p.region)::Int d = p.d sz = ntuple(i -> i == firstdim ? d : size(x, i), Val(ndims(x))) y = similar(x, real(float(eltype(x))), sz) - # compute DFT - real_invdft!(y, x, p.region) + # run in-place mul! + mul!(y, p, x) return y end # In-place plans -# (simple wrapper of out-of-place plans that does not support inverses) +# (simple wrapper of OOP plans) struct InplaceTestPlan{T,P<:Plan{T}} <: Plan{T} plan::P end Base.size(p::InplaceTestPlan) = size(p.plan) Base.ndims(p::InplaceTestPlan) = ndims(p.plan) +AbstractFFTs.fftdims(p::InplaceTestPlan) = fftdims(p.plan) AbstractFFTs.AdjointStyle(p::InplaceTestPlan) = AbstractFFTs.AdjointStyle(p.plan) function AbstractFFTs.plan_fft!(x::AbstractArray, region; kwargs...) @@ -250,7 +272,10 @@ function AbstractFFTs.plan_bfft!(x::AbstractArray, region; kwargs...) return InplaceTestPlan(plan_bfft(x, region; kwargs...)) end -function LinearAlgebra.mul!(y::AbstractArray, p::InplaceTestPlan, x::AbstractArray) - return mul!(y, p.plan, x) -end Base.:*(p::InplaceTestPlan, x::AbstractArray) = copyto!(x, p.plan * x) + +AbstractFFTs.plan_inv(p::InplaceTestPlan) = InplaceTestPlan(AbstractFFTs.plan_inv(p.plan)) +# Don't cache inverse of inplace wrapper plan (only inverse of inner plan) +Base.inv(p::InplaceTestPlan) = InplaceTestPlan(inv(p.plan)) + +end \ No newline at end of file diff --git a/test/runtests.jl b/test/runtests.jl index c5f0659..fe74897 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,20 +1,20 @@ -# This file contains code that was formerly part of Julia. License is MIT: https://julialang.org/license - -using AbstractFFTs -using AbstractFFTs: Plan, ScaledPlan -using ChainRulesTestUtils -using FiniteDifferences -import ChainRulesCore - -using LinearAlgebra using Random using Test - +using AbstractFFTs +using ChainRulesTestUtils import Unitful +using LinearAlgebra +using ChainRulesCore +using FiniteDifferences Random.seed!(1234) -include("testplans.jl") +# Load example plan implementation. +include("TestPlans.jl") + +# Run interface tests for TestPlans +AbstractFFTs.TestUtils.test_complex_ffts(Array) +AbstractFFTs.TestUtils.test_real_ffts(Array) @testset "rfft sizes" begin A = rand(11, 10) @@ -26,124 +26,6 @@ include("testplans.jl") @test_throws AssertionError AbstractFFTs.brfft_output_size(A1, 10, 2) end -@testset "Custom Plan" begin - # DFT along last dimension, results computed using FFTW - for (x, fftw_fft) in ( - (collect(1:7), - [28.0 + 0.0im, - -3.5 + 7.267824888003178im, - -3.5 + 2.7911568610884143im, - -3.5 + 0.7988521603655248im, - -3.5 - 0.7988521603655248im, - -3.5 - 2.7911568610884143im, - -3.5 - 7.267824888003178im]), - (collect(1:8), - [36.0 + 0.0im, - -4.0 + 9.65685424949238im, - -4.0 + 4.0im, - -4.0 + 1.6568542494923806im, - -4.0 + 0.0im, - -4.0 - 1.6568542494923806im, - -4.0 - 4.0im, - -4.0 - 9.65685424949238im]), - (collect(reshape(1:8, 2, 4)), - [16.0+0.0im -4.0+4.0im -4.0+0.0im -4.0-4.0im; - 20.0+0.0im -4.0+4.0im -4.0+0.0im -4.0-4.0im]), - (collect(reshape(1:9, 3, 3)), - [12.0+0.0im -4.5+2.598076211353316im -4.5-2.598076211353316im; - 15.0+0.0im -4.5+2.598076211353316im -4.5-2.598076211353316im; - 18.0+0.0im -4.5+2.598076211353316im -4.5-2.598076211353316im]), - ) - # FFT - dims = ndims(x) - y = AbstractFFTs.fft(x, dims) - @test y ≈ fftw_fft - # test plan_fft and also inv and plan_inv of plan_ifft, which should all give - # functionally identical plans - for P in [plan_fft(x, dims), inv(plan_ifft(x, dims)), - AbstractFFTs.plan_inv(plan_ifft(x, dims))] - @test eltype(P) === ComplexF64 - @test P * x ≈ fftw_fft - @test P \ (P * x) ≈ x - @test fftdims(P) == dims - end - - # in-place plan - P = plan_fft!(x, dims) - @test eltype(P) === ComplexF64 - xc64 = ComplexF64.(x) - @test P * xc64 ≈ fftw_fft - @test xc64 ≈ fftw_fft - - fftw_bfft = complex.(size(x, dims) .* x) - @test AbstractFFTs.bfft(y, dims) ≈ fftw_bfft - P = plan_bfft(x, dims) - @test P * y ≈ fftw_bfft - @test P \ (P * y) ≈ y - @test fftdims(P) == dims - - # in-place plan - P = plan_bfft!(x, dims) - @test eltype(P) === ComplexF64 - yc64 = ComplexF64.(y) - @test P * yc64 ≈ fftw_bfft - @test yc64 ≈ fftw_bfft - - fftw_ifft = complex.(x) - @test AbstractFFTs.ifft(y, dims) ≈ fftw_ifft - # test plan_ifft and also inv and plan_inv of plan_fft, which should all give - # functionally identical plans - for P in [plan_ifft(x, dims), inv(plan_fft(x, dims)), - AbstractFFTs.plan_inv(plan_fft(x, dims))] - @test P * y ≈ fftw_ifft - @test P \ (P * y) ≈ y - @test fftdims(P) == dims - end - - # in-place plan - P = plan_ifft!(x, dims) - @test eltype(P) === ComplexF64 - yc64 = ComplexF64.(y) - @test P * yc64 ≈ fftw_ifft - @test yc64 ≈ fftw_ifft - - # real FFT - fftw_rfft = fftw_fft[ - (Colon() for _ in 1:(ndims(fftw_fft) - 1))..., - 1:(size(fftw_fft, ndims(fftw_fft)) ÷ 2 + 1) - ] - ry = AbstractFFTs.rfft(x, dims) - @test ry ≈ fftw_rfft - # test plan_rfft and also inv and plan_inv of plan_irfft, which should all give - # functionally identical plans - for P in [plan_rfft(x, dims), inv(plan_irfft(ry, size(x, dims), dims)), - AbstractFFTs.plan_inv(plan_irfft(ry, size(x, dims), dims))] - @test eltype(P) <: Real - @test P * x ≈ fftw_rfft - @test P \ (P * x) ≈ x - @test fftdims(P) == dims - end - - fftw_brfft = complex.(size(x, dims) .* x) - @test AbstractFFTs.brfft(ry, size(x, dims), dims) ≈ fftw_brfft - P = plan_brfft(ry, size(x, dims), dims) - @test P * ry ≈ fftw_brfft - @test P \ (P * ry) ≈ ry - @test fftdims(P) == dims - - fftw_irfft = complex.(x) - @test AbstractFFTs.irfft(ry, size(x, dims), dims) ≈ fftw_irfft - # test plan_rfft and also inv and plan_inv of plan_irfft, which should all give - # functionally identical plans - for P in [plan_irfft(ry, size(x, dims), dims), inv(plan_rfft(x, dims)), - AbstractFFTs.plan_inv(plan_rfft(x, dims))] - @test P * ry ≈ fftw_irfft - @test P \ (P * ry) ≈ ry - @test fftdims(P) == dims - end - end -end - @testset "Shift functions" begin @test @inferred(AbstractFFTs.fftshift([1 2 3])) == [3 1 2] @test @inferred(AbstractFFTs.fftshift([1, 2, 3])) == [3, 1, 2] @@ -232,7 +114,7 @@ end # normalization should be inferable even if region is only inferred as ::Any, # need to wrap in another function to test this (note that p.region::Any for # p::TestPlan) - f9(p::Plan{T}, sz) where {T} = AbstractFFTs.normalization(real(T), sz, fftdims(p)) + f9(p::AbstractFFTs.Plan{T}, sz) where {T} = AbstractFFTs.normalization(real(T), sz, fftdims(p)) @test @inferred(f9(plan_fft(zeros(10), 1), 10)) == 1/10 end @@ -271,55 +153,6 @@ end end end -@testset "adjoint" begin - @testset "complex fft adjoint" begin - for x_shape in ((3,), (3, 4), (3, 4, 5)) - N = length(x_shape) - real_x = randn(x_shape) - complex_x = randn(ComplexF64, x_shape) - y = randn(ComplexF64, x_shape) - for x in (real_x, complex_x) - for dims in unique((1, 1:N, N)) - P = plan_fft(x, dims) - @test (P')' === P # test adjoint of adjoint - @test size(P') == AbstractFFTs.output_size(P) # test size of adjoint - @test dot(y, P * x) ≈ dot(P' * y, x) # test validity of adjoint - @test dot(y, P \ x) ≈ dot(P' \ y, x) # test inv of adjoint - @test dot(y, P \ x) ≈ dot(AbstractFFTs.plan_inv(P') * y, x) # test plan_inv of adjoint - Pinv = plan_ifft(y) - @test (Pinv')' * y == Pinv * y - @test size(Pinv') == AbstractFFTs.output_size(Pinv) - @test dot(x, Pinv * y) ≈ dot(Pinv' * x, y) - @test dot(x, Pinv \ y) ≈ dot(Pinv' \ x, y) - @test dot(x, Pinv \ y) ≈ dot(AbstractFFTs.plan_inv(Pinv') * x, y) - @test_throws MethodError mul!(x, P', y) - end - end - end - end - @testset "real fft adjoint" begin - for x in (randn(3), randn(4), randn(3, 4), randn(3, 4, 5)) # test odd and even lengths - N = ndims(x) - for dims in unique((1, 1:N, N)) - P = plan_rfft(x, dims) - y = randn(ComplexF64, size(P * x)) - @test (P')' * x == P * x - @test size(P') == AbstractFFTs.output_size(P) - @test dot(real.(y), real.(P * x)) + dot(imag.(y), imag.(P * x)) ≈ dot(P' * y, x) - @test dot(real.(y), real.(P' \ x)) + dot(imag.(y), imag.(P' \ x)) ≈ dot(P \ y, x) - @test dot(real.(y), real.(AbstractFFTs.plan_inv(P') * x)) + - dot(imag.(y), imag.(AbstractFFTs.plan_inv(P') * x)) ≈ dot(P \ y, x) - Pinv = plan_irfft(y, size(x)[first(dims)], dims) - @test (Pinv')' * y == Pinv * y - @test size(Pinv') == AbstractFFTs.output_size(Pinv) - @test dot(x, Pinv * y) ≈ dot(real.(y), real.(Pinv' * x)) + dot(imag.(y), imag.(Pinv' * x)) - @test dot(x, Pinv' \ y) ≈ dot(real.(y), real.(Pinv \ x)) + dot(imag.(y), imag.(Pinv \ x)) - @test dot(x, AbstractFFTs.plan_inv(Pinv') * y) ≈ dot(real.(y), real.(Pinv \ x)) + dot(imag.(y), imag.(Pinv \ x)) - end - end - end -end - # Test that dims defaults to 1:ndims for fft-like functions @testset "Default dims" begin for x in (randn(3), randn(3, 4), randn(3, 4, 5)) @@ -369,7 +202,7 @@ end @testset "fft" begin # Overloads to allow ChainRulesTestUtils to test rules w.r.t. ScaledPlan's. See https://github.com/JuliaDiff/ChainRulesTestUtils.jl/issues/256 - InnerPlan = Union{TestPlan, InverseTestPlan, TestRPlan, InverseTestRPlan} + InnerPlan = Union{TestPlans.TestPlan, TestPlans.InverseTestPlan, TestPlans.TestRPlan, TestPlans.InverseTestRPlan} function FiniteDifferences.to_vec(x::InnerPlan) function FFTPlan_from_vec(x_vec::Vector) return x @@ -427,3 +260,4 @@ end end end end +