From f16914221dbff74a24e4bf77042479f2fa0ccd22 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 4 Jun 2024 22:41:59 -0700 Subject: [PATCH] Extend to arbitrary structures --- Project.toml | 12 +++++-- ext/LuxDeviceUtilsFillArraysExt.jl | 9 +++-- ...ArraysExt.jl => LuxDeviceUtilsMetalExt.jl} | 2 +- ext/LuxDeviceUtilsoneAPIExt.jl | 3 ++ src/LuxDeviceUtils.jl | 33 ++++++++++++++++++- 5 files changed, 49 insertions(+), 10 deletions(-) rename ext/{LuxDeviceUtilsMetalGPUArraysExt.jl => LuxDeviceUtilsMetalExt.jl} (95%) create mode 100644 ext/LuxDeviceUtilsoneAPIExt.jl diff --git a/Project.toml b/Project.toml index aadadd7..8d556c3 100644 --- a/Project.toml +++ b/Project.toml @@ -1,10 +1,11 @@ name = "LuxDeviceUtils" uuid = "34f89e08-e1d5-43b4-8944-0b49ac560553" authors = ["Avik Pal and contributors"] -version = "0.1.20" +version = "0.1.21" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" +ArgCheck = "dce04be8-c92d-5529-be00-80e4d2c0e197" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" FastClosures = "9aa1b823-49e4-5ca5-8b0f-3971ec8bab6a" Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" @@ -24,6 +25,7 @@ Metal = "dde4c033-4e86-420c-a63e-0dd931031962" RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" +oneAPI = "8f75cd03-7ff8-4ecb-9b8f-daf728133b1b" [extensions] LuxDeviceUtilsAMDGPUExt = "AMDGPU" @@ -32,15 +34,17 @@ LuxDeviceUtilsFillArraysExt = "FillArrays" LuxDeviceUtilsGPUArraysExt = "GPUArrays" LuxDeviceUtilsLuxAMDGPUExt = "LuxAMDGPU" LuxDeviceUtilsLuxCUDAExt = "LuxCUDA" -LuxDeviceUtilsMetalGPUArraysExt = ["GPUArrays", "Metal"] +LuxDeviceUtilsMetalExt = ["GPUArrays", "Metal"] LuxDeviceUtilsRecursiveArrayToolsExt = "RecursiveArrayTools" LuxDeviceUtilsSparseArraysExt = "SparseArrays" LuxDeviceUtilsZygoteExt = "Zygote" +LuxDeviceUtilsoneAPIExt = "oneAPI" [compat] AMDGPU = "0.8.4, 0.9" Adapt = "4" Aqua = "0.8.4" +ArgCheck = "2.3" CUDA = "5.2" ChainRulesCore = "1.20" ComponentArrays = "0.15.8" @@ -63,6 +67,7 @@ Test = "1.10" TestSetExtensions = "3" Zygote = "0.6.69" julia = "1.10" +oneAPI = "1.5" [extras] Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" @@ -80,6 +85,7 @@ SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" TestSetExtensions = "98d24dd4-01ad-11ea-1b02-c9a08f80db04" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" +oneAPI = "8f75cd03-7ff8-4ecb-9b8f-daf728133b1b" [targets] -test = ["Aqua", "ComponentArrays", "ExplicitImports", "FillArrays", "LuxAMDGPU", "LuxCUDA", "LuxCore", "Metal", "Random", "RecursiveArrayTools", "SafeTestsets", "SparseArrays", "Test", "TestSetExtensions", "Zygote"] +test = ["Aqua", "ComponentArrays", "ExplicitImports", "FillArrays", "LuxAMDGPU", "LuxCUDA", "LuxCore", "Metal", "Random", "RecursiveArrayTools", "SafeTestsets", "SparseArrays", "Test", "TestSetExtensions", "Zygote", "oneAPI"] diff --git a/ext/LuxDeviceUtilsFillArraysExt.jl b/ext/LuxDeviceUtilsFillArraysExt.jl index 879d380..ecf44f3 100644 --- a/ext/LuxDeviceUtilsFillArraysExt.jl +++ b/ext/LuxDeviceUtilsFillArraysExt.jl @@ -1,13 +1,12 @@ module LuxDeviceUtilsFillArraysExt using Adapt: Adapt -using FillArrays: FillArrays -using LuxDeviceUtils: LuxDeviceUtils, LuxCPUAdaptor +using FillArrays: FillArrays, AbstractFill +using LuxDeviceUtils: LuxDeviceUtils, LuxCPUAdaptor, AbstractLuxDeviceAdaptor -Adapt.adapt_structure(::LuxCPUAdaptor, x::FillArrays.AbstractFill) = x +Adapt.adapt_structure(::LuxCPUAdaptor, x::AbstractFill) = x -function Adapt.adapt_structure( - to::LuxDeviceUtils.AbstractLuxDeviceAdaptor, x::FillArrays.AbstractFill) +function Adapt.adapt_structure(to::AbstractLuxDeviceAdaptor, x::AbstractFill) return Adapt.adapt(to, collect(x)) end diff --git a/ext/LuxDeviceUtilsMetalGPUArraysExt.jl b/ext/LuxDeviceUtilsMetalExt.jl similarity index 95% rename from ext/LuxDeviceUtilsMetalGPUArraysExt.jl rename to ext/LuxDeviceUtilsMetalExt.jl index 5cdd530..2d81b59 100644 --- a/ext/LuxDeviceUtilsMetalGPUArraysExt.jl +++ b/ext/LuxDeviceUtilsMetalExt.jl @@ -1,4 +1,4 @@ -module LuxDeviceUtilsMetalGPUArraysExt +module LuxDeviceUtilsMetalExt using Adapt: Adapt using GPUArrays: GPUArrays diff --git a/ext/LuxDeviceUtilsoneAPIExt.jl b/ext/LuxDeviceUtilsoneAPIExt.jl new file mode 100644 index 0000000..0bb7e89 --- /dev/null +++ b/ext/LuxDeviceUtilsoneAPIExt.jl @@ -0,0 +1,3 @@ +module LuxDeviceUtilsoneAPIExt + +end diff --git a/src/LuxDeviceUtils.jl b/src/LuxDeviceUtils.jl index 775439c..a1e6596 100644 --- a/src/LuxDeviceUtils.jl +++ b/src/LuxDeviceUtils.jl @@ -4,6 +4,7 @@ using PrecompileTools: @recompile_invalidations @recompile_invalidations begin using Adapt: Adapt + using ArgCheck: @argcheck using ChainRulesCore: ChainRulesCore, NoTangent using FastClosures: @closure using Functors: Functors, fmap @@ -326,7 +327,8 @@ end Returns the device of the array `x`. Trigger Packages must be loaded for this to return the correct device. """ -function get_device(x::AbstractArray) +function get_device(x::AbstractArray{T}) where {T} + !isbitstype(T) && __combine_devices(get_device.(x)) if hasmethod(parent, Tuple{typeof(x)}) parent_x = parent(x) parent_x === x && return LuxCPUDevice() @@ -335,8 +337,37 @@ function get_device(x::AbstractArray) return LuxCPUDevice() end +""" + get_device(x) -> AbstractLuxDevice | Exception | Nothing + +If all arrays (on the leaves of the structure) are on the same device, we return that +device. Otherwise, we throw an error. If the object is device agnostic, we return `nothing`. +""" +function get_device(x) + dev = Ref{Union{AbstractLuxDevice, Nothing}}(nothing) + _get_device(x) = (dev[] = __combine_devices(dev[], get_device(x))) + fmap(_get_device, x) + return dev[] +end +for T in (Number, AbstractRNG, Val) + @eval get_device(::$(T)) = nothing +end +get_device(x::Tuple) = __combine_devices(get_device.(x)...) +get_device(x::NamedTuple) = __combine_devices(get_device.(values(x))...) + CRC.@non_differentiable get_device(::Any...) +__combine_devices(dev1) = dev1 +function __combine_devices(dev1, dev2) + dev1 === nothing && return dev2 + dev2 === nothing && return dev1 + @argcheck dev1 == dev2 + return dev1 +end +function __combine_devices(dev1, dev2, rem_devs...) + return foldl(__combine_devices, (dev1, dev2, rem_devs...)) +end + # Set the device const SET_DEVICE_DOCS = """ Set the device for the given type. This is a no-op for `LuxCPUDevice`. For `LuxCUDADevice`