From baba2ca2507022fbea06bc24ee612f3c8d6aa7d4 Mon Sep 17 00:00:00 2001 From: Oliver Schulz Date: Tue, 4 Apr 2023 09:07:49 +0200 Subject: [PATCH] Use weakdeps on Julia v1.9 --- Project.toml | 12 ++++++++++- ext/StructArraysGPUArraysCoreExt.jl | 21 +++++++++++++++++++ .../StructArraysStaticArraysCoreExt.jl | 11 +++++++++- src/tables.jl => ext/StructArraysTablesExt.jl | 7 +++++++ src/StructArrays.jl | 15 ++++--------- 5 files changed, 53 insertions(+), 13 deletions(-) create mode 100644 ext/StructArraysGPUArraysCoreExt.jl rename src/staticarrays_support.jl => ext/StructArraysStaticArraysCoreExt.jl (89%) rename src/tables.jl => ext/StructArraysTablesExt.jl (92%) diff --git a/Project.toml b/Project.toml index c672e08d..ae7cc770 100644 --- a/Project.toml +++ b/Project.toml @@ -9,6 +9,16 @@ GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527" StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c" Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c" +[weakdeps] +GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527" +StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c" +Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c" + +[extensions] +StructArraysGPUArraysCoreExt = "GPUArraysCore" +StructArraysStaticArraysCoreExt = "StaticArraysCore" +StructArraysTablesExt = "Tables" + [compat] Adapt = "1, 2, 3" DataAPI = "1" @@ -30,4 +40,4 @@ TypedTables = "9d95f2ec-7b3d-5a63-8d20-e2491e220bb9" WeakRefStrings = "ea10d353-3f73-51f8-a26c-33c1cb351aa5" [targets] -test = ["Test", "JLArrays", "StaticArrays", "OffsetArrays", "PooledArrays", "TypedTables", "WeakRefStrings", "Documenter", "SparseArrays"] +test = ["Test", "JLArrays", "StaticArrays", "OffsetArrays", "PooledArrays", "TypedTables", "WeakRefStrings", "Documenter", "SparseArrays", "GPUArraysCore", "StaticArraysCore", "Tables"] diff --git a/ext/StructArraysGPUArraysCoreExt.jl b/ext/StructArraysGPUArraysCoreExt.jl new file mode 100644 index 00000000..b05d3082 --- /dev/null +++ b/ext/StructArraysGPUArraysCoreExt.jl @@ -0,0 +1,21 @@ +module StructArraysGPUArraysCoreExt + +using StructArrays +using StructArrays: map_params, array_types + +using Base: tail + +import GPUArraysCore + +# for GPU broadcast +import GPUArraysCore +function GPUArraysCore.backend(::Type{T}) where {T<:StructArray} + backends = map_params(GPUArraysCore.backend, array_types(T)) + backend, others = backends[1], tail(backends) + isconsistent = mapfoldl(isequal(backend), &, others; init=true) + isconsistent || throw(ArgumentError("all component arrays must have the same GPU backend")) + return backend +end +StructArrays.always_struct_broadcast(::GPUArraysCore.AbstractGPUArrayStyle) = true + +end # module diff --git a/src/staticarrays_support.jl b/ext/StructArraysStaticArraysCoreExt.jl similarity index 89% rename from src/staticarrays_support.jl rename to ext/StructArraysStaticArraysCoreExt.jl index 1af186e8..28e73762 100644 --- a/src/staticarrays_support.jl +++ b/ext/StructArraysStaticArraysCoreExt.jl @@ -1,3 +1,10 @@ +module StructArraysStaticArraysCoreExt + +using StructArrays +using StructArrays: StructArrayStyle, createinstance, replace_structarray, isnonemptystructtype + +using Base.Broadcast: Broadcasted + using StaticArraysCore: StaticArray, FieldArray, tuple_prod """ @@ -40,7 +47,7 @@ Broadcast._axes(bc::Broadcasted{<:StructStaticArrayStyle}, ::Nothing) = axes(rep # StaticArrayStyle has no similar defined. # Overload `Base.copy` instead. -@inline function try_struct_copy(bc::Broadcasted{StaticArrayStyle{M}}) where {M} +@inline function StructArrays.try_struct_copy(bc::Broadcasted{StaticArrayStyle{M}}) where {M} sa = copy(bc) ET = eltype(sa) isnonemptystructtype(ET) || return sa @@ -66,3 +73,5 @@ end return map(Base.Fix2(getfield, i), x) end end + +end # module diff --git a/src/tables.jl b/ext/StructArraysTablesExt.jl similarity index 92% rename from src/tables.jl rename to ext/StructArraysTablesExt.jl index d6ac2248..14420879 100644 --- a/src/tables.jl +++ b/ext/StructArraysTablesExt.jl @@ -1,3 +1,8 @@ +module StructArraysTablesExt + +using StructArrays +using StructArrays: components, hasfields, foreachfield, staticschema + import Tables Tables.isrowtable(::Type{<:StructArray}) = true @@ -38,3 +43,5 @@ for (f, g) in zip((:append!, :prepend!), (:push!, :pushfirst!)) end end end + +end # module diff --git a/src/StructArrays.jl b/src/StructArrays.jl index 129dcd82..b91007d2 100644 --- a/src/StructArrays.jl +++ b/src/StructArrays.jl @@ -12,8 +12,6 @@ include("utils.jl") include("collect.jl") include("sort.jl") include("lazy.jl") -include("tables.jl") -include("staticarrays_support.jl") # Implement refarray and refvalue to deal with pooled arrays and weakrefstrings effectively import DataAPI: refarray, refvalue @@ -29,15 +27,10 @@ end import Adapt Adapt.adapt_structure(to, s::StructArray) = replace_storage(x->Adapt.adapt(to, x), s) -# for GPU broadcast -import GPUArraysCore -function GPUArraysCore.backend(::Type{T}) where {T<:StructArray} - backends = map_params(GPUArraysCore.backend, array_types(T)) - backend, others = backends[1], tail(backends) - isconsistent = mapfoldl(isequal(backend), &, others; init=true) - isconsistent || throw(ArgumentError("all component arrays must have the same GPU backend")) - return backend +@static if !isdefined(Base, :get_extension) + include("../ext/StructArraysGPUArraysCoreExt.jl") + include("../ext/StructArraysTablesExt.jl") + include("../ext/StructArraysStaticArraysCoreExt.jl") end -always_struct_broadcast(::GPUArraysCore.AbstractGPUArrayStyle) = true end # module