Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Using weakdeps to reduce load time on Julia v1.9 #270

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,5 @@
*.jl.*.cov
*.jl.mem
Manifest.toml

.vscode
15 changes: 14 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,16 @@ GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c"
Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c"

[weakdeps]
GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c"
Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c"

[extensions]
StructArraysGPUArraysCoreExt = "GPUArraysCore"
StructArraysStaticArraysCoreExt = "StaticArraysCore"
StructArraysTablesExt = "Tables"

[compat]
Adapt = "1, 2, 3"
DataAPI = "1"
Expand All @@ -20,14 +30,17 @@ julia = "1.6"

[extras]
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
JLArrays = "27aeb0d3-9eb9-45fb-866b-73c2ecf80fcb"
OffsetArrays = "6fe1bfb0-de20-5000-8ca7-80f57d26f881"
PooledArrays = "2dfb63ee-cc39-5dd5-95bd-886bf059d720"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c"
Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
TypedTables = "9d95f2ec-7b3d-5a63-8d20-e2491e220bb9"
WeakRefStrings = "ea10d353-3f73-51f8-a26c-33c1cb351aa5"

[targets]
test = ["Test", "JLArrays", "StaticArrays", "OffsetArrays", "PooledArrays", "TypedTables", "WeakRefStrings", "Documenter", "SparseArrays"]
test = ["Test", "JLArrays", "StaticArrays", "OffsetArrays", "PooledArrays", "TypedTables", "WeakRefStrings", "Documenter", "SparseArrays", "GPUArraysCore", "StaticArraysCore", "Tables"]
21 changes: 21 additions & 0 deletions ext/StructArraysGPUArraysCoreExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
module StructArraysGPUArraysCoreExt

using StructArrays
using StructArrays: map_params, array_types

using Base: tail

import GPUArraysCore

# for GPU broadcast
import GPUArraysCore
function GPUArraysCore.backend(::Type{T}) where {T<:StructArray}
backends = map_params(GPUArraysCore.backend, array_types(T))
backend, others = backends[1], tail(backends)
isconsistent = mapfoldl(isequal(backend), &, others; init=true)
isconsistent || throw(ArgumentError("all component arrays must have the same GPU backend"))
return backend
end
StructArrays.always_struct_broadcast(::GPUArraysCore.AbstractGPUArrayStyle) = true

end # module
Original file line number Diff line number Diff line change
@@ -1,3 +1,10 @@
module StructArraysStaticArraysCoreExt

using StructArrays
using StructArrays: StructArrayStyle, createinstance, replace_structarray, isnonemptystructtype

using Base.Broadcast: Broadcasted

using StaticArraysCore: StaticArray, FieldArray, tuple_prod

"""
Expand Down Expand Up @@ -40,7 +47,7 @@ Broadcast._axes(bc::Broadcasted{<:StructStaticArrayStyle}, ::Nothing) = axes(rep

# StaticArrayStyle has no similar defined.
# Overload `Base.copy` instead.
@inline function try_struct_copy(bc::Broadcasted{StaticArrayStyle{M}}) where {M}
@inline function StructArrays.try_struct_copy(bc::Broadcasted{StaticArrayStyle{M}}) where {M}
sa = copy(bc)
ET = eltype(sa)
isnonemptystructtype(ET) || return sa
Expand All @@ -66,3 +73,5 @@ end
return map(Base.Fix2(getfield, i), x)
end
end

end # module
7 changes: 7 additions & 0 deletions src/tables.jl → ext/StructArraysTablesExt.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
module StructArraysTablesExt

using StructArrays
using StructArrays: components, hasfields, foreachfield, staticschema

import Tables

Tables.isrowtable(::Type{<:StructArray}) = true
Expand Down Expand Up @@ -38,3 +43,5 @@ for (f, g) in zip((:append!, :prepend!), (:push!, :pushfirst!))
end
Copy link
Member

@aplavin aplavin Jun 17, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When put into extension, these functions are kinda-piracy: loading Tables changes a completely independent method, eg Base.push!(::StructVector, ::Any).

Copy link
Member

@aplavin aplavin Jun 17, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe they aren't needed at all? Not sure.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hm, that's a very good point, we really shouldn't do that. That code may be necessary - we may have to keep Tables as a hard dependency. But maybe Tables itself could be made more lightweight using Pkg extensions or so?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, that's a very good point, I hadn't thought about that. Yes, we do need Tables for basic functionality like pushing or appending to a StructVector effectively, so that should definitely stay a hard dependency. Still, it is supposed to be an interface package, so that should be OK.

end
end

end # module
15 changes: 4 additions & 11 deletions src/StructArrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,6 @@ include("utils.jl")
include("collect.jl")
include("sort.jl")
include("lazy.jl")
include("tables.jl")
include("staticarrays_support.jl")

# Implement refarray and refvalue to deal with pooled arrays and weakrefstrings effectively
import DataAPI: refarray, refvalue
Expand All @@ -29,15 +27,10 @@ end
import Adapt
Adapt.adapt_structure(to, s::StructArray) = replace_storage(x->Adapt.adapt(to, x), s)

# for GPU broadcast
import GPUArraysCore
function GPUArraysCore.backend(::Type{T}) where {T<:StructArray}
backends = map_params(GPUArraysCore.backend, array_types(T))
backend, others = backends[1], tail(backends)
isconsistent = mapfoldl(isequal(backend), &, others; init=true)
isconsistent || throw(ArgumentError("all component arrays must have the same GPU backend"))
return backend
@static if !isdefined(Base, :get_extension)
include("../ext/StructArraysGPUArraysCoreExt.jl")
include("../ext/StructArraysTablesExt.jl")
include("../ext/StructArraysStaticArraysCoreExt.jl")
end
always_struct_broadcast(::GPUArraysCore.AbstractGPUArrayStyle) = true

end # module