Skip to content

Commit

Permalink
Use weakdeps on Julia v1.9
Browse files Browse the repository at this point in the history
  • Loading branch information
oschulz committed Apr 4, 2023
1 parent f23f3a8 commit baba2ca
Show file tree
Hide file tree
Showing 5 changed files with 53 additions and 13 deletions.
12 changes: 11 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,16 @@ GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c"
Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c"

[weakdeps]
GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c"
Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c"

[extensions]
StructArraysGPUArraysCoreExt = "GPUArraysCore"
StructArraysStaticArraysCoreExt = "StaticArraysCore"
StructArraysTablesExt = "Tables"

[compat]
Adapt = "1, 2, 3"
DataAPI = "1"
Expand All @@ -30,4 +40,4 @@ TypedTables = "9d95f2ec-7b3d-5a63-8d20-e2491e220bb9"
WeakRefStrings = "ea10d353-3f73-51f8-a26c-33c1cb351aa5"

[targets]
test = ["Test", "JLArrays", "StaticArrays", "OffsetArrays", "PooledArrays", "TypedTables", "WeakRefStrings", "Documenter", "SparseArrays"]
test = ["Test", "JLArrays", "StaticArrays", "OffsetArrays", "PooledArrays", "TypedTables", "WeakRefStrings", "Documenter", "SparseArrays", "GPUArraysCore", "StaticArraysCore", "Tables"]
21 changes: 21 additions & 0 deletions ext/StructArraysGPUArraysCoreExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
module StructArraysGPUArraysCoreExt

using StructArrays
using StructArrays: map_params, array_types

using Base: tail

import GPUArraysCore

# for GPU broadcast
import GPUArraysCore
function GPUArraysCore.backend(::Type{T}) where {T<:StructArray}
backends = map_params(GPUArraysCore.backend, array_types(T))
backend, others = backends[1], tail(backends)
isconsistent = mapfoldl(isequal(backend), &, others; init=true)
isconsistent || throw(ArgumentError("all component arrays must have the same GPU backend"))
return backend
end
StructArrays.always_struct_broadcast(::GPUArraysCore.AbstractGPUArrayStyle) = true

end # module
Original file line number Diff line number Diff line change
@@ -1,3 +1,10 @@
module StructArraysStaticArraysCoreExt

using StructArrays
using StructArrays: StructArrayStyle, createinstance, replace_structarray, isnonemptystructtype

using Base.Broadcast: Broadcasted

using StaticArraysCore: StaticArray, FieldArray, tuple_prod

"""
Expand Down Expand Up @@ -40,7 +47,7 @@ Broadcast._axes(bc::Broadcasted{<:StructStaticArrayStyle}, ::Nothing) = axes(rep

# StaticArrayStyle has no similar defined.
# Overload `Base.copy` instead.
@inline function try_struct_copy(bc::Broadcasted{StaticArrayStyle{M}}) where {M}
@inline function StructArrays.try_struct_copy(bc::Broadcasted{StaticArrayStyle{M}}) where {M}
sa = copy(bc)
ET = eltype(sa)
isnonemptystructtype(ET) || return sa
Expand All @@ -66,3 +73,5 @@ end
return map(Base.Fix2(getfield, i), x)
end
end

end # module
7 changes: 7 additions & 0 deletions src/tables.jl → ext/StructArraysTablesExt.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
module StructArraysTablesExt

using StructArrays
using StructArrays: components, hasfields, foreachfield, staticschema

import Tables

Tables.isrowtable(::Type{<:StructArray}) = true
Expand Down Expand Up @@ -38,3 +43,5 @@ for (f, g) in zip((:append!, :prepend!), (:push!, :pushfirst!))
end
end
end

end # module
15 changes: 4 additions & 11 deletions src/StructArrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,6 @@ include("utils.jl")
include("collect.jl")
include("sort.jl")
include("lazy.jl")
include("tables.jl")
include("staticarrays_support.jl")

# Implement refarray and refvalue to deal with pooled arrays and weakrefstrings effectively
import DataAPI: refarray, refvalue
Expand All @@ -29,15 +27,10 @@ end
import Adapt
Adapt.adapt_structure(to, s::StructArray) = replace_storage(x->Adapt.adapt(to, x), s)

# for GPU broadcast
import GPUArraysCore
function GPUArraysCore.backend(::Type{T}) where {T<:StructArray}
backends = map_params(GPUArraysCore.backend, array_types(T))
backend, others = backends[1], tail(backends)
isconsistent = mapfoldl(isequal(backend), &, others; init=true)
isconsistent || throw(ArgumentError("all component arrays must have the same GPU backend"))
return backend
@static if !isdefined(Base, :get_extension)
include("../ext/StructArraysGPUArraysCoreExt.jl")
include("../ext/StructArraysTablesExt.jl")
include("../ext/StructArraysStaticArraysCoreExt.jl")
end
always_struct_broadcast(::GPUArraysCore.AbstractGPUArrayStyle) = true

end # module

0 comments on commit baba2ca

Please sign in to comment.