Skip to content

Commit

Permalink
Fully omit extra allocation in staticstructbroadcast.
Browse files Browse the repository at this point in the history
Now we get elements via `StaticArrays.__broadcast`
  • Loading branch information
N5N3 committed Mar 3, 2023
1 parent 95da1c6 commit b131762
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 15 deletions.
44 changes: 29 additions & 15 deletions ext/StructArraysStaticArraysExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -45,29 +45,43 @@ StructArrays.component(s::FieldArray, i) = invoke(StructArrays.component, Tuple{
StructArrays.createinstance(T::Type{<:FieldArray}, args...) = invoke(StructArrays.createinstance, Tuple{Type{<:Any}, Vararg}, T, args...)

# Broadcast overload
@loadext using StaticArrays: StaticArrayStyle, similar_type
@loadext using StaticArrays: StaticArrayStyle, similar_type, Size, SOneTo
@loadext using StaticArrays: broadcast_flatten, broadcast_sizes, first_statictype, __broadcast
@loadext using StructArrays: isnonemptystructtype
using Base.Broadcast: Broadcasted

# StaticArrayStyle has no similar defined.
# Overload `try_struct_copy` instead.
@inline function StructArrays.try_struct_copy(bc::Broadcasted{StaticArrayStyle{M}}) where {M}
sa = copy(bc)
ET = eltype(sa)
isnonemptystructtype(ET) || return sa
elements = Tuple(sa)
@static if VERSION >= v"1.7"
arrs = ntuple(Val(fieldcount(ET))) do i
similar_type(sa, fieldtype(ET, i))(_getfields(elements, i))
end
flat = broadcast_flatten(bc); as = flat.args; f = flat.f
argsizes = broadcast_sizes(as...)
ax = axes(bc)
ax isa Tuple{Vararg{SOneTo}} || error("Dimension is not static. Please file a bug.")
return _broadcast(f, Size(map(length, ax)), argsizes, as...)
end

@inline function _broadcast(f, sz::Size{newsize}, s::Tuple{Vararg{Size}}, a...) where {newsize}
first_staticarray = first_statictype(a...)
elements, ET = if prod(newsize) == 0
# Use inference to get eltype in empty case (see also comments in _map)
eltys = Tuple{map(eltype, a)...}
(), Core.Compiler.return_type(f, eltys)

Check warning on line 68 in ext/StructArraysStaticArraysExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/StructArraysStaticArraysExt.jl#L67-L68

Added lines #L67 - L68 were not covered by tests
else
_fieldtype(::Type{T}) where {T} = i -> fieldtype(T, i)
__fieldtype = _fieldtype(ET)
arrs = ntuple(Val(fieldcount(ET))) do i
similar_type(sa, __fieldtype(i))(_getfields(elements, i))
end
temp = __broadcast(f, sz, s, a...)
temp, eltype(temp)
end
if isnonemptystructtype(ET)
@static if VERSION >= v"1.7"

Check warning on line 74 in ext/StructArraysStaticArraysExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/StructArraysStaticArraysExt.jl#L74

Added line #L74 was not covered by tests
arrs = ntuple(Val(fieldcount(ET))) do i
@inbounds similar_type(first_staticarray, fieldtype(ET, i), sz)(_getfields(elements, i))
end
else
similarET(::Type{SA}, ::Type{T}) where {SA, T} = i -> @inbounds similar_type(SA, fieldtype(T, i), sz)(_getfields(elements, i))
arrs = ntuple(similarET(first_staticarray, ET), Val(fieldcount(ET)))
end
return StructArray{ET}(arrs)
end
return StructArray{ET}(arrs)
@inbounds return similar_type(first_staticarray, ET, sz)(elements)
end

@inline function _getfields(x::Tuple, i::Int)
Expand Down
2 changes: 2 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1297,8 +1297,10 @@ Base.BroadcastStyle(::Broadcast.ArrayStyle{MyArray2}, S::Broadcast.DefaultArrayS

@testset "allocation test" begin
a = StructArray{ComplexF64}(undef, 1)
sa = StructArray{ComplexF64}((SizedVector{1}(a.re), SizedVector{1}(a.re)))
allocated(a) = @allocated a .+ 1
@test allocated(a) == 2allocated(a.re)
@test allocated(sa) == 2allocated(sa.re)
allocated2(a) = @allocated a .= complex.(a.im, a.re)
@test allocated2(a) == 0
end
Expand Down

0 comments on commit b131762

Please sign in to comment.