diff --git a/Project.toml b/Project.toml index 173bd5c..9dcaad1 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "AbstractFFTs" uuid = "621f4979-c628-5d54-868e-fcf4e3e8185c" -version = "1.1.0" +version = "1.2.0" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" diff --git a/README.md b/README.md index 89f7d48..5b33c59 100644 --- a/README.md +++ b/README.md @@ -23,7 +23,7 @@ To define a new FFT implementation in your own module, you should inverse plan. * Define a new method `AbstractFFTs.plan_fft(x, region; kws...)` that returns a `MyPlan` for at least some types of - `x` and some set of dimensions `region`. + `x` and some set of dimensions `region`. The `region` (or a copy thereof) should be accessible via `fftdims(p::MyPlan)` (which defaults to `p.region`). * Define a method of `LinearAlgebra.mul!(y, p::MyPlan, x)` (or `A_mul_B!(y, p::MyPlan, x)` on Julia prior to 0.7.0-DEV.3204) that computes the transform `p` of `x` and stores the result in `y`. diff --git a/docs/src/api.md b/docs/src/api.md index 147908a..5d8316b 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -19,6 +19,7 @@ AbstractFFTs.brfft AbstractFFTs.plan_rfft AbstractFFTs.plan_brfft AbstractFFTs.plan_irfft +AbstractFFTs.fftdims AbstractFFTs.fftshift AbstractFFTs.fftshift! AbstractFFTs.ifftshift diff --git a/src/AbstractFFTs.jl b/src/AbstractFFTs.jl index 734c7d4..56d7123 100644 --- a/src/AbstractFFTs.jl +++ b/src/AbstractFFTs.jl @@ -5,7 +5,7 @@ import ChainRulesCore export fft, ifft, bfft, fft!, ifft!, bfft!, plan_fft, plan_ifft, plan_bfft, plan_fft!, plan_ifft!, plan_bfft!, rfft, irfft, brfft, plan_rfft, plan_irfft, plan_brfft, - fftshift, ifftshift, fftshift!, ifftshift!, Frequencies, fftfreq, rfftfreq + fftdims, fftshift, ifftshift, fftshift!, ifftshift!, Frequencies, fftfreq, rfftfreq include("definitions.jl") include("chainrules.jl") diff --git a/src/definitions.jl b/src/definitions.jl index 41df3e5..7901966 100644 --- a/src/definitions.jl +++ b/src/definitions.jl @@ -15,6 +15,18 @@ size(p::Plan, d) = size(p)[d] ndims(p::Plan) = length(size(p)) length(p::Plan) = prod(size(p))::Int +""" + fftdims(p::Plan) + +Return an iterable of the dimensions that are transformed by the FFT plan `p`. + +# Implementation + +For legacy reasons, the default definition of `fftdims` returns `p.region`. +Hence this method should be implemented only for `Plan` subtypes that do not store the transformed dimensions in a field named `region`. +""" +fftdims(p::Plan) = p.region + fftfloat(x) = _fftfloat(float(x)) _fftfloat(::Type{T}) where {T<:BlasReal} = T _fftfloat(::Type{Float16}) = Float32 @@ -243,6 +255,8 @@ ScaledPlan(p::ScaledPlan, α::Number) = ScaledPlan(p.p, p.scale * α) size(p::ScaledPlan) = size(p.p) +fftdims(p::ScaledPlan) = fftdims(p.p) + show(io::IO, p::ScaledPlan) = print(io, p.scale, " * ", p.p) summary(p::ScaledPlan) = string(p.scale, " * ", summary(p.p)) diff --git a/test/runtests.jl b/test/runtests.jl index 95c7c5d..623d625 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -60,18 +60,21 @@ end @test eltype(P) === ComplexF64 @test P * x ≈ fftw_fft @test P \ (P * x) ≈ x + @test fftdims(P) == dims fftw_bfft = complex.(size(x, dims) .* x) @test AbstractFFTs.bfft(y, dims) ≈ fftw_bfft P = plan_bfft(x, dims) @test P * y ≈ fftw_bfft @test P \ (P * y) ≈ y + @test fftdims(P) == dims fftw_ifft = complex.(x) @test AbstractFFTs.ifft(y, dims) ≈ fftw_ifft P = plan_ifft(x, dims) @test P * y ≈ fftw_ifft @test P \ (P * y) ≈ y + @test fftdims(P) == dims # real FFT fftw_rfft = fftw_fft[ @@ -84,18 +87,21 @@ end @test eltype(P) === Int @test P * x ≈ fftw_rfft @test P \ (P * x) ≈ x + @test fftdims(P) == dims fftw_brfft = complex.(size(x, dims) .* x) @test AbstractFFTs.brfft(ry, size(x, dims), dims) ≈ fftw_brfft P = plan_brfft(ry, size(x, dims), dims) @test P * ry ≈ fftw_brfft @test P \ (P * ry) ≈ ry - + @test fftdims(P) == dims + fftw_irfft = complex.(x) @test AbstractFFTs.irfft(ry, size(x, dims), dims) ≈ fftw_irfft P = plan_irfft(ry, size(x, dims), dims) @test P * ry ≈ fftw_irfft @test P \ (P * ry) ≈ ry + @test fftdims(P) == dims end end @@ -187,7 +193,7 @@ end # normalization should be inferable even if region is only inferred as ::Any, # need to wrap in another function to test this (note that p.region::Any for # p::TestPlan) - f9(p::Plan{T}, sz) where {T} = AbstractFFTs.normalization(real(T), sz, p.region) + f9(p::Plan{T}, sz) where {T} = AbstractFFTs.normalization(real(T), sz, fftdims(p)) @test @inferred(f9(plan_fft(zeros(10), 1), 10)) == 1/10 end