diff --git a/src/DSP.jl b/src/DSP.jl index 66e0f1b12..f66f64c91 100644 --- a/src/DSP.jl +++ b/src/DSP.jl @@ -5,7 +5,7 @@ using LinearAlgebra: mul!, rmul! using IterTools: subsets using Compat: Compat -export conv, deconv, filt, filt!, xcorr +export conv, conv!, deconv, filt, filt!, xcorr # This function has methods added in `periodograms` but is not exported, # so we define it here so one can do `DSP.allocate_output` instead of diff --git a/src/dspbase.jl b/src/dspbase.jl index 915cc525e..31b9f1c2d 100644 --- a/src/dspbase.jl +++ b/src/dspbase.jl @@ -483,16 +483,16 @@ end # Assumes u is larger than, or the same size as, v # nfft should be greater than or equal to 2*sv-1 function unsafe_conv_kern_os!(out, + output_indices, u::AbstractArray{<:Any, N}, v, - su, - sv, - sout, nffts) where N + sout = size(out) + su = size(u) + sv = size(v) u_start = first.(axes(u)) - out_axes = axes(out) - out_start = first.(out_axes) - out_stop = last.(out_axes) + out_start = Tuple(first(output_indices)) + out_stop = Tuple(last(output_indices)) ideal_save_blocksize = nffts .- sv .+ 1 # Number of samples that are "missing" if the output is smaller than the # valid portion of the convolution @@ -604,6 +604,7 @@ function unsafe_conv_kern_os!(out, end function _conv_kern_fft!(out, + output_indices, u::AbstractArray{T, N}, v::AbstractArray{T, N}) where {T<:Real, N} outsize = size(out) @@ -616,11 +617,11 @@ function _conv_kern_fft!(out, uf .*= vf raw_out = irfft(uf, nffts[1]) copyto!(out, - CartesianIndices(out), + output_indices, raw_out, CartesianIndices(UnitRange.(1, outsize))) end -function _conv_kern_fft!(out, u, v) +function _conv_kern_fft!(out, output_indices, u, v) outsize = size(out) nffts = nextfastfft(outsize) upad = _zeropad(u, nffts) @@ -632,40 +633,34 @@ function _conv_kern_fft!(out, u, v) upad .*= vpad ip! * upad copyto!(out, - CartesianIndices(out), + output_indices, upad, CartesianIndices(UnitRange.(1, outsize))) end # May switch argument order -function _conv_fft!(out, u, v) +function _conv_fft!(out, output_indices, u, v) su = size(u) sv = size(v) outsize = size(out) + if output_indices != CartesianIndices(out) + fill!(out, zero(eltype(out))) + end os_nffts = su >= sv ? map(optimalfftfiltlength, sv, su) : map(optimalfftfiltlength, su, sv) if any(os_nffts .< outsize) # v should be smaller than u for good performance if su >= sv - return unsafe_conv_kern_os!(out, u, v, su, sv, outsize, os_nffts) + return unsafe_conv_kern_os!(out, output_indices, u, v, os_nffts) else - return unsafe_conv_kern_os!(out, v, u, sv, su, outsize, os_nffts) + return unsafe_conv_kern_os!(out, output_indices, v, u, os_nffts) end else - return _conv_kern_fft!(out, u, v) + return _conv_kern_fft!(out, output_indices, u, v) end end -function _conv_td!(out, u::AbstractArray{<:Number, N}, v::AbstractArray{<:Number, N}) where {N} - calc_index_offset(ao::Base.OneTo, au::Base.OneTo, av::Base.OneTo) = 1 - calc_index_offset(ao::Base.OneTo, au, av) = # first(au) + first(av) - 1 - throw(ArgumentError("output must have offset axes if the input has")) - calc_index_offset(ao, au::Base.OneTo, av::Base.OneTo) = # 2 - throw(ArgumentError("output must not have offset axes if none of the inputs has")) - calc_index_offset(ao, au, av) = 0 - index_offset = CartesianIndex(map(calc_index_offset, axes(out), axes(u), axes(v))) - output_indices = CartesianIndices(map(axes(u), axes(v)) do au, av - return (first(au)+first(av)):(last(au)+last(av)) - end) .- index_offset +function _conv_td!(out, output_indices, u::AbstractArray{<:Number, N}, v::AbstractArray{<:Number, N}) where {N} + index_offset = first(CartesianIndices(u)) + first(CartesianIndices(v)) - first(output_indices) checkbounds(out, output_indices) fill!(out, zero(eltype(out))) for m in CartesianIndices(u), n in CartesianIndices(v) @@ -682,6 +677,15 @@ function conv!( v::AbstractArray{<:Number, N}; algorithm=T <: FFTTypes ? :auto : :direct ) where {T<:Number, N} + calc_index_offset(ao::Base.OneTo, au::Base.OneTo, av::Base.OneTo) = 1 + calc_index_offset(ao::Base.OneTo, au, av) = # first(au) + first(av) - 1 + throw(ArgumentError("output must have offset axes if the input has")) + calc_index_offset(ao, au::Base.OneTo, av::Base.OneTo) = # 2 + throw(ArgumentError("output must not have offset axes if none of the inputs has")) + calc_index_offset(ao, au, av) = 0 + output_indices = CartesianIndices(map(axes(out), axes(u), axes(v)) do ao, au, av + return (first(au)+first(av) : last(au)+last(av)) .- calc_index_offset(ao, au, av) + end) if algorithm===:auto if length(u) * length(v) < 2^16 # TODO: better heuristic algorithm = :direct @@ -690,9 +694,9 @@ function conv!( end end if algorithm===:direct - return _conv_td!(out, u, v) + return _conv_td!(out, output_indices, u, v) elseif algorithm===:fft - return _conv_fft!(out, u, v) + return _conv_fft!(out, output_indices, u, v) else throw(ArgumentError("algorithm must be :auto, :direct, or :fft")) end diff --git a/test/dsp.jl b/test/dsp.jl index 7fd98d0af..409902c2e 100644 --- a/test/dsp.jl +++ b/test/dsp.jl @@ -59,7 +59,12 @@ end @test conv(f32a, b) ≈ fexp @test conv(fb, a) ≈ fexp - u = rand(100) + # similar sizes so algorithm=:fft chooses direct fft + u = rand(190) + v = rand(200) + @test conv(u, v; algorithm=:direct) ≈ conv(u, v; algorithm=:fft) + # very different sizes so algorithm=:fft chooses overlap-save + u = rand(5) v = rand(200) @test conv(u, v; algorithm=:direct) ≈ conv(u, v; algorithm=:fft) @@ -73,6 +78,32 @@ end offset_arr_f = OffsetVector{Float64}(undef, -1:2) offset_arr_f[:] = fa @test conv(offset_arr_f, 1:3) ≈ OffsetVector(fexp, 0:5) + + for M in [10, 200], N in [10, 200], T in [Float64, ComplexF64] + u = rand(T, M) + v = rand(T, N) + u_off = OffsetVector(u, 23) + v_off = OffsetVector(v, -42) + @test conv(u, v; algorithm=:direct) ≈ conv(u, v; algorithm=:fft) + @test conv(u_off, v_off; algorithm=:direct) ≈ conv(u_off, v_off; algorithm=:fft) + @test conv(u, v) == conv(u_off, v_off)[23-42+2:23-42+N+M] + + for algorithm in [:direct, :fft] + # pre-allocated non-offset output larger than necessary + out = ones(T, M+N+10) + conv!(out, u, v; algorithm) + @test out[1:M+N-1] ≈ conv(u, v; algorithm) # why can't this be == ? + @test all(iszero, out[M+N:end]) + + # pre-allocated output with offset larger than necessary + out = OffsetVector(ones(T, M+N+10), 23-42-5) + conv!(out, u_off, v_off; algorithm) + @test out[23-42+2:23-42+N+M] ≈ conv(u, v; algorithm) # why can't this be == ? + @test all(iszero, out[begin:23-42+1]) + @test all(iszero, out[23-42+N+M+1:end]) + end + end + # Issue #352 @test conv([1//2, 1//3, 1//4], [1, 2]) ≈ [1//2, 4//3, 11//12, 1//2] # Non-numerical arrays should not be convolved @@ -206,9 +237,9 @@ end sv, v = os_test_data(T, nv, N) sout = su .+ sv .- 1 out = similar(u, T, sout) - unsafe_conv_kern_os!(out, u, v, su, sv, sout, nffts) + unsafe_conv_kern_os!(out, CartesianIndices(out), u, v, nffts) os_out = copy(out) - _conv_kern_fft!(out, u, v) + _conv_kern_fft!(out, CartesianIndices(out), u, v) @test out ≈ os_out end Ns = [1, 2, 3]