diff --git a/ext/StructArraysStaticArraysExt.jl b/ext/StructArraysStaticArraysExt.jl index a6b41d6..b3fc235 100644 --- a/ext/StructArraysStaticArraysExt.jl +++ b/ext/StructArraysStaticArraysExt.jl @@ -45,29 +45,43 @@ StructArrays.component(s::FieldArray, i) = invoke(StructArrays.component, Tuple{ StructArrays.createinstance(T::Type{<:FieldArray}, args...) = invoke(StructArrays.createinstance, Tuple{Type{<:Any}, Vararg}, T, args...) # Broadcast overload -@loadext using StaticArrays: StaticArrayStyle, similar_type +@loadext using StaticArrays: StaticArrayStyle, similar_type, Size, SOneTo +@loadext using StaticArrays: broadcast_flatten, broadcast_sizes, first_statictype, __broadcast @loadext using StructArrays: isnonemptystructtype using Base.Broadcast: Broadcasted # StaticArrayStyle has no similar defined. # Overload `try_struct_copy` instead. @inline function StructArrays.try_struct_copy(bc::Broadcasted{StaticArrayStyle{M}}) where {M} - sa = copy(bc) - ET = eltype(sa) - isnonemptystructtype(ET) || return sa - elements = Tuple(sa) - @static if VERSION >= v"1.7" - arrs = ntuple(Val(fieldcount(ET))) do i - similar_type(sa, fieldtype(ET, i))(_getfields(elements, i)) - end + flat = broadcast_flatten(bc); as = flat.args; f = flat.f + argsizes = broadcast_sizes(as...) + ax = axes(bc) + ax isa Tuple{Vararg{SOneTo}} || error("Dimension is not static. Please file a bug.") + return _broadcast(f, Size(map(length, ax)), argsizes, as...) +end + +@inline function _broadcast(f, sz::Size{newsize}, s::Tuple{Vararg{Size}}, a...) where {newsize} + first_staticarray = first_statictype(a...) + elements, ET = if prod(newsize) == 0 + # Use inference to get eltype in empty case (see also comments in _map) + eltys = Tuple{map(eltype, a)...} + (), Core.Compiler.return_type(f, eltys) else - _fieldtype(::Type{T}) where {T} = i -> fieldtype(T, i) - __fieldtype = _fieldtype(ET) - arrs = ntuple(Val(fieldcount(ET))) do i - similar_type(sa, __fieldtype(i))(_getfields(elements, i)) - end + temp = __broadcast(f, sz, s, a...) + temp, eltype(temp) + end + if isnonemptystructtype(ET) + @static if VERSION >= v"1.7" + arrs = ntuple(Val(fieldcount(ET))) do i + @inbounds similar_type(first_staticarray, fieldtype(ET, i), sz)(_getfields(elements, i)) + end + else + similarET(::Type{SA}, ::Type{T}) where {SA, T} = i -> @inbounds similar_type(SA, fieldtype(T, i), sz)(_getfields(elements, i)) + arrs = ntuple(similarET(first_staticarray, ET), Val(fieldcount(ET))) + end + return StructArray{ET}(arrs) end - return StructArray{ET}(arrs) + @inbounds return similar_type(first_staticarray, ET, sz)(elements) end @inline function _getfields(x::Tuple, i::Int) diff --git a/test/runtests.jl b/test/runtests.jl index c144311..b591c7c 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1297,8 +1297,10 @@ Base.BroadcastStyle(::Broadcast.ArrayStyle{MyArray2}, S::Broadcast.DefaultArrayS @testset "allocation test" begin a = StructArray{ComplexF64}(undef, 1) + sa = StructArray{ComplexF64}((SizedVector{1}(a.re), SizedVector{1}(a.re))) allocated(a) = @allocated a .+ 1 @test allocated(a) == 2allocated(a.re) + @test allocated(sa) == 2allocated(sa.re) allocated2(a) = @allocated a .= complex.(a.im, a.re) @test allocated2(a) == 0 end