Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement time-domain convolution and use it for integers #545

Merged
merged 11 commits into from
Nov 5, 2024
6 changes: 6 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,12 @@ Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"

[weakdeps]
OffsetArrays = "6fe1bfb0-de20-5000-8ca7-80f57d26f881"

[extensions]
OffsetArraysExt = "OffsetArrays"

[compat]
Bessels = "0.2"
DelimitedFiles = "1.6"
Expand Down
1 change: 1 addition & 0 deletions docs/src/convolutions.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

```@docs
conv
conv!
deconv
xcorr
```
7 changes: 7 additions & 0 deletions ext/OffsetArraysExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
module OffsetArraysExt
import DSP
import OffsetArrays

DSP.conv_with_offset(::OffsetArrays.IdOffsetRange) = true

end
2 changes: 1 addition & 1 deletion src/DSP.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ using FFTW
using LinearAlgebra: mul!, rmul!
using IterTools: subsets

export conv, deconv, filt, filt!, xcorr
export conv, conv!, deconv, filt, filt!, xcorr

# This function has methods added in `periodograms` but is not exported,
# so we define it here so one can do `DSP.allocate_output` instead of
Expand Down
227 changes: 138 additions & 89 deletions src/dspbase.jl
Original file line number Diff line number Diff line change
Expand Up @@ -488,16 +488,16 @@
# Assumes u is larger than, or the same size as, v
# nfft should be greater than or equal to 2*sv-1
function unsafe_conv_kern_os!(out,
output_indices,
u::AbstractArray{<:Any, N},
v,
su,
sv,
sout,
nffts) where N
sout = size(out)
su = size(u)
sv = size(v)
u_start = first.(axes(u))
out_axes = axes(out)
out_start = first.(out_axes)
out_stop = last.(out_axes)
out_start = Tuple(first(output_indices))
out_stop = Tuple(last(output_indices))
ideal_save_blocksize = nffts .- sv .+ 1
# Number of samples that are "missing" if the output is smaller than the
# valid portion of the convolution
Expand All @@ -507,7 +507,7 @@
nblocks = cld.(sout, save_blocksize)

# Pre-allocation
tdbuff, fdbuff, p, ip = os_prepare_conv(u, nffts)
tdbuff, fdbuff, p, ip = os_prepare_conv(out, nffts)
tdbuff_axes = axes(tdbuff)

# Transform the smaller filter
Expand Down Expand Up @@ -608,129 +608,178 @@
out
end

function _conv_kern_fft!(out,
u::AbstractArray{T, N},
v::AbstractArray{T, N},
su,
sv,
outsize,
nffts) where {T<:Real, N}
padded = _zeropad(u, nffts)
function _conv_kern_fft!(out::AbstractArray{T, N},
output_indices,
u::AbstractArray{<:Real, N},
v::AbstractArray{<:Real, N}) where {T<:Real, N}
outsize = size(output_indices)
nffts = nextfastfft(outsize)
padded = _zeropad!(similar(u, T, nffts), u)
p = plan_rfft(padded)
uf = p * padded
_zeropad!(padded, v)
vf = p * padded
uf .*= vf
raw_out = irfft(uf, nffts[1])
copyto!(out,
CartesianIndices(out),
output_indices,
raw_out,
CartesianIndices(UnitRange.(1, outsize)))
end
function _conv_kern_fft!(out, u, v, su, sv, outsize, nffts)
upad = _zeropad(u, nffts)
vpad = _zeropad(v, nffts)
function _conv_kern_fft!(out::AbstractArray{T}, output_indices, u, v) where {T}
outsize = size(output_indices)
nffts = nextfastfft(outsize)
upad = _zeropad!(similar(u, T, nffts), u)
vpad = _zeropad!(similar(v, T, nffts), v)
p! = plan_fft!(upad)
ip! = inv(p!)
p! * upad # Operates in place on upad
p! * vpad
upad .*= vpad
ip! * upad
copyto!(out,
CartesianIndices(out),
output_indices,
upad,
CartesianIndices(UnitRange.(1, outsize)))
end

# v should be smaller than u for good performance
function _conv_fft!(out, u, v, su, sv, outsize)
os_nffts = map(optimalfftfiltlength, sv, su)
if any(os_nffts .< outsize)
unsafe_conv_kern_os!(out, u, v, su, sv, outsize, os_nffts)
function _conv_td!(out, output_indices, u::AbstractArray{<:Number, N}, v::AbstractArray{<:Number, N}) where {N}
index_offset = first(CartesianIndices(u)) + first(CartesianIndices(v)) - first(output_indices)
checkbounds(out, output_indices)
fill!(out, zero(eltype(out)))
if size(u, 1) ≤ size(v, 1) # choose more efficient iteration order
for m in CartesianIndices(u), n in CartesianIndices(v)
@inbounds out[n+m - index_offset] = muladd(u[m], v[n], out[n+m - index_offset])
end
else
nffts = nextfastfft(outsize)
_conv_kern_fft!(out, u, v, su, sv, outsize, nffts)
for n in CartesianIndices(v), m in CartesianIndices(u)
@inbounds out[n+m - index_offset] = muladd(u[m], v[n], out[n+m - index_offset])
end
end
return out
end

# whether the given axis are to be considered to carry an offset for `conv!` and `conv`
conv_with_offset(::Base.OneTo) = false
conv_with_offset(a::Any) = throw(ArgumentError("unsupported axis type $(typeof(a))"))

Check warning on line 664 in src/dspbase.jl

View check run for this annotation

Codecov / codecov/patch

src/dspbase.jl#L664

Added line #L664 was not covered by tests

# For arrays with weird offsets
function _conv_similar(u, outsize, axesu, axesv)
out_offsets = first.(axesu) .+ first.(axesv)
out_axes = UnitRange.(out_offsets, out_offsets .+ outsize .- 1)
similar(u, out_axes)
end
function _conv_similar(
u, outsize, ::NTuple{<:Any, Base.OneTo{Int}}, ::NTuple{<:Any, Base.OneTo{Int}}
)
similar(u, outsize)
end
_conv_similar(u, v, outsize) = _conv_similar(u, outsize, axes(u), axes(v))

# Does convolution, will not switch argument order
function _conv!(out, u, v, su, sv, outsize)
# TODO: Add spatial / time domain algorithm
_conv_fft!(out, u, v, su, sv, outsize)
end

# Does convolution, will not switch argument order
function _conv(u, v, su, sv)
outsize = su .+ sv .- 1
out = _conv_similar(u, v, outsize)
_conv!(out, u, v, su, sv, outsize)
end

# We use this type definition for clarity
const RealOrComplexFloat = Union{AbstractFloat, Complex{T} where T<:AbstractFloat}
const FFTTypes = Union{Float32, Float64, ComplexF32, ComplexF64}
wheeheee marked this conversation as resolved.
Show resolved Hide resolved

# May switch argument order
"""
conv(u,v)

Convolution of two arrays. Uses either FFT convolution or overlap-save,
depending on the size of the input. `u` and `v` can be N-dimensional arrays,
with arbitrary indexing offsets, but their axes must be a `UnitRange`.
conv!(out, u, v; algorithm=:auto)

Convolution of two arrays `u` and `v` with the result stored in `out`. `out`
must be large enough to store the entire result; if it is even larger, the
excess entries will be zeroed.

`out`, `u`, and `v` can be N-dimensional arrays, with arbitrary indexing
offsets. If none of them has offset axes,
`size(out,d) ≥ size(u,d) + size(v,d) - 1` must hold. If both input and output
have offset axes, `firstindex(out,d) ≤ firstindex(u,d) + firstindex(v,d)` and
`lastindex(out,d) ≥ lastindex(u,d) + lastindex(v,d)` must hold (for d = 1,...,N).
A mix of offset and non-offset axes between input and output is not permitted.

The `algorithm` keyword allows choosing the algorithm to use:
* `:direct`: Evaluates the convolution sum in time domain.
* `:fft_simple`: Evaluates the convolution as a product in the frequency domain.
* `:fft_overlapsave`: Evaluates the convolution block-wise as a product in the
frequency domain, overlapping the resulting blocks.
* `:fft`: Selects the faster of `:fft_simple` and `:fft_overlapsave` (as
estimated from the input size).
* `:fast`: Selects the fastest of `:direct`, `:fft_simple` and
`:fft_overlapsave` (as estimated from the input size).
* `:auto` (default): Equivalent to `:fast` if the data type is known to be
suitable for FFT-based computation, equivalent to `:direct` otherwise.
martinholters marked this conversation as resolved.
Show resolved Hide resolved

!!! warning
The choices made by `:fft`, `:fast`, and `:auto` are based on performance
heuristics which may not result in the fastest algorithm in all cases. If
best performance for a certain size/type combination is required, it is
advised to do individual benchmarking and explicitly specify the desired
algorithm.
"""
function conv(u::AbstractArray{T, N},
v::AbstractArray{T, N}) where {T<:RealOrComplexFloat, N}
su = size(u)
sv = size(v)
if length(u) >= length(v)
_conv(u, v, su, sv)
function conv!(
out::AbstractArray{T, N},
u::AbstractArray{<:Number, N},
v::AbstractArray{<:Number, N};
algorithm=:auto
) where {T<:Number, N}
output_indices = CartesianIndices(map(axes(out), axes(u), axes(v)) do ao, au, av
input_has_offset = conv_with_offset(au) || conv_with_offset(av)
if input_has_offset !== conv_with_offset(ao)
throw(ArgumentError("output must have offset axes if and only if the input has"))
end
offset = input_has_offset ? 0 : 1
return (first(au)+first(av) : last(au)+last(av)) .- offset
end)

if algorithm===:auto
algorithm = T <: FFTTypes ? :fast : :direct
end
if algorithm===:fast
if length(u) * length(v) < 2^16 # TODO: better heuristic
algorithm = :direct
else
algorithm = :fft
end
end
if algorithm===:direct
return _conv_td!(out, output_indices, u, v)
else
_conv(v, u, sv, su)
if output_indices != CartesianIndices(out)
fill!(out, zero(eltype(out)))
end
os_nffts = length(u) >= length(v) ? map(optimalfftfiltlength, size(v), size(u)) : map(optimalfftfiltlength, size(u), size(v))
if algorithm===:fft
if any(os_nffts .< size(output_indices))
algorithm = :fft_overlapsave
else
algorithm = :fft_simple
end
end
if algorithm === :fft_overlapsave
# v should be smaller than u for good performance
if length(u) >= length(v)
return unsafe_conv_kern_os!(out, output_indices, u, v, os_nffts)
else
return unsafe_conv_kern_os!(out, output_indices, v, u, os_nffts)
end
elseif algorithm === :fft_simple
return _conv_kern_fft!(out, output_indices, u, v)
else
throw(ArgumentError("algorithm must be :auto, :fast, :direct, :fft, :fft_simple, or :fft_overlapsave"))
end
end
end

function conv(u::AbstractArray{<:RealOrComplexFloat, N},
v::AbstractArray{<:RealOrComplexFloat, N}) where N
fu, fv = promote(u, v)
conv(fu, fv)
end

conv(u::AbstractArray{<:Integer, N}, v::AbstractArray{<:Integer, N}) where {N} =
round.(Int, conv(float(u), float(v)))
conv_output_axis(au, av) =
conv_with_offset(au) || conv_with_offset(av) ?
(first(au)+first(av):last(au)+last(av)) : Base.OneTo(last(au) + last(av) - 1)

conv(u::AbstractArray{<:Number, N}, v::AbstractArray{<:Number, N}) where {N} =
conv(float(u), float(v))

function conv(u::AbstractArray{<:Number, N},
v::AbstractArray{<:RealOrComplexFloat, N}) where N
conv(float(u), v)
end
"""
conv(u, v; algorithm)

function conv(u::AbstractArray{<:RealOrComplexFloat, N},
v::AbstractArray{<:Number, N}) where N
conv(u, float(v))
Convolution of two arrays. A convolution algorithm is automatically chosen among
direct convolution, FFT, or FFT overlap-save, depending on the size of the
input, unless explicitly specified with the `algorithm` keyword argument; see
[`conv!`](@ref) for details.
"""
function conv(
u::AbstractArray{Tu, N}, v::AbstractArray{Tv, N}; kwargs...
) where {Tu<:Number, Tv<:Number, N}
T = promote_type(Tu, Tv)
out_axes = map(conv_output_axis, axes(u), axes(v))
out = similar(u, T, out_axes)
return conv!(out, u, v; kwargs...)
end

function conv(A::AbstractArray{<:Number, M},
B::AbstractArray{<:Number, N}) where {M, N}
B::AbstractArray{<:Number, N}; kwargs...) where {M, N}
if (M < N)
conv(cat(A, dims=N)::AbstractArray{eltype(A), N}, B)
conv(cat(A, dims=N)::AbstractArray{eltype(A), N}, B; kwargs...)
else
@assert M > N
conv(A, cat(B, dims=M)::AbstractArray{eltype(B), M})
conv(A, cat(B, dims=M)::AbstractArray{eltype(B), M}; kwargs...)
end
end

Expand Down
Loading
Loading