Skip to content

Commit

Permalink
Test coverage
Browse files Browse the repository at this point in the history
  • Loading branch information
Gaurav Arya committed Jun 9, 2022
1 parent a411b06 commit 3b2a754
Showing 1 changed file with 19 additions and 6 deletions.
25 changes: 19 additions & 6 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -205,9 +205,11 @@ end
y = randn(size(x))
for dims in unique((1, 1:N, N))
P = plan_fft(x, dims)
@test AbstractFFTs.output_size(P) == size(P * x)
@test AbstractFFTs.output_size(P) == size(x)
@test AbstractFFTs.output_size(P') == size(x)
Pinv = plan_ifft(x)
@test AbstractFFTs.output_size(Pinv) == size(Pinv * x)
@test AbstractFFTs.output_size(Pinv) == size(x)
@test AbstractFFTs.output_size(Pinv') == size(x)
end
end
end
Expand All @@ -218,9 +220,11 @@ end
P = plan_rfft(x, dims)
Px_sz = size(P * x)
@test AbstractFFTs.output_size(P) == Px_sz
@test AbstractFFTs.output_size(P') == size(x)
y = randn(Px_sz) .+ randn(Px_sz) * im
Pinv = plan_irfft(y, size(x)[first(dims)], dims)
@test AbstractFFTs.output_size(Pinv) == size(Pinv * y)
@test AbstractFFTs.output_size(Pinv') == size(y)
end
end
end
Expand All @@ -233,9 +237,11 @@ end
y = randn(size(x))
for dims in unique((1, 1:N, N))
P = plan_fft(x, dims)
@test (P')' * x == P * x # test adjoint of adjoint
@test dot(y, P * x) dot(P' * y, x)
@test_broken dot(y, P \ x) dot(P' \ y, x)
Pinv = plan_ifft(x)
Pinv = plan_ifft(y)
@test (Pinv')' * y == Pinv * y # test adjoint of adjoint
@test dot(x, Pinv * y) dot(Pinv' * x, y)
@test_broken dot(x, Pinv \ y) dot(Pinv' \ x, y)
end
Expand All @@ -246,12 +252,14 @@ end
N = ndims(x)
for dims in unique((1, 1:N, N))
P = plan_rfft(x, dims)
@test (P')' * x == P * x
y_real = randn(size(P * x))
y_imag = randn(size(P * x))
y = y_real .+ y_imag .* im
@test dot(y_real, real.(P * x)) + dot(y_imag, imag.(P * x)) dot(P' * y, x)
@test_broken dot(y_real, real.(P \ x)) + dot(y_imag, imag.(P \ x)) dot(P' * y, x)
Pinv = plan_irfft(y, size(x)[first(dims)], dims)
@test (Pinv')' * y == Pinv * y
@test dot(x, Pinv * y) dot(y_real, real.(Pinv' * x)) + dot(y_imag, imag.(Pinv' * x))
@test_broken dot(x, Pinv \ y) dot(y_real, real.(Pinv' \ x)) + dot(y_imag, imag.(Pinv' \ x))
end
Expand Down Expand Up @@ -284,20 +292,27 @@ end
N = ndims(x)
complex_x = complex.(x)
for dims in unique((1, 1:N, N))
# fft, ifft, bfft
for f in (fft, ifft, bfft)
test_frule(f, x, dims)
test_rrule(f, x, dims)
test_frule(f, complex_x, dims)
test_rrule(f, complex_x, dims)
end

for pf in (plan_fft, plan_ifft, plan_bfft)
test_frule(*, pf(x, dims) NoTangent(), x)
test_rrule(*, pf(x, dims) NoTangent(), x)
test_frule(*, pf(complex_x, dims) NoTangent(), complex_x)
test_rrule(*, pf(complex_x, dims) NoTangent(), complex_x)
end

# rfft
test_frule(rfft, x, dims)
test_rrule(rfft, x, dims)
test_frule(*, plan_rfft(x, dims) NoTangent(), x)
test_rrule(*, plan_rfft(x, dims) NoTangent(), x)

# irfft, brfft
for f in (irfft, brfft)
for d in (2 * size(x, first(dims)) - 1, 2 * size(x, first(dims)) - 2)
test_frule(f, x, d, dims)
Expand All @@ -306,14 +321,12 @@ end
test_rrule(f, complex_x, d, dims)
end
end

for pf in (plan_irfft, plan_brfft)
for d in (2 * size(x, first(dims)) - 1, 2 * size(x, first(dims)) - 2)
test_frule(*, pf(complex_x, d, dims) NoTangent(), complex_x)
test_rrule(*, pf(complex_x, d, dims) NoTangent(), complex_x)
end
end

end
end
end
Expand Down

0 comments on commit 3b2a754

Please sign in to comment.