From a411b0602ecceee18cd09ff3e592d1c4e38d4bde Mon Sep 17 00:00:00 2001 From: Gaurav Arya Date: Wed, 8 Jun 2022 22:48:26 -0700 Subject: [PATCH] Polish output_size --- src/definitions.jl | 3 +-- test/runtests.jl | 32 ++++++++++++++++++++++++++++++-- 2 files changed, 31 insertions(+), 4 deletions(-) diff --git a/src/definitions.jl b/src/definitions.jl index 3ba69cd..a119878 100644 --- a/src/definitions.jl +++ b/src/definitions.jl @@ -255,7 +255,7 @@ ScaledPlan(p::Plan{T}, scale::Number) where {T} = ScaledPlan{T}(p, scale) ScaledPlan(p::ScaledPlan, α::Number) = ScaledPlan(p.p, p.scale * α) size(p::ScaledPlan) = size(p.p) -output_size(p::ScaledPlan) = size(p) +output_size(p::ScaledPlan) = output_size(p.p) region(p::ScaledPlan) = region(p.p) @@ -587,7 +587,6 @@ const ProjectionStyle = Union{NoProjectionStyle, RealProjectionStyle, RealInvers function irfft_dim end -ProjectionStyle(p::Plan) = error("No projection style defined for plan") output_size(p::Plan) = _output_size(p, ProjectionStyle(p)) _output_size(p::Plan, ::NoProjectionStyle) = size(p) _output_size(p::Plan, ::RealProjectionStyle) = rfft_output_size(size(p), region(p)) diff --git a/test/runtests.jl b/test/runtests.jl index c635754..1c22893 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -198,6 +198,34 @@ end @test @inferred(f9(plan_fft(zeros(10), 1), 10)) == 1/10 end +@testset "output size" begin + @testset "complex fft output size" begin + for x in (randn(3), randn(3, 4), randn(3, 4, 5)) + N = ndims(x) + y = randn(size(x)) + for dims in unique((1, 1:N, N)) + P = plan_fft(x, dims) + @test AbstractFFTs.output_size(P) == size(P * x) + Pinv = plan_ifft(x) + @test AbstractFFTs.output_size(Pinv) == size(Pinv * x) + end + end + end + @testset "real fft output size" begin + for x in (randn(3), randn(4), randn(3, 4), randn(3, 4, 5)) # test odd and even lengths + N = ndims(x) + for dims in unique((1, 1:N, N)) + P = plan_rfft(x, dims) + Px_sz = size(P * x) + @test AbstractFFTs.output_size(P) == Px_sz + y = randn(Px_sz) .+ randn(Px_sz) * im + Pinv = plan_irfft(y, size(x)[first(dims)], dims) + @test AbstractFFTs.output_size(Pinv) == size(Pinv * y) + end + end + end +end + @testset "adjoint" begin @testset "complex fft adjoint" begin for x in (randn(3), randn(3, 4), randn(3, 4, 5)) @@ -217,13 +245,13 @@ end for x in (randn(3), randn(4), randn(3, 4), randn(3, 4, 5)) # test odd and even lengths N = ndims(x) for dims in unique((1, 1:N, N)) - P = plan_rfft(similar(x), dims) + P = plan_rfft(x, dims) y_real = randn(size(P * x)) y_imag = randn(size(P * x)) y = y_real .+ y_imag .* im @test dot(y_real, real.(P * x)) + dot(y_imag, imag.(P * x)) ≈ dot(P' * y, x) @test_broken dot(y_real, real.(P \ x)) + dot(y_imag, imag.(P \ x)) ≈ dot(P' * y, x) - Pinv = plan_irfft(similar(y), size(x)[first(dims)], dims) + Pinv = plan_irfft(y, size(x)[first(dims)], dims) @test dot(x, Pinv * y) ≈ dot(y_real, real.(Pinv' * x)) + dot(y_imag, imag.(Pinv' * x)) @test_broken dot(x, Pinv \ y) ≈ dot(y_real, real.(Pinv' \ x)) + dot(y_imag, imag.(Pinv' \ x)) end