Skip to content

Commit

Permalink
WIP rework conv interface
Browse files Browse the repository at this point in the history
  • Loading branch information
martinholters committed Feb 29, 2024
1 parent 56dd6bd commit addb53b
Show file tree
Hide file tree
Showing 3 changed files with 65 additions and 30 deletions.
2 changes: 1 addition & 1 deletion src/DSP.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ using LinearAlgebra: mul!, rmul!
using IterTools: subsets
using Compat: Compat

export conv, deconv, filt, filt!, xcorr
export conv, conv!, deconv, filt, filt!, xcorr

# This function has methods added in `periodograms` but is not exported,
# so we define it here so one can do `DSP.allocate_output` instead of
Expand Down
56 changes: 30 additions & 26 deletions src/dspbase.jl
Original file line number Diff line number Diff line change
Expand Up @@ -483,16 +483,16 @@ end
# Assumes u is larger than, or the same size as, v
# nfft should be greater than or equal to 2*sv-1
function unsafe_conv_kern_os!(out,
output_indices,
u::AbstractArray{<:Any, N},
v,
su,
sv,
sout,
nffts) where N
sout = size(out)
su = size(u)
sv = size(v)
u_start = first.(axes(u))
out_axes = axes(out)
out_start = first.(out_axes)
out_stop = last.(out_axes)
out_start = Tuple(first(output_indices))
out_stop = Tuple(last(output_indices))
ideal_save_blocksize = nffts .- sv .+ 1
# Number of samples that are "missing" if the output is smaller than the
# valid portion of the convolution
Expand Down Expand Up @@ -604,6 +604,7 @@ function unsafe_conv_kern_os!(out,
end

function _conv_kern_fft!(out,
output_indices,
u::AbstractArray{T, N},
v::AbstractArray{T, N}) where {T<:Real, N}
outsize = size(out)
Expand All @@ -616,11 +617,11 @@ function _conv_kern_fft!(out,
uf .*= vf
raw_out = irfft(uf, nffts[1])
copyto!(out,
CartesianIndices(out),
output_indices,
raw_out,
CartesianIndices(UnitRange.(1, outsize)))
end
function _conv_kern_fft!(out, u, v)
function _conv_kern_fft!(out, output_indices, u, v)
outsize = size(out)
nffts = nextfastfft(outsize)
upad = _zeropad(u, nffts)
Expand All @@ -632,40 +633,34 @@ function _conv_kern_fft!(out, u, v)
upad .*= vpad
ip! * upad
copyto!(out,
CartesianIndices(out),
output_indices,
upad,
CartesianIndices(UnitRange.(1, outsize)))
end

# May switch argument order
function _conv_fft!(out, u, v)
function _conv_fft!(out, output_indices, u, v)
su = size(u)
sv = size(v)
outsize = size(out)
if output_indices != CartesianIndices(out)
fill!(out, zero(eltype(out)))
end
os_nffts = su >= sv ? map(optimalfftfiltlength, sv, su) : map(optimalfftfiltlength, su, sv)
if any(os_nffts .< outsize)
# v should be smaller than u for good performance
if su >= sv
return unsafe_conv_kern_os!(out, u, v, su, sv, outsize, os_nffts)
return unsafe_conv_kern_os!(out, output_indices, u, v, os_nffts)
else
return unsafe_conv_kern_os!(out, v, u, sv, su, outsize, os_nffts)
return unsafe_conv_kern_os!(out, output_indices, v, u, os_nffts)
end
else
return _conv_kern_fft!(out, u, v)
return _conv_kern_fft!(out, output_indices, u, v)
end
end

function _conv_td!(out, u::AbstractArray{<:Number, N}, v::AbstractArray{<:Number, N}) where {N}
calc_index_offset(ao::Base.OneTo, au::Base.OneTo, av::Base.OneTo) = 1
calc_index_offset(ao::Base.OneTo, au, av) = # first(au) + first(av) - 1
throw(ArgumentError("output must have offset axes if the input has"))
calc_index_offset(ao, au::Base.OneTo, av::Base.OneTo) = # 2
throw(ArgumentError("output must not have offset axes if none of the inputs has"))
calc_index_offset(ao, au, av) = 0
index_offset = CartesianIndex(map(calc_index_offset, axes(out), axes(u), axes(v)))
output_indices = CartesianIndices(map(axes(u), axes(v)) do au, av
return (first(au)+first(av)):(last(au)+last(av))
end) .- index_offset
function _conv_td!(out, output_indices, u::AbstractArray{<:Number, N}, v::AbstractArray{<:Number, N}) where {N}
index_offset = first(CartesianIndices(u)) + first(CartesianIndices(v)) - first(output_indices)
checkbounds(out, output_indices)
fill!(out, zero(eltype(out)))
for m in CartesianIndices(u), n in CartesianIndices(v)
Expand All @@ -682,6 +677,15 @@ function conv!(
v::AbstractArray{<:Number, N};
algorithm=T <: FFTTypes ? :auto : :direct
) where {T<:Number, N}
calc_index_offset(ao::Base.OneTo, au::Base.OneTo, av::Base.OneTo) = 1
calc_index_offset(ao::Base.OneTo, au, av) = # first(au) + first(av) - 1

Check warning on line 681 in src/dspbase.jl

View check run for this annotation

Codecov / codecov/patch

src/dspbase.jl#L681

Added line #L681 was not covered by tests
throw(ArgumentError("output must have offset axes if the input has"))
calc_index_offset(ao, au::Base.OneTo, av::Base.OneTo) = # 2

Check warning on line 683 in src/dspbase.jl

View check run for this annotation

Codecov / codecov/patch

src/dspbase.jl#L683

Added line #L683 was not covered by tests
throw(ArgumentError("output must not have offset axes if none of the inputs has"))
calc_index_offset(ao, au, av) = 0
output_indices = CartesianIndices(map(axes(out), axes(u), axes(v)) do ao, au, av
return (first(au)+first(av) : last(au)+last(av)) .- calc_index_offset(ao, au, av)
end)
if algorithm===:auto
if length(u) * length(v) < 2^16 # TODO: better heuristic
algorithm = :direct
Expand All @@ -690,9 +694,9 @@ function conv!(
end
end
if algorithm===:direct
return _conv_td!(out, u, v)
return _conv_td!(out, output_indices, u, v)
elseif algorithm===:fft
return _conv_fft!(out, u, v)
return _conv_fft!(out, output_indices, u, v)
else
throw(ArgumentError("algorithm must be :auto, :direct, or :fft"))

Check warning on line 701 in src/dspbase.jl

View check run for this annotation

Codecov / codecov/patch

src/dspbase.jl#L701

Added line #L701 was not covered by tests
end
Expand Down
37 changes: 34 additions & 3 deletions test/dsp.jl
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,12 @@ end
@test conv(f32a, b) fexp
@test conv(fb, a) fexp

u = rand(100)
# similar sizes so algorithm=:fft chooses direct fft
u = rand(190)
v = rand(200)
@test conv(u, v; algorithm=:direct) conv(u, v; algorithm=:fft)
# very different sizes so algorithm=:fft chooses overlap-save
u = rand(5)
v = rand(200)
@test conv(u, v; algorithm=:direct) conv(u, v; algorithm=:fft)

Expand All @@ -73,6 +78,32 @@ end
offset_arr_f = OffsetVector{Float64}(undef, -1:2)
offset_arr_f[:] = fa
@test conv(offset_arr_f, 1:3) OffsetVector(fexp, 0:5)

for M in [10, 200], N in [10, 200], T in [Float64, ComplexF64]
u = rand(T, M)
v = rand(T, N)
u_off = OffsetVector(u, 23)
v_off = OffsetVector(v, -42)
@test conv(u, v; algorithm=:direct) conv(u, v; algorithm=:fft)
@test conv(u_off, v_off; algorithm=:direct) conv(u_off, v_off; algorithm=:fft)
@test conv(u, v) == conv(u_off, v_off)[23-42+2:23-42+N+M]

for algorithm in [:direct, :fft]
# pre-allocated non-offset output larger than necessary
out = ones(T, M+N+10)
conv!(out, u, v; algorithm)
@test out[1:M+N-1] conv(u, v; algorithm) # why can't this be == ?
@test all(iszero, out[M+N:end])

# pre-allocated output with offset larger than necessary
out = OffsetVector(ones(T, M+N+10), 23-42-5)
conv!(out, u_off, v_off; algorithm)
@test out[23-42+2:23-42+N+M] conv(u, v; algorithm) # why can't this be == ?
@test all(iszero, out[begin:23-42+1])
@test all(iszero, out[23-42+N+M+1:end])
end
end

# Issue #352
@test conv([1//2, 1//3, 1//4], [1, 2]) [1//2, 4//3, 11//12, 1//2]
# Non-numerical arrays should not be convolved
Expand Down Expand Up @@ -206,9 +237,9 @@ end
sv, v = os_test_data(T, nv, N)
sout = su .+ sv .- 1
out = similar(u, T, sout)
unsafe_conv_kern_os!(out, u, v, su, sv, sout, nffts)
unsafe_conv_kern_os!(out, CartesianIndices(out), u, v, nffts)
os_out = copy(out)
_conv_kern_fft!(out, u, v)
_conv_kern_fft!(out, CartesianIndices(out), u, v)
@test out os_out
end
Ns = [1, 2, 3]
Expand Down

0 comments on commit addb53b

Please sign in to comment.