Skip to content

Commit

Permalink
Implement time-domain convolution and use it for integers (#545)
Browse files Browse the repository at this point in the history
  • Loading branch information
martinholters authored Nov 5, 2024
1 parent 1dae6a3 commit 7652e2d
Show file tree
Hide file tree
Showing 6 changed files with 231 additions and 97 deletions.
6 changes: 6 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,12 @@ Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"

[weakdeps]
OffsetArrays = "6fe1bfb0-de20-5000-8ca7-80f57d26f881"

[extensions]
OffsetArraysExt = "OffsetArrays"

[compat]
Bessels = "0.2"
DelimitedFiles = "1.6"
Expand Down
1 change: 1 addition & 0 deletions docs/src/convolutions.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

```@docs
conv
conv!
deconv
xcorr
```
7 changes: 7 additions & 0 deletions ext/OffsetArraysExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
module OffsetArraysExt
import DSP
import OffsetArrays

DSP.conv_with_offset(::OffsetArrays.IdOffsetRange) = true

end
2 changes: 1 addition & 1 deletion src/DSP.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ using FFTW
using LinearAlgebra: mul!, rmul!
using IterTools: subsets

export conv, deconv, filt, filt!, xcorr
export conv, conv!, deconv, filt, filt!, xcorr

# This function has methods added in `periodograms` but is not exported,
# so we define it here so one can do `DSP.allocate_output` instead of
Expand Down
227 changes: 138 additions & 89 deletions src/dspbase.jl
Original file line number Diff line number Diff line change
Expand Up @@ -488,16 +488,16 @@ end
# Assumes u is larger than, or the same size as, v
# nfft should be greater than or equal to 2*sv-1
function unsafe_conv_kern_os!(out,
output_indices,
u::AbstractArray{<:Any, N},
v,
su,
sv,
sout,
nffts) where N
sout = size(out)
su = size(u)
sv = size(v)
u_start = first.(axes(u))
out_axes = axes(out)
out_start = first.(out_axes)
out_stop = last.(out_axes)
out_start = Tuple(first(output_indices))
out_stop = Tuple(last(output_indices))
ideal_save_blocksize = nffts .- sv .+ 1
# Number of samples that are "missing" if the output is smaller than the
# valid portion of the convolution
Expand All @@ -507,7 +507,7 @@ function unsafe_conv_kern_os!(out,
nblocks = cld.(sout, save_blocksize)

# Pre-allocation
tdbuff, fdbuff, p, ip = os_prepare_conv(u, nffts)
tdbuff, fdbuff, p, ip = os_prepare_conv(out, nffts)
tdbuff_axes = axes(tdbuff)

# Transform the smaller filter
Expand Down Expand Up @@ -608,129 +608,178 @@ function unsafe_conv_kern_os!(out,
out
end

function _conv_kern_fft!(out,
u::AbstractArray{T, N},
v::AbstractArray{T, N},
su,
sv,
outsize,
nffts) where {T<:Real, N}
padded = _zeropad(u, nffts)
function _conv_kern_fft!(out::AbstractArray{T, N},
output_indices,
u::AbstractArray{<:Real, N},
v::AbstractArray{<:Real, N}) where {T<:Real, N}
outsize = size(output_indices)
nffts = nextfastfft(outsize)
padded = _zeropad!(similar(u, T, nffts), u)
p = plan_rfft(padded)
uf = p * padded
_zeropad!(padded, v)
vf = p * padded
uf .*= vf
raw_out = irfft(uf, nffts[1])
copyto!(out,
CartesianIndices(out),
output_indices,
raw_out,
CartesianIndices(UnitRange.(1, outsize)))
end
function _conv_kern_fft!(out, u, v, su, sv, outsize, nffts)
upad = _zeropad(u, nffts)
vpad = _zeropad(v, nffts)
function _conv_kern_fft!(out::AbstractArray{T}, output_indices, u, v) where {T}
outsize = size(output_indices)
nffts = nextfastfft(outsize)
upad = _zeropad!(similar(u, T, nffts), u)
vpad = _zeropad!(similar(v, T, nffts), v)
p! = plan_fft!(upad)
ip! = inv(p!)
p! * upad # Operates in place on upad
p! * vpad
upad .*= vpad
ip! * upad
copyto!(out,
CartesianIndices(out),
output_indices,
upad,
CartesianIndices(UnitRange.(1, outsize)))
end

# v should be smaller than u for good performance
function _conv_fft!(out, u, v, su, sv, outsize)
os_nffts = map(optimalfftfiltlength, sv, su)
if any(os_nffts .< outsize)
unsafe_conv_kern_os!(out, u, v, su, sv, outsize, os_nffts)
function _conv_td!(out, output_indices, u::AbstractArray{<:Number, N}, v::AbstractArray{<:Number, N}) where {N}
index_offset = first(CartesianIndices(u)) + first(CartesianIndices(v)) - first(output_indices)
checkbounds(out, output_indices)
fill!(out, zero(eltype(out)))
if size(u, 1) size(v, 1) # choose more efficient iteration order
for m in CartesianIndices(u), n in CartesianIndices(v)
@inbounds out[n+m - index_offset] = muladd(u[m], v[n], out[n+m - index_offset])
end
else
nffts = nextfastfft(outsize)
_conv_kern_fft!(out, u, v, su, sv, outsize, nffts)
for n in CartesianIndices(v), m in CartesianIndices(u)
@inbounds out[n+m - index_offset] = muladd(u[m], v[n], out[n+m - index_offset])
end
end
return out
end

# whether the given axis are to be considered to carry an offset for `conv!` and `conv`
conv_with_offset(::Base.OneTo) = false
conv_with_offset(a::Any) = throw(ArgumentError("unsupported axis type $(typeof(a))"))

# For arrays with weird offsets
function _conv_similar(u, outsize, axesu, axesv)
out_offsets = first.(axesu) .+ first.(axesv)
out_axes = UnitRange.(out_offsets, out_offsets .+ outsize .- 1)
similar(u, out_axes)
end
function _conv_similar(
u, outsize, ::NTuple{<:Any, Base.OneTo{Int}}, ::NTuple{<:Any, Base.OneTo{Int}}
)
similar(u, outsize)
end
_conv_similar(u, v, outsize) = _conv_similar(u, outsize, axes(u), axes(v))

# Does convolution, will not switch argument order
function _conv!(out, u, v, su, sv, outsize)
# TODO: Add spatial / time domain algorithm
_conv_fft!(out, u, v, su, sv, outsize)
end

# Does convolution, will not switch argument order
function _conv(u, v, su, sv)
outsize = su .+ sv .- 1
out = _conv_similar(u, v, outsize)
_conv!(out, u, v, su, sv, outsize)
end

# We use this type definition for clarity
const RealOrComplexFloat = Union{AbstractFloat, Complex{T} where T<:AbstractFloat}
const FFTTypes = Union{Float32, Float64, ComplexF32, ComplexF64}

# May switch argument order
"""
conv(u,v)
Convolution of two arrays. Uses either FFT convolution or overlap-save,
depending on the size of the input. `u` and `v` can be N-dimensional arrays,
with arbitrary indexing offsets, but their axes must be a `UnitRange`.
conv!(out, u, v; algorithm=:auto)
Convolution of two arrays `u` and `v` with the result stored in `out`. `out`
must be large enough to store the entire result; if it is even larger, the
excess entries will be zeroed.
`out`, `u`, and `v` can be N-dimensional arrays, with arbitrary indexing
offsets. If none of them has offset axes,
`size(out,d) ≥ size(u,d) + size(v,d) - 1` must hold. If both input and output
have offset axes, `firstindex(out,d) ≤ firstindex(u,d) + firstindex(v,d)` and
`lastindex(out,d) ≥ lastindex(u,d) + lastindex(v,d)` must hold (for d = 1,...,N).
A mix of offset and non-offset axes between input and output is not permitted.
The `algorithm` keyword allows choosing the algorithm to use:
* `:direct`: Evaluates the convolution sum in time domain.
* `:fft_simple`: Evaluates the convolution as a product in the frequency domain.
* `:fft_overlapsave`: Evaluates the convolution block-wise as a product in the
frequency domain, overlapping the resulting blocks.
* `:fft`: Selects the faster of `:fft_simple` and `:fft_overlapsave` (as
estimated from the input size).
* `:fast`: Selects the fastest of `:direct`, `:fft_simple` and
`:fft_overlapsave` (as estimated from the input size).
* `:auto` (default): Equivalent to `:fast` if the data type is known to be
suitable for FFT-based computation, equivalent to `:direct` otherwise.
!!! warning
The choices made by `:fft`, `:fast`, and `:auto` are based on performance
heuristics which may not result in the fastest algorithm in all cases. If
best performance for a certain size/type combination is required, it is
advised to do individual benchmarking and explicitly specify the desired
algorithm.
"""
function conv(u::AbstractArray{T, N},
v::AbstractArray{T, N}) where {T<:RealOrComplexFloat, N}
su = size(u)
sv = size(v)
if length(u) >= length(v)
_conv(u, v, su, sv)
function conv!(
out::AbstractArray{T, N},
u::AbstractArray{<:Number, N},
v::AbstractArray{<:Number, N};
algorithm=:auto
) where {T<:Number, N}
output_indices = CartesianIndices(map(axes(out), axes(u), axes(v)) do ao, au, av
input_has_offset = conv_with_offset(au) || conv_with_offset(av)
if input_has_offset !== conv_with_offset(ao)
throw(ArgumentError("output must have offset axes if and only if the input has"))
end
offset = input_has_offset ? 0 : 1
return (first(au)+first(av) : last(au)+last(av)) .- offset
end)

if algorithm===:auto
algorithm = T <: FFTTypes ? :fast : :direct
end
if algorithm===:fast
if length(u) * length(v) < 2^16 # TODO: better heuristic
algorithm = :direct
else
algorithm = :fft
end
end
if algorithm===:direct
return _conv_td!(out, output_indices, u, v)
else
_conv(v, u, sv, su)
if output_indices != CartesianIndices(out)
fill!(out, zero(eltype(out)))
end
os_nffts = length(u) >= length(v) ? map(optimalfftfiltlength, size(v), size(u)) : map(optimalfftfiltlength, size(u), size(v))
if algorithm===:fft
if any(os_nffts .< size(output_indices))
algorithm = :fft_overlapsave
else
algorithm = :fft_simple
end
end
if algorithm === :fft_overlapsave
# v should be smaller than u for good performance
if length(u) >= length(v)
return unsafe_conv_kern_os!(out, output_indices, u, v, os_nffts)
else
return unsafe_conv_kern_os!(out, output_indices, v, u, os_nffts)
end
elseif algorithm === :fft_simple
return _conv_kern_fft!(out, output_indices, u, v)
else
throw(ArgumentError("algorithm must be :auto, :fast, :direct, :fft, :fft_simple, or :fft_overlapsave"))
end
end
end

function conv(u::AbstractArray{<:RealOrComplexFloat, N},
v::AbstractArray{<:RealOrComplexFloat, N}) where N
fu, fv = promote(u, v)
conv(fu, fv)
end

conv(u::AbstractArray{<:Integer, N}, v::AbstractArray{<:Integer, N}) where {N} =
round.(Int, conv(float(u), float(v)))
conv_output_axis(au, av) =
conv_with_offset(au) || conv_with_offset(av) ?
(first(au)+first(av):last(au)+last(av)) : Base.OneTo(last(au) + last(av) - 1)

conv(u::AbstractArray{<:Number, N}, v::AbstractArray{<:Number, N}) where {N} =
conv(float(u), float(v))

function conv(u::AbstractArray{<:Number, N},
v::AbstractArray{<:RealOrComplexFloat, N}) where N
conv(float(u), v)
end
"""
conv(u, v; algorithm)
function conv(u::AbstractArray{<:RealOrComplexFloat, N},
v::AbstractArray{<:Number, N}) where N
conv(u, float(v))
Convolution of two arrays. A convolution algorithm is automatically chosen among
direct convolution, FFT, or FFT overlap-save, depending on the size of the
input, unless explicitly specified with the `algorithm` keyword argument; see
[`conv!`](@ref) for details.
"""
function conv(
u::AbstractArray{Tu, N}, v::AbstractArray{Tv, N}; kwargs...
) where {Tu<:Number, Tv<:Number, N}
T = promote_type(Tu, Tv)
out_axes = map(conv_output_axis, axes(u), axes(v))
out = similar(u, T, out_axes)
return conv!(out, u, v; kwargs...)
end

function conv(A::AbstractArray{<:Number, M},
B::AbstractArray{<:Number, N}) where {M, N}
B::AbstractArray{<:Number, N}; kwargs...) where {M, N}
if (M < N)
conv(cat(A, dims=N)::AbstractArray{eltype(A), N}, B)
conv(cat(A, dims=N)::AbstractArray{eltype(A), N}, B; kwargs...)
else
@assert M > N
conv(A, cat(B, dims=M)::AbstractArray{eltype(B), M})
conv(A, cat(B, dims=M)::AbstractArray{eltype(B), M}; kwargs...)
end
end

Expand Down
Loading

0 comments on commit 7652e2d

Please sign in to comment.