Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add region(::Plan) for accessing transformed region #65

Merged
merged 8 commits into from
Jun 30, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "AbstractFFTs"
uuid = "621f4979-c628-5d54-868e-fcf4e3e8185c"
version = "1.1.0"
version = "1.2.0"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ To define a new FFT implementation in your own module, you should
inverse plan.

* Define a new method `AbstractFFTs.plan_fft(x, region; kws...)` that returns a `MyPlan` for at least some types of
`x` and some set of dimensions `region`.
`x` and some set of dimensions `region`. The `region` (or a copy thereof) should be accessible via `fftdims(p::MyPlan)` (which defaults to `p.region`).

* Define a method of `LinearAlgebra.mul!(y, p::MyPlan, x)` (or `A_mul_B!(y, p::MyPlan, x)` on Julia prior to
0.7.0-DEV.3204) that computes the transform `p` of `x` and stores the result in `y`.
Expand Down
1 change: 1 addition & 0 deletions docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ AbstractFFTs.brfft
AbstractFFTs.plan_rfft
AbstractFFTs.plan_brfft
AbstractFFTs.plan_irfft
AbstractFFTs.fftdims
AbstractFFTs.fftshift
AbstractFFTs.ifftshift
AbstractFFTs.fftfreq
Expand Down
2 changes: 1 addition & 1 deletion src/AbstractFFTs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import ChainRulesCore
export fft, ifft, bfft, fft!, ifft!, bfft!,
plan_fft, plan_ifft, plan_bfft, plan_fft!, plan_ifft!, plan_bfft!,
rfft, irfft, brfft, plan_rfft, plan_irfft, plan_brfft,
fftshift, ifftshift, fftshift!, ifftshift!, Frequencies, fftfreq, rfftfreq
fftdims, fftshift, ifftshift, fftshift!, ifftshift!, Frequencies, fftfreq, rfftfreq

include("definitions.jl")
include("chainrules.jl")
Expand Down
14 changes: 14 additions & 0 deletions src/definitions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,18 @@ size(p::Plan, d) = size(p)[d]
ndims(p::Plan) = length(size(p))
length(p::Plan) = prod(size(p))::Int

"""
fftdims(p::Plan)

Return an iterable of the dimensions that are transformed by the FFT plan `p`.

# Implementation

For legacy reasons, the default definition of `fftdims` returns `p.region`.
Hence this method should be implemented only for `Plan` subtypes that do not store the transformed dimensions in a field named `region`.
"""
fftdims(p::Plan) = p.region

fftfloat(x) = _fftfloat(float(x))
_fftfloat(::Type{T}) where {T<:BlasReal} = T
_fftfloat(::Type{Float16}) = Float32
Expand Down Expand Up @@ -243,6 +255,8 @@ ScaledPlan(p::ScaledPlan, α::Number) = ScaledPlan(p.p, p.scale * α)

size(p::ScaledPlan) = size(p.p)

fftdims(p::ScaledPlan) = fftdims(p.p)

show(io::IO, p::ScaledPlan) = print(io, p.scale, " * ", p.p)
summary(p::ScaledPlan) = string(p.scale, " * ", summary(p.p))

Expand Down
10 changes: 8 additions & 2 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -60,18 +60,21 @@ end
@test eltype(P) === ComplexF64
@test P * x ≈ fftw_fft
@test P \ (P * x) ≈ x
@test fftdims(P) == dims

fftw_bfft = complex.(size(x, dims) .* x)
@test AbstractFFTs.bfft(y, dims) ≈ fftw_bfft
P = plan_bfft(x, dims)
@test P * y ≈ fftw_bfft
@test P \ (P * y) ≈ y
@test fftdims(P) == dims

fftw_ifft = complex.(x)
@test AbstractFFTs.ifft(y, dims) ≈ fftw_ifft
P = plan_ifft(x, dims)
@test P * y ≈ fftw_ifft
@test P \ (P * y) ≈ y
@test fftdims(P) == dims

# real FFT
fftw_rfft = fftw_fft[
Expand All @@ -84,18 +87,21 @@ end
@test eltype(P) === Int
@test P * x ≈ fftw_rfft
@test P \ (P * x) ≈ x
@test fftdims(P) == dims

fftw_brfft = complex.(size(x, dims) .* x)
@test AbstractFFTs.brfft(ry, size(x, dims), dims) ≈ fftw_brfft
P = plan_brfft(ry, size(x, dims), dims)
@test P * ry ≈ fftw_brfft
@test P \ (P * ry) ≈ ry

@test fftdims(P) == dims

fftw_irfft = complex.(x)
@test AbstractFFTs.irfft(ry, size(x, dims), dims) ≈ fftw_irfft
P = plan_irfft(ry, size(x, dims), dims)
@test P * ry ≈ fftw_irfft
@test P \ (P * ry) ≈ ry
@test fftdims(P) == dims
end
end

Expand Down Expand Up @@ -187,7 +193,7 @@ end
# normalization should be inferable even if region is only inferred as ::Any,
# need to wrap in another function to test this (note that p.region::Any for
# p::TestPlan)
f9(p::Plan{T}, sz) where {T} = AbstractFFTs.normalization(real(T), sz, p.region)
f9(p::Plan{T}, sz) where {T} = AbstractFFTs.normalization(real(T), sz, fftdims(p))
@test @inferred(f9(plan_fft(zeros(10), 1), 10)) == 1/10
end

Expand Down