Skip to content

Commit

Permalink
Extend to arbitrary structures
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Jun 5, 2024
1 parent 8d1b76a commit f169142
Show file tree
Hide file tree
Showing 5 changed files with 49 additions and 10 deletions.
12 changes: 9 additions & 3 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
name = "LuxDeviceUtils"
uuid = "34f89e08-e1d5-43b4-8944-0b49ac560553"
authors = ["Avik Pal <[email protected]> and contributors"]
version = "0.1.20"
version = "0.1.21"

[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
ArgCheck = "dce04be8-c92d-5529-be00-80e4d2c0e197"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
FastClosures = "9aa1b823-49e4-5ca5-8b0f-3971ec8bab6a"
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
Expand All @@ -24,6 +25,7 @@ Metal = "dde4c033-4e86-420c-a63e-0dd931031962"
RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
oneAPI = "8f75cd03-7ff8-4ecb-9b8f-daf728133b1b"

[extensions]
LuxDeviceUtilsAMDGPUExt = "AMDGPU"
Expand All @@ -32,15 +34,17 @@ LuxDeviceUtilsFillArraysExt = "FillArrays"
LuxDeviceUtilsGPUArraysExt = "GPUArrays"
LuxDeviceUtilsLuxAMDGPUExt = "LuxAMDGPU"
LuxDeviceUtilsLuxCUDAExt = "LuxCUDA"
LuxDeviceUtilsMetalGPUArraysExt = ["GPUArrays", "Metal"]
LuxDeviceUtilsMetalExt = ["GPUArrays", "Metal"]
LuxDeviceUtilsRecursiveArrayToolsExt = "RecursiveArrayTools"
LuxDeviceUtilsSparseArraysExt = "SparseArrays"
LuxDeviceUtilsZygoteExt = "Zygote"
LuxDeviceUtilsoneAPIExt = "oneAPI"

[compat]
AMDGPU = "0.8.4, 0.9"
Adapt = "4"
Aqua = "0.8.4"
ArgCheck = "2.3"
CUDA = "5.2"
ChainRulesCore = "1.20"
ComponentArrays = "0.15.8"
Expand All @@ -63,6 +67,7 @@ Test = "1.10"
TestSetExtensions = "3"
Zygote = "0.6.69"
julia = "1.10"
oneAPI = "1.5"

[extras]
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
Expand All @@ -80,6 +85,7 @@ SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
TestSetExtensions = "98d24dd4-01ad-11ea-1b02-c9a08f80db04"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
oneAPI = "8f75cd03-7ff8-4ecb-9b8f-daf728133b1b"

[targets]
test = ["Aqua", "ComponentArrays", "ExplicitImports", "FillArrays", "LuxAMDGPU", "LuxCUDA", "LuxCore", "Metal", "Random", "RecursiveArrayTools", "SafeTestsets", "SparseArrays", "Test", "TestSetExtensions", "Zygote"]
test = ["Aqua", "ComponentArrays", "ExplicitImports", "FillArrays", "LuxAMDGPU", "LuxCUDA", "LuxCore", "Metal", "Random", "RecursiveArrayTools", "SafeTestsets", "SparseArrays", "Test", "TestSetExtensions", "Zygote", "oneAPI"]
9 changes: 4 additions & 5 deletions ext/LuxDeviceUtilsFillArraysExt.jl
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
module LuxDeviceUtilsFillArraysExt

using Adapt: Adapt
using FillArrays: FillArrays
using LuxDeviceUtils: LuxDeviceUtils, LuxCPUAdaptor
using FillArrays: FillArrays, AbstractFill
using LuxDeviceUtils: LuxDeviceUtils, LuxCPUAdaptor, AbstractLuxDeviceAdaptor

Adapt.adapt_structure(::LuxCPUAdaptor, x::FillArrays.AbstractFill) = x
Adapt.adapt_structure(::LuxCPUAdaptor, x::AbstractFill) = x

function Adapt.adapt_structure(
to::LuxDeviceUtils.AbstractLuxDeviceAdaptor, x::FillArrays.AbstractFill)
function Adapt.adapt_structure(to::AbstractLuxDeviceAdaptor, x::AbstractFill)
return Adapt.adapt(to, collect(x))
end

Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
module LuxDeviceUtilsMetalGPUArraysExt
module LuxDeviceUtilsMetalExt

using Adapt: Adapt
using GPUArrays: GPUArrays
Expand Down
3 changes: 3 additions & 0 deletions ext/LuxDeviceUtilsoneAPIExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
module LuxDeviceUtilsoneAPIExt

end
33 changes: 32 additions & 1 deletion src/LuxDeviceUtils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ using PrecompileTools: @recompile_invalidations

@recompile_invalidations begin
using Adapt: Adapt
using ArgCheck: @argcheck
using ChainRulesCore: ChainRulesCore, NoTangent
using FastClosures: @closure
using Functors: Functors, fmap
Expand Down Expand Up @@ -326,7 +327,8 @@ end
Returns the device of the array `x`. Trigger Packages must be loaded for this to return the
correct device.
"""
function get_device(x::AbstractArray)
function get_device(x::AbstractArray{T}) where {T}
!isbitstype(T) && __combine_devices(get_device.(x))
if hasmethod(parent, Tuple{typeof(x)})
parent_x = parent(x)
parent_x === x && return LuxCPUDevice()
Expand All @@ -335,8 +337,37 @@ function get_device(x::AbstractArray)
return LuxCPUDevice()
end

"""
get_device(x) -> AbstractLuxDevice | Exception | Nothing
If all arrays (on the leaves of the structure) are on the same device, we return that
device. Otherwise, we throw an error. If the object is device agnostic, we return `nothing`.
"""
function get_device(x)
dev = Ref{Union{AbstractLuxDevice, Nothing}}(nothing)
_get_device(x) = (dev[] = __combine_devices(dev[], get_device(x)))
fmap(_get_device, x)
return dev[]
end
for T in (Number, AbstractRNG, Val)
@eval get_device(::$(T)) = nothing
end
get_device(x::Tuple) = __combine_devices(get_device.(x)...)
get_device(x::NamedTuple) = __combine_devices(get_device.(values(x))...)

CRC.@non_differentiable get_device(::Any...)

__combine_devices(dev1) = dev1
function __combine_devices(dev1, dev2)
dev1 === nothing && return dev2
dev2 === nothing && return dev1
@argcheck dev1 == dev2
return dev1
end
function __combine_devices(dev1, dev2, rem_devs...)
return foldl(__combine_devices, (dev1, dev2, rem_devs...))
end

# Set the device
const SET_DEVICE_DOCS = """
Set the device for the given type. This is a no-op for `LuxCPUDevice`. For `LuxCUDADevice`
Expand Down

0 comments on commit f169142

Please sign in to comment.