Skip to content

Commit

Permalink
Polish output_size
Browse files Browse the repository at this point in the history
  • Loading branch information
Gaurav Arya committed Jun 9, 2022
1 parent c7efe8d commit a411b06
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 4 deletions.
3 changes: 1 addition & 2 deletions src/definitions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,7 @@ ScaledPlan(p::Plan{T}, scale::Number) where {T} = ScaledPlan{T}(p, scale)
ScaledPlan(p::ScaledPlan, α::Number) = ScaledPlan(p.p, p.scale * α)

size(p::ScaledPlan) = size(p.p)
output_size(p::ScaledPlan) = size(p)
output_size(p::ScaledPlan) = output_size(p.p)

region(p::ScaledPlan) = region(p.p)

Expand Down Expand Up @@ -587,7 +587,6 @@ const ProjectionStyle = Union{NoProjectionStyle, RealProjectionStyle, RealInvers

function irfft_dim end

ProjectionStyle(p::Plan) = error("No projection style defined for plan")
output_size(p::Plan) = _output_size(p, ProjectionStyle(p))
_output_size(p::Plan, ::NoProjectionStyle) = size(p)
_output_size(p::Plan, ::RealProjectionStyle) = rfft_output_size(size(p), region(p))
Expand Down
32 changes: 30 additions & 2 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,34 @@ end
@test @inferred(f9(plan_fft(zeros(10), 1), 10)) == 1/10
end

@testset "output size" begin
@testset "complex fft output size" begin
for x in (randn(3), randn(3, 4), randn(3, 4, 5))
N = ndims(x)
y = randn(size(x))
for dims in unique((1, 1:N, N))
P = plan_fft(x, dims)
@test AbstractFFTs.output_size(P) == size(P * x)
Pinv = plan_ifft(x)
@test AbstractFFTs.output_size(Pinv) == size(Pinv * x)
end
end
end
@testset "real fft output size" begin
for x in (randn(3), randn(4), randn(3, 4), randn(3, 4, 5)) # test odd and even lengths
N = ndims(x)
for dims in unique((1, 1:N, N))
P = plan_rfft(x, dims)
Px_sz = size(P * x)
@test AbstractFFTs.output_size(P) == Px_sz
y = randn(Px_sz) .+ randn(Px_sz) * im
Pinv = plan_irfft(y, size(x)[first(dims)], dims)
@test AbstractFFTs.output_size(Pinv) == size(Pinv * y)
end
end
end
end

@testset "adjoint" begin
@testset "complex fft adjoint" begin
for x in (randn(3), randn(3, 4), randn(3, 4, 5))
Expand All @@ -217,13 +245,13 @@ end
for x in (randn(3), randn(4), randn(3, 4), randn(3, 4, 5)) # test odd and even lengths
N = ndims(x)
for dims in unique((1, 1:N, N))
P = plan_rfft(similar(x), dims)
P = plan_rfft(x, dims)
y_real = randn(size(P * x))
y_imag = randn(size(P * x))
y = y_real .+ y_imag .* im
@test dot(y_real, real.(P * x)) + dot(y_imag, imag.(P * x)) dot(P' * y, x)
@test_broken dot(y_real, real.(P \ x)) + dot(y_imag, imag.(P \ x)) dot(P' * y, x)
Pinv = plan_irfft(similar(y), size(x)[first(dims)], dims)
Pinv = plan_irfft(y, size(x)[first(dims)], dims)
@test dot(x, Pinv * y) dot(y_real, real.(Pinv' * x)) + dot(y_imag, imag.(Pinv' * x))
@test_broken dot(x, Pinv \ y) dot(y_real, real.(Pinv' \ x)) + dot(y_imag, imag.(Pinv' \ x))
end
Expand Down

0 comments on commit a411b06

Please sign in to comment.