Skip to content

Commit

Permalink
Rewrite conv OffsetArrays support to use extension
Browse files Browse the repository at this point in the history
  • Loading branch information
martinholters committed Nov 4, 2024
1 parent 892ad62 commit 34f42b2
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 14 deletions.
6 changes: 6 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,12 @@ Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"

[weakdeps]
OffsetArrays = "6fe1bfb0-de20-5000-8ca7-80f57d26f881"

[extensions]
OffsetArraysExt = "OffsetArrays"

[compat]
Bessels = "0.2"
DelimitedFiles = "1.6"
Expand Down
7 changes: 7 additions & 0 deletions ext/OffsetArraysExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
module OffsetArraysExt
import DSP
import OffsetArrays

DSP.conv_with_offset(::OffsetArrays.IdOffsetRange) = true

end
31 changes: 17 additions & 14 deletions src/dspbase.jl
Original file line number Diff line number Diff line change
Expand Up @@ -659,6 +659,10 @@ function _conv_td!(out, output_indices, u::AbstractArray{<:Number, N}, v::Abstra
return out
end

# whether the given axis are to be considered to carry an offset for `conv!` and `conv`
conv_with_offset(::Base.OneTo) = false
conv_with_offset(a::Any) = throw(ArgumentError("unsupported axis type $(typeof(a))"))

Check warning on line 664 in src/dspbase.jl

View check run for this annotation

Codecov / codecov/patch

src/dspbase.jl#L664

Added line #L664 was not covered by tests

const FFTTypes = Union{Float32, Float64, ComplexF32, ComplexF64}

"""
Expand Down Expand Up @@ -700,17 +704,14 @@ function conv!(
v::AbstractArray{<:Number, N};
algorithm=:auto
) where {T<:Number, N}
calc_index_offset(ao::Base.OneTo, au::Base.OneTo, av::Base.OneTo) = 1
calc_index_offset(ao::Base.OneTo, au::AbstractUnitRange, av::AbstractUnitRange) = # first(au) + first(av) - 1
throw(ArgumentError("output must have offset axes if the input has"))
calc_index_offset(ao::AbstractUnitRange, au::Base.OneTo, av::Base.OneTo) = # 2
throw(ArgumentError("output must not have offset axes if none of the inputs has"))
calc_index_offset(ao::AbstractUnitRange, au::AbstractUnitRange, av::AbstractUnitRange) = 0
output_indices = let calc_index_offset = calc_index_offset # prevent boxing
CartesianIndices(map(axes(out), axes(u), axes(v)) do ao, au, av
return (first(au)+first(av) : last(au)+last(av)) .- calc_index_offset(ao, au, av)
end)
end
output_indices = CartesianIndices(map(axes(out), axes(u), axes(v)) do ao, au, av
input_has_offset = conv_with_offset(au) || conv_with_offset(av)
if input_has_offset !== conv_with_offset(ao)
throw(ArgumentError("output must have offset axes if and only if the input has"))
end
offset = input_has_offset ? 0 : 1
return (first(au)+first(av) : last(au)+last(av)) .- offset
end)

if algorithm===:auto
algorithm = T <: FFTTypes ? :fast : :direct
Expand Down Expand Up @@ -751,6 +752,10 @@ function conv!(
end
end

conv_output_axis(au, av) =
conv_with_offset(au) || conv_with_offset(av) ?
(first(au)+first(av):last(au)+last(av)) : Base.OneTo(last(au) + last(av) - 1)

"""
conv(u, v; algorithm)
Expand All @@ -763,9 +768,7 @@ function conv(
u::AbstractArray{Tu, N}, v::AbstractArray{Tv, N}; kwargs...
) where {Tu<:Number, Tv<:Number, N}
T = promote_type(Tu, Tv)
out_axis(au, av) = (first(au)+first(av)):(last(au)+last(av))
out_axis(au::Base.OneTo, av::Base.OneTo) = Base.OneTo(last(au) + last(av) - 1)
out_axes = map(out_axis, axes(u), axes(v))
out_axes = map(conv_output_axis, axes(u), axes(v))
out = similar(u, T, out_axes)
return conv!(out, u, v; kwargs...)
end
Expand Down

0 comments on commit 34f42b2

Please sign in to comment.