From 2825852a68d0e1989a20cc3b74c8bc6e631fadad Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 23 Feb 2024 23:43:00 -0500 Subject: [PATCH 1/7] Add setup for multiGPU setups --- Project.toml | 3 +- ext/LuxDeviceUtilsLuxAMDGPUExt.jl | 15 +++- ext/LuxDeviceUtilsLuxCUDAExt.jl | 15 +++- ext/LuxDeviceUtilsMetalGPUArraysExt.jl | 6 +- ext/LuxDeviceUtilsSparseArraysExt.jl | 9 ++ src/LuxDeviceUtils.jl | 113 ++++++++++++++----------- 6 files changed, 103 insertions(+), 58 deletions(-) create mode 100644 ext/LuxDeviceUtilsSparseArraysExt.jl diff --git a/Project.toml b/Project.toml index da0cab4..8e83cce 100644 --- a/Project.toml +++ b/Project.toml @@ -11,7 +11,6 @@ LuxCore = "bb33d45b-7691-41d6-9220-0943567d0623" PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a" Preferences = "21216c6a-2e73-6563-6e65-726566657250" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" -SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" [weakdeps] FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" @@ -20,6 +19,7 @@ LuxAMDGPU = "83120cb1-ca15-4f04-bf3b-6967d2e6b60b" LuxCUDA = "d0bbae9a-e099-4d5b-a835-1c6931763bda" Metal = "dde4c033-4e86-420c-a63e-0dd931031962" RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd" +SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [extensions] @@ -29,6 +29,7 @@ LuxDeviceUtilsLuxAMDGPUExt = "LuxAMDGPU" LuxDeviceUtilsLuxCUDAExt = "LuxCUDA" LuxDeviceUtilsMetalGPUArraysExt = ["GPUArrays", "Metal"] LuxDeviceUtilsRecursiveArrayToolsExt = "RecursiveArrayTools" +LuxDeviceUtilsSparseArraysExt = "SparseArrays" LuxDeviceUtilsZygoteExt = "Zygote" [compat] diff --git a/ext/LuxDeviceUtilsLuxAMDGPUExt.jl b/ext/LuxDeviceUtilsLuxAMDGPUExt.jl index ac951f1..f061fcb 100644 --- a/ext/LuxDeviceUtilsLuxAMDGPUExt.jl +++ b/ext/LuxDeviceUtilsLuxAMDGPUExt.jl @@ -5,8 +5,19 @@ import Adapt: adapt_storage, adapt __init__() = reset_gpu_device!() -LuxDeviceUtils.__is_loaded(::LuxAMDGPUDevice) = true -LuxDeviceUtils.__is_functional(::LuxAMDGPUDevice) = LuxAMDGPU.functional() +LuxDeviceUtils.__is_loaded(::Union{LuxAMDGPUDevice, <:Type{LuxAMDGPUDevice}}) = true +function LuxDeviceUtils.__is_functional(::Union{LuxAMDGPUDevice, <:Type{LuxAMDGPUDevice}}) + return LuxAMDGPU.functional() +end + +function LuxDeviceUtils._with_device_id(::Type{LuxAMDGPUDevice}, device_id) + id = ifelse(device_id === nothing, 0, device_id) + old_id = AMDGPU.device_id(AMDGPU.device()) - 1 + AMDGPU.device!(AMDGPU.devices()[id + 1]) + device = LuxAMDGPUDevice(AMDGPU.device()) + AMDGPU.device!(AMDGPU.devices()[old_id + 1]) + return device +end # Default RNG LuxDeviceUtils.default_device_rng(::LuxAMDGPUDevice) = AMDGPU.rocrand_rng() diff --git a/ext/LuxDeviceUtilsLuxCUDAExt.jl b/ext/LuxDeviceUtilsLuxCUDAExt.jl index 4edf554..d57fc97 100644 --- a/ext/LuxDeviceUtilsLuxCUDAExt.jl +++ b/ext/LuxDeviceUtilsLuxCUDAExt.jl @@ -5,8 +5,19 @@ import Adapt: adapt_storage, adapt __init__() = reset_gpu_device!() -LuxDeviceUtils.__is_loaded(::LuxCUDADevice) = true -LuxDeviceUtils.__is_functional(::LuxCUDADevice) = LuxCUDA.functional() +LuxDeviceUtils.__is_loaded(::Union{LuxCUDADevice, Type{<:LuxCUDADevice}}) = true +function LuxDeviceUtils.__is_functional(::Union{LuxCUDADevice, Type{<:LuxCUDADevice}}) + return LuxCUDA.functional() +end + +function LuxDeviceUtils._with_device_id(::Type{LuxCUDADevice}, device_id) + id = ifelse(device_id === nothing, 0, device_id) + old_id = CUDA.device().handle + CUDA.device!(id) + device = LuxCUDADevice(CUDA.device()) + CUDA.device!(old_id) + return device +end # Default RNG LuxDeviceUtils.default_device_rng(::LuxCUDADevice) = CUDA.default_rng() diff --git a/ext/LuxDeviceUtilsMetalGPUArraysExt.jl b/ext/LuxDeviceUtilsMetalGPUArraysExt.jl index 836ab07..8272d6c 100644 --- a/ext/LuxDeviceUtilsMetalGPUArraysExt.jl +++ b/ext/LuxDeviceUtilsMetalGPUArraysExt.jl @@ -5,8 +5,10 @@ import Adapt: adapt_storage, adapt __init__() = reset_gpu_device!() -LuxDeviceUtils.__is_loaded(::LuxMetalDevice) = true -LuxDeviceUtils.__is_functional(::LuxMetalDevice) = Metal.functional() +LuxDeviceUtils.__is_loaded(::Union{LuxMetalDevice, Type{<:LuxMetalDevice}}) = true +function LuxDeviceUtils.__is_functional(::Union{LuxMetalDevice, Type{<:LuxMetalDevice}}) + return Metal.functional() +end # Default RNG LuxDeviceUtils.default_device_rng(::LuxMetalDevice) = GPUArrays.default_rng(MtlArray) diff --git a/ext/LuxDeviceUtilsSparseArraysExt.jl b/ext/LuxDeviceUtilsSparseArraysExt.jl new file mode 100644 index 0000000..80f5e35 --- /dev/null +++ b/ext/LuxDeviceUtilsSparseArraysExt.jl @@ -0,0 +1,9 @@ +module LuxDeviceUtilsSparseArraysExt + +import Adapt: adapt_storage +import LuxDeviceUtils: LuxCPUAdaptor +import SparseArrays: AbstractSparseArray + +adapt_storage(::LuxCPUAdaptor, x::AbstractSparseArray) = x + +end diff --git a/src/LuxDeviceUtils.jl b/src/LuxDeviceUtils.jl index 04347dc..3cf70bb 100644 --- a/src/LuxDeviceUtils.jl +++ b/src/LuxDeviceUtils.jl @@ -3,7 +3,7 @@ module LuxDeviceUtils import PrecompileTools: @recompile_invalidations @recompile_invalidations begin - using ChainRulesCore, Functors, LuxCore, Preferences, Random, SparseArrays + using ChainRulesCore, Functors, LuxCore, Preferences, Random import Adapt: adapt, adapt_storage import ChainRulesCore as CRC end @@ -17,37 +17,53 @@ export get_device abstract type AbstractLuxDevice <: Function end abstract type AbstractLuxGPUDevice <: AbstractLuxDevice end -__is_functional(::AbstractLuxDevice) = false -__is_loaded(::AbstractLuxDevice) = false +__is_functional(x) = false +__is_loaded(x) = false struct LuxCPUDevice <: AbstractLuxDevice end -struct LuxCUDADevice <: AbstractLuxGPUDevice end -struct LuxAMDGPUDevice <: AbstractLuxGPUDevice end +@kwdef struct LuxCUDADevice{ID} <: AbstractLuxGPUDevice + device_id::ID = nothing +end +@kwdef struct LuxAMDGPUDevice{ID} <: AbstractLuxGPUDevice + device_id::ID = nothing +end struct LuxMetalDevice <: AbstractLuxGPUDevice end -__is_functional(::LuxCPUDevice) = true -__is_loaded(::LuxCPUDevice) = true +_with_device_id(::Type{LuxCPUDevice}, ::Nothing) = LuxCPUDevice() +function _with_device_id(::Type{LuxCPUDevice}, device_id) + @warn "`device_id` is not applicable for `LuxCPUDevice`." maxlog=1 + return LuxCPUDevice() +end + +_with_device_id(::Type{LuxMetalDevice}, ::Nothing) = LuxMetalDevice() +function _with_device_id(::Type{LuxMetalDevice}, device_id) + @warn "`device_id` is not applicable for `LuxMetalDevice`." maxlog=1 + return LuxMetalDevice() +end + +__is_functional(::Union{LuxCPUDevice, Type{<:LuxCPUDevice}}) = true +__is_loaded(::Union{LuxCPUDevice, Type{<:LuxCPUDevice}}) = true -_get_device_name(::LuxCPUDevice) = "CPU" -_get_device_name(::LuxCUDADevice) = "CUDA" -_get_device_name(::LuxAMDGPUDevice) = "AMDGPU" -_get_device_name(::LuxMetalDevice) = "Metal" +_get_device_name(::Union{LuxCPUDevice, Type{<:LuxCPUDevice}}) = "CPU" +_get_device_name(::Union{LuxCUDADevice, Type{<:LuxCUDADevice}}) = "CUDA" +_get_device_name(::Union{LuxAMDGPUDevice, Type{<:LuxAMDGPUDevice}}) = "AMDGPU" +_get_device_name(::Union{LuxMetalDevice, Type{<:LuxMetalDevice}}) = "Metal" -_get_triggerpkg_name(::LuxCPUDevice) = "" -_get_triggerpkg_name(::LuxCUDADevice) = "LuxCUDA" -_get_triggerpkg_name(::LuxAMDGPUDevice) = "LuxAMDGPU" -_get_triggerpkg_name(::LuxMetalDevice) = "Metal" +_get_triggerpkg_name(::Union{LuxCPUDevice, Type{<:LuxCPUDevice}}) = "" +_get_triggerpkg_name(::Union{LuxCUDADevice, Type{<:LuxCUDADevice}}) = "LuxCUDA" +_get_triggerpkg_name(::Union{LuxAMDGPUDevice, Type{<:LuxAMDGPUDevice}}) = "LuxAMDGPU" +_get_triggerpkg_name(::Union{LuxMetalDevice, Type{<:LuxMetalDevice}}) = "Metal" Base.show(io::IO, dev::AbstractLuxDevice) = print(io, nameof(dev)) struct LuxDeviceSelectionException <: Exception end -function Base.showerror(io::IO, e::LuxDeviceSelectionException) +function Base.showerror(io::IO, ::LuxDeviceSelectionException) return print(io, "LuxDeviceSelectionException(No functional GPU device found!!)") end # Order is important here -const GPU_DEVICES = (LuxCUDADevice(), LuxAMDGPUDevice(), LuxMetalDevice()) +const GPU_DEVICES = (LuxCUDADevice, LuxAMDGPUDevice, LuxMetalDevice) const GPU_DEVICE = Ref{Union{Nothing, AbstractLuxDevice}}(nothing) @@ -57,27 +73,22 @@ const GPU_DEVICE = Ref{Union{Nothing, AbstractLuxDevice}}(nothing) Resets the selected GPU device. This is useful when automatic GPU selection needs to be run again. """ -function reset_gpu_device!() - return GPU_DEVICE[] = nothing -end +reset_gpu_device!() = (GPU_DEVICE[] = nothing) """ supported_gpu_backends() -> Tuple{String, ...} Return a tuple of supported GPU backends. -::: warning - -This is not the list of functional backends on the system, but rather backends which -`Lux.jl` supports. +!!! warning -::: + This is not the list of functional backends on the system, but rather backends which + `Lux.jl` supports. -::: danger +!!! danger -`Metal.jl` support is **extremely** experimental and most things are not expected to work. - -::: + `Metal.jl` support is **extremely** experimental and most things are not expected to + work. """ supported_gpu_backends() = map(_get_device_name, GPU_DEVICES) @@ -95,14 +106,15 @@ Selects GPU device based on the following criteria: invoked. 4. If nothing works, an error is thrown. """ -function gpu_device(; force_gpu_usage::Bool=false)::AbstractLuxDevice +function gpu_device(device_id=nothing; force_gpu_usage::Bool=false)::AbstractLuxDevice if GPU_DEVICE[] !== nothing force_gpu_usage && !(GPU_DEVICE[] isa AbstractLuxGPUDevice) && throw(LuxDeviceSelectionException()) return GPU_DEVICE[] end - device = _get_gpu_device(; force_gpu_usage) + device_type = _get_gpu_device(; force_gpu_usage) + device = _with_device_id(device_type, device_id) GPU_DEVICE[] = device return device @@ -116,25 +128,25 @@ function _get_gpu_device(; force_gpu_usage::Bool) allowed_backends = supported_gpu_backends() idx = findfirst(isequal(backend), allowed_backends) if backend ∉ allowed_backends - @warn """ - `gpu_backend` preference is set to $backend, which is not a valid backend. - Valid backends are $allowed_backends. - Defaulting to automatic GPU Backend selection. - """ maxlog=1 + @warn "`gpu_backend` preference is set to $backend, which is not a valid \ + backend. Valid backends are $allowed_backends. Defaulting to automatic \ + GPU Backend selection." maxlog=1 else @debug "Using GPU backend set in preferences: $backend." device = GPU_DEVICES[idx] if !__is_loaded(device) - @warn """Trying to use backend: $(_get_device_name(device)) but the trigger package $(device.pkgid) is not loaded. - Ignoring the Preferences backend!!! - Please load the package and call this function again to respect the Preferences backend.""" maxlog=1 + @warn "Trying to use backend: $(_get_device_name(device)) but the trigger \ + package $(device.pkgid) is not loaded. Ignoring the Preferences \ + backend!!! Please load the package and call this function again to \ + respect the Preferences backend." maxlog=1 else if __is_functional(device) @debug "Using GPU backend: $(_get_device_name(device))." return device else - @warn "GPU backend: $(_get_device_name(device)) set via Preferences.jl is not functional. - Defaulting to automatic GPU Backend selection." maxlog=1 + @warn "GPU backend: $(_get_device_name(device)) set via Preferences.jl \ + is not functional. Defaulting to automatic GPU Backend \ + selection." maxlog=1 end end end @@ -150,7 +162,8 @@ function _get_gpu_device(; force_gpu_usage::Bool) end @debug "GPU backend: $(_get_device_name(device)) is not functional." else - @debug "Trigger package for backend ($(_get_device_name(device))): $(_get_trigger_pkgname(device)) not loaded." + @debug "Trigger package for backend ($(_get_device_name(device))): \ + $(_get_trigger_pkgname(device)) not loaded." end end @@ -164,7 +177,7 @@ function _get_gpu_device(; force_gpu_usage::Bool) a. LuxCUDA.jl for NVIDIA CUDA Support. b. LuxAMDGPU.jl for AMD GPU ROCM Support. c. Metal.jl for Apple Metal GPU Support.""" maxlog=1 - return cpu_device() + return LuxCPUDevice end end @@ -188,7 +201,8 @@ gpu_backend!() = gpu_backend!("") function gpu_backend!(backend::String) if backend == "" @delete_preferences!("gpu_backend") - @info "Deleted the local preference for `gpu_backend`. Restart Julia to use the new backend." + @info "Deleted the local preference for `gpu_backend`. Restart Julia to use the \ + new backend." return end @@ -250,8 +264,8 @@ for (dev) in (:CPU, :CUDA, :AMDGPU, :Metal) end function (::$(ldev))(NN::LuxCore.AbstractExplicitLayer) @warn "Lux layers are stateless and hence don't participate in device \ - transfers. Apply this function on the parameters and states generated \ - using `Lux.setup`." maxlog=1 + transfers. Apply this function on the parameters and states generated \ + using `Lux.setup`." maxlog=1 return NN end end @@ -264,7 +278,7 @@ end Returns the device of the array `x`. Trigger Packages must be loaded for this to return the correct device. """ -get_device(x::AbstractArray) = LuxCPUDevice() +get_device(::AbstractArray) = LuxCPUDevice() # Adapt Interface abstract type AbstractLuxDeviceAdaptor end @@ -274,10 +288,7 @@ struct LuxCUDAAdaptor <: AbstractLuxDeviceAdaptor end struct LuxAMDGPUAdaptor <: AbstractLuxDeviceAdaptor end struct LuxMetalAdaptor <: AbstractLuxDeviceAdaptor end -function adapt_storage(::LuxCPUAdaptor, - x::Union{AbstractRange, SparseArrays.AbstractSparseArray}) - return x -end +adapt_storage(::LuxCPUAdaptor, x::AbstractRange) = x adapt_storage(::LuxCPUAdaptor, x::AbstractArray) = adapt(Array, x) adapt_storage(::LuxCPUAdaptor, rng::AbstractRNG) = rng From 54214977a02015777d7cf993b488312898b84df5 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 24 Feb 2024 00:15:17 -0500 Subject: [PATCH 2/7] Map device to adaptor --- ext/LuxDeviceUtilsLuxAMDGPUExt.jl | 8 ++++++-- ext/LuxDeviceUtilsLuxCUDAExt.jl | 8 ++++++-- src/LuxDeviceUtils.jl | 13 +++++++++++-- 3 files changed, 23 insertions(+), 6 deletions(-) diff --git a/ext/LuxDeviceUtilsLuxAMDGPUExt.jl b/ext/LuxDeviceUtilsLuxAMDGPUExt.jl index f061fcb..764700d 100644 --- a/ext/LuxDeviceUtilsLuxAMDGPUExt.jl +++ b/ext/LuxDeviceUtilsLuxAMDGPUExt.jl @@ -10,8 +10,12 @@ function LuxDeviceUtils.__is_functional(::Union{LuxAMDGPUDevice, <:Type{LuxAMDGP return LuxAMDGPU.functional() end -function LuxDeviceUtils._with_device_id(::Type{LuxAMDGPUDevice}, device_id) - id = ifelse(device_id === nothing, 0, device_id) +LuxDeviceUtils._get_adaptor(::LuxAMDGPUDevice{Nothing}) = LuxAMDGPUAdaptor(AMDGPU.device()) + +function LuxDeviceUtils._with_device_id(::Type{LuxAMDGPUDevice}, ::Nothing) + return LuxAMDGPUDevice(AMDGPU.device()) +end +function LuxDeviceUtils._with_device_id(::Type{LuxAMDGPUDevice}, id) old_id = AMDGPU.device_id(AMDGPU.device()) - 1 AMDGPU.device!(AMDGPU.devices()[id + 1]) device = LuxAMDGPUDevice(AMDGPU.device()) diff --git a/ext/LuxDeviceUtilsLuxCUDAExt.jl b/ext/LuxDeviceUtilsLuxCUDAExt.jl index d57fc97..228fa4e 100644 --- a/ext/LuxDeviceUtilsLuxCUDAExt.jl +++ b/ext/LuxDeviceUtilsLuxCUDAExt.jl @@ -10,8 +10,12 @@ function LuxDeviceUtils.__is_functional(::Union{LuxCUDADevice, Type{<:LuxCUDADev return LuxCUDA.functional() end -function LuxDeviceUtils._with_device_id(::Type{LuxCUDADevice}, device_id) - id = ifelse(device_id === nothing, 0, device_id) +LuxDeviceUtils._get_adaptor(::LuxCUDADevice{Nothing}) = LuxCUDAAdaptor(CUDA.device()) + +function LuxDeviceUtils._with_device_id(::Type{LuxCUDADevice}, ::Nothing) + return LuxCUDADevice(CUDA.device()) +end +function LuxDeviceUtils._with_device_id(::Type{LuxCUDADevice}, id) old_id = CUDA.device().handle CUDA.device!(id) device = LuxCUDADevice(CUDA.device()) diff --git a/src/LuxDeviceUtils.jl b/src/LuxDeviceUtils.jl index 3cf70bb..5c6b7a6 100644 --- a/src/LuxDeviceUtils.jl +++ b/src/LuxDeviceUtils.jl @@ -41,6 +41,11 @@ function _with_device_id(::Type{LuxMetalDevice}, device_id) return LuxMetalDevice() end +_get_adaptor(::LuxCPUDevice) = LuxCPUAdaptor() +_get_adaptor(dev::LuxCUDADevice) = LuxCUDAAdaptor(dev.device_id) +_get_adaptor(dev::LuxAMDGPUDevice) = LuxAMDGPUAdaptor(dev.device_id) +_get_adaptor(::LuxMetalDevice) = LuxMetalAdaptor() + __is_functional(::Union{LuxCPUDevice, Type{<:LuxCPUDevice}}) = true __is_loaded(::Union{LuxCPUDevice, Type{<:LuxCPUDevice}}) = true @@ -284,8 +289,12 @@ get_device(::AbstractArray) = LuxCPUDevice() abstract type AbstractLuxDeviceAdaptor end struct LuxCPUAdaptor <: AbstractLuxDeviceAdaptor end -struct LuxCUDAAdaptor <: AbstractLuxDeviceAdaptor end -struct LuxAMDGPUAdaptor <: AbstractLuxDeviceAdaptor end +struct LuxCUDAAdaptor{ID} <: AbstractLuxDeviceAdaptor + device_id::ID +end +struct LuxAMDGPUAdaptor{ID} <: AbstractLuxDeviceAdaptor + device_id::ID +end struct LuxMetalAdaptor <: AbstractLuxDeviceAdaptor end adapt_storage(::LuxCPUAdaptor, x::AbstractRange) = x From 8e4f924df3394121cfa498be4e94c0aef903ffbf Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 24 Feb 2024 00:36:20 -0500 Subject: [PATCH 3/7] write the adaptor code --- ext/LuxDeviceUtilsLuxAMDGPUExt.jl | 30 ++++++++++++++++------ ext/LuxDeviceUtilsLuxCUDAExt.jl | 30 ++++++++++++++++------ src/LuxDeviceUtils.jl | 41 ++++++++++++++++--------------- 3 files changed, 65 insertions(+), 36 deletions(-) diff --git a/ext/LuxDeviceUtilsLuxAMDGPUExt.jl b/ext/LuxDeviceUtilsLuxAMDGPUExt.jl index 764700d..1a4a8fc 100644 --- a/ext/LuxDeviceUtilsLuxAMDGPUExt.jl +++ b/ext/LuxDeviceUtilsLuxAMDGPUExt.jl @@ -10,16 +10,14 @@ function LuxDeviceUtils.__is_functional(::Union{LuxAMDGPUDevice, <:Type{LuxAMDGP return LuxAMDGPU.functional() end -LuxDeviceUtils._get_adaptor(::LuxAMDGPUDevice{Nothing}) = LuxAMDGPUAdaptor(AMDGPU.device()) - -function LuxDeviceUtils._with_device_id(::Type{LuxAMDGPUDevice}, ::Nothing) - return LuxAMDGPUDevice(AMDGPU.device()) +function LuxDeviceUtils._with_device(::Type{LuxAMDGPUDevice}, ::Nothing) + return LuxAMDGPUDevice(nothing) end -function LuxDeviceUtils._with_device_id(::Type{LuxAMDGPUDevice}, id) - old_id = AMDGPU.device_id(AMDGPU.device()) - 1 +function LuxDeviceUtils._with_device(::Type{LuxAMDGPUDevice}, id) + old_dev = AMDGPU.device() AMDGPU.device!(AMDGPU.devices()[id + 1]) device = LuxAMDGPUDevice(AMDGPU.device()) - AMDGPU.device!(AMDGPU.devices()[old_id + 1]) + AMDGPU.device!(old_dev) return device end @@ -31,7 +29,23 @@ LuxDeviceUtils.get_device(::AMDGPU.AnyROCArray) = LuxAMDGPUDevice() # Device Transfer ## To GPU -adapt_storage(::LuxAMDGPUAdaptor, x) = roc(x) +adapt_storage(::LuxAMDGPUAdaptor{Nothing}, x) = roc(x) +function adapt_storage(to::LuxAMDGPUAdaptor, x) + old_dev = AMDGPU.device() # remember the current device + if !(x isa AMDGPU.AnyROCArray) + AMDGPU.device!(to.device) + x_new = roc(x) + AMDGPU.device!(old_dev) + return x_new + elseif AMDGPU.device_id(AMDGPU.device(x)) == AMDGPU.device_id(to.device) + return x + else + AMDGPU.device!(to.device) + x_new = copy(x) + AMDGPU.device!(old_dev) + return x_new + end +end adapt_storage(::LuxAMDGPUAdaptor, rng::AbstractRNG) = rng adapt_storage(::LuxAMDGPUAdaptor, rng::Random.TaskLocalRNG) = AMDGPU.rocrand_rng() diff --git a/ext/LuxDeviceUtilsLuxCUDAExt.jl b/ext/LuxDeviceUtilsLuxCUDAExt.jl index 228fa4e..737bdf1 100644 --- a/ext/LuxDeviceUtilsLuxCUDAExt.jl +++ b/ext/LuxDeviceUtilsLuxCUDAExt.jl @@ -10,16 +10,14 @@ function LuxDeviceUtils.__is_functional(::Union{LuxCUDADevice, Type{<:LuxCUDADev return LuxCUDA.functional() end -LuxDeviceUtils._get_adaptor(::LuxCUDADevice{Nothing}) = LuxCUDAAdaptor(CUDA.device()) - -function LuxDeviceUtils._with_device_id(::Type{LuxCUDADevice}, ::Nothing) - return LuxCUDADevice(CUDA.device()) +function LuxDeviceUtils._with_device(::Type{LuxCUDADevice}, ::Nothing) + return LuxCUDADevice(nothing) end -function LuxDeviceUtils._with_device_id(::Type{LuxCUDADevice}, id) - old_id = CUDA.device().handle +function LuxDeviceUtils._with_device(::Type{LuxCUDADevice}, id) + old_dev = CUDA.device() CUDA.device!(id) device = LuxCUDADevice(CUDA.device()) - CUDA.device!(old_id) + CUDA.device!(old_dev) return device end @@ -31,7 +29,23 @@ LuxDeviceUtils.get_device(::CUDA.AnyCuArray) = LuxCUDADevice() # Device Transfer ## To GPU -adapt_storage(::LuxCUDAAdaptor, x) = cu(x) +adapt_storage(::LuxCUDAAdaptor{Nothing}, x) = cu(x) +function adapt_storage(to::LuxCUDAAdaptor, x) + old_dev = CUDA.device() # remember the current device + if !(x isa CUDA.AnyCuArray) + CUDA.device!(to.device) + x_new = cu(x) + CUDA.device!(old_dev) + return x_new + elseif CUDA.device(x).handle == to.device.handle + return x + else + CUDA.device!(to.device) + x_new = copy(x) + CUDA.device!(old_dev) + return x_new + end +end adapt_storage(::LuxCUDAAdaptor, rng::AbstractRNG) = rng adapt_storage(::LuxCUDAAdaptor, rng::Random.TaskLocalRNG) = CUDA.default_rng() diff --git a/src/LuxDeviceUtils.jl b/src/LuxDeviceUtils.jl index 5c6b7a6..12ab7f5 100644 --- a/src/LuxDeviceUtils.jl +++ b/src/LuxDeviceUtils.jl @@ -21,29 +21,29 @@ __is_functional(x) = false __is_loaded(x) = false struct LuxCPUDevice <: AbstractLuxDevice end -@kwdef struct LuxCUDADevice{ID} <: AbstractLuxGPUDevice - device_id::ID = nothing +@kwdef struct LuxCUDADevice{D} <: AbstractLuxGPUDevice + device::D = nothing end -@kwdef struct LuxAMDGPUDevice{ID} <: AbstractLuxGPUDevice - device_id::ID = nothing +@kwdef struct LuxAMDGPUDevice{D} <: AbstractLuxGPUDevice + device::D = nothing end struct LuxMetalDevice <: AbstractLuxGPUDevice end -_with_device_id(::Type{LuxCPUDevice}, ::Nothing) = LuxCPUDevice() -function _with_device_id(::Type{LuxCPUDevice}, device_id) +_with_device(::Type{LuxCPUDevice}, ::Nothing) = LuxCPUDevice() +function _with_device(::Type{LuxCPUDevice}, device_id) @warn "`device_id` is not applicable for `LuxCPUDevice`." maxlog=1 return LuxCPUDevice() end -_with_device_id(::Type{LuxMetalDevice}, ::Nothing) = LuxMetalDevice() -function _with_device_id(::Type{LuxMetalDevice}, device_id) +_with_device(::Type{LuxMetalDevice}, ::Nothing) = LuxMetalDevice() +function _with_device(::Type{LuxMetalDevice}, device_id) @warn "`device_id` is not applicable for `LuxMetalDevice`." maxlog=1 return LuxMetalDevice() end _get_adaptor(::LuxCPUDevice) = LuxCPUAdaptor() -_get_adaptor(dev::LuxCUDADevice) = LuxCUDAAdaptor(dev.device_id) -_get_adaptor(dev::LuxAMDGPUDevice) = LuxAMDGPUAdaptor(dev.device_id) +_get_adaptor(dev::LuxCUDADevice) = LuxCUDAAdaptor(dev.device) +_get_adaptor(dev::LuxAMDGPUDevice) = LuxAMDGPUAdaptor(dev.device) _get_adaptor(::LuxMetalDevice) = LuxMetalAdaptor() __is_functional(::Union{LuxCPUDevice, Type{<:LuxCPUDevice}}) = true @@ -119,7 +119,7 @@ function gpu_device(device_id=nothing; force_gpu_usage::Bool=false)::AbstractLux end device_type = _get_gpu_device(; force_gpu_usage) - device = _with_device_id(device_type, device_id) + device = _with_device(device_type, device_id) GPU_DEVICE[] = device return device @@ -255,17 +255,18 @@ default_device_rng(::LuxCPUDevice) = Random.default_rng() # For Lux, typically models only has these 3 datastructures so we should be mostly fine. for (dev) in (:CPU, :CUDA, :AMDGPU, :Metal) ldev = Symbol("Lux$(dev)Device") - ladaptor = Symbol("Lux$(dev)Adaptor") @eval begin function (D::$(ldev))(x::AbstractArray) - fn = Base.Fix1(adapt, $(ladaptor)()) + ladaptor = _get_adaptor(D) + fn = Base.Fix1(adapt, ladaptor) return _isbitsarray(x) ? fn(x) : map(D, x) end (D::$(ldev))(x::Tuple) = map(D, x) (D::$(ldev))(x::NamedTuple{F}) where {F} = NamedTuple{F}(D(values(x))) - function (::$(ldev))(x) - _isleaf(x) && return adapt($(ladaptor)(), x) - return fmap(Base.Fix1(adapt, $(ladaptor)()), x; exclude=_isleaf) + function (D::$(ldev))(x) + ladaptor = _get_adaptor(D) + _isleaf(x) && return adapt(ladaptor, x) + return fmap(Base.Fix1(adapt, ladaptor), x; exclude=_isleaf) end function (::$(ldev))(NN::LuxCore.AbstractExplicitLayer) @warn "Lux layers are stateless and hence don't participate in device \ @@ -289,11 +290,11 @@ get_device(::AbstractArray) = LuxCPUDevice() abstract type AbstractLuxDeviceAdaptor end struct LuxCPUAdaptor <: AbstractLuxDeviceAdaptor end -struct LuxCUDAAdaptor{ID} <: AbstractLuxDeviceAdaptor - device_id::ID +struct LuxCUDAAdaptor{D} <: AbstractLuxDeviceAdaptor + device::D end -struct LuxAMDGPUAdaptor{ID} <: AbstractLuxDeviceAdaptor - device_id::ID +struct LuxAMDGPUAdaptor{D} <: AbstractLuxDeviceAdaptor + device::D end struct LuxMetalAdaptor <: AbstractLuxDeviceAdaptor end From 2aef0903e3b988eeaf39345f30bbff046b8942cc Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 24 Feb 2024 12:25:40 -0500 Subject: [PATCH 4/7] reselect gpu if id changed --- ext/LuxDeviceUtilsLuxAMDGPUExt.jl | 8 +++-- ext/LuxDeviceUtilsLuxCUDAExt.jl | 8 +++-- src/LuxDeviceUtils.jl | 54 +++++++++++++++++++++++++------ 3 files changed, 56 insertions(+), 14 deletions(-) diff --git a/ext/LuxDeviceUtilsLuxAMDGPUExt.jl b/ext/LuxDeviceUtilsLuxAMDGPUExt.jl index 1a4a8fc..0a8ea7d 100644 --- a/ext/LuxDeviceUtilsLuxAMDGPUExt.jl +++ b/ext/LuxDeviceUtilsLuxAMDGPUExt.jl @@ -13,14 +13,18 @@ end function LuxDeviceUtils._with_device(::Type{LuxAMDGPUDevice}, ::Nothing) return LuxAMDGPUDevice(nothing) end -function LuxDeviceUtils._with_device(::Type{LuxAMDGPUDevice}, id) +function LuxDeviceUtils._with_device(::Type{LuxAMDGPUDevice}, id::Int) + id > length(AMDGPU.devices()) && + throw(ArgumentError("id = $id > length(AMDGPU.devices()) = $(length(AMDGPU.devices()))")) old_dev = AMDGPU.device() - AMDGPU.device!(AMDGPU.devices()[id + 1]) + AMDGPU.device!(AMDGPU.devices()[id]) device = LuxAMDGPUDevice(AMDGPU.device()) AMDGPU.device!(old_dev) return device end +LuxDeviceUtils._get_device_id(dev::LuxAMDGPUDevice) = AMDGPU.device_id(dev.device) + # Default RNG LuxDeviceUtils.default_device_rng(::LuxAMDGPUDevice) = AMDGPU.rocrand_rng() diff --git a/ext/LuxDeviceUtilsLuxCUDAExt.jl b/ext/LuxDeviceUtilsLuxCUDAExt.jl index 737bdf1..49a1e0b 100644 --- a/ext/LuxDeviceUtilsLuxCUDAExt.jl +++ b/ext/LuxDeviceUtilsLuxCUDAExt.jl @@ -13,14 +13,18 @@ end function LuxDeviceUtils._with_device(::Type{LuxCUDADevice}, ::Nothing) return LuxCUDADevice(nothing) end -function LuxDeviceUtils._with_device(::Type{LuxCUDADevice}, id) +function LuxDeviceUtils._with_device(::Type{LuxCUDADevice}, id::Int) + id > length(CUDA.devices()) && + throw(ArgumentError("id = $id > length(CUDA.devices()) = $(length(CUDA.devices()))")) old_dev = CUDA.device() - CUDA.device!(id) + CUDA.device!(id - 1) device = LuxCUDADevice(CUDA.device()) CUDA.device!(old_dev) return device end +LuxDeviceUtils._get_device_id(dev::LuxCUDADevice) = CUDA.deviceid(dev.device) + 1 + # Default RNG LuxDeviceUtils.default_device_rng(::LuxCUDADevice) = CUDA.default_rng() diff --git a/src/LuxDeviceUtils.jl b/src/LuxDeviceUtils.jl index 12ab7f5..07397b7 100644 --- a/src/LuxDeviceUtils.jl +++ b/src/LuxDeviceUtils.jl @@ -41,11 +41,6 @@ function _with_device(::Type{LuxMetalDevice}, device_id) return LuxMetalDevice() end -_get_adaptor(::LuxCPUDevice) = LuxCPUAdaptor() -_get_adaptor(dev::LuxCUDADevice) = LuxCUDAAdaptor(dev.device) -_get_adaptor(dev::LuxAMDGPUDevice) = LuxAMDGPUAdaptor(dev.device) -_get_adaptor(::LuxMetalDevice) = LuxMetalAdaptor() - __is_functional(::Union{LuxCPUDevice, Type{<:LuxCPUDevice}}) = true __is_loaded(::Union{LuxCPUDevice, Type{<:LuxCPUDevice}}) = true @@ -59,6 +54,16 @@ _get_triggerpkg_name(::Union{LuxCUDADevice, Type{<:LuxCUDADevice}}) = "LuxCUDA" _get_triggerpkg_name(::Union{LuxAMDGPUDevice, Type{<:LuxAMDGPUDevice}}) = "LuxAMDGPU" _get_triggerpkg_name(::Union{LuxMetalDevice, Type{<:LuxMetalDevice}}) = "Metal" +_get_adaptor(::LuxCPUDevice) = LuxCPUAdaptor() +_get_adaptor(dev::LuxCUDADevice) = LuxCUDAAdaptor(dev.device) +_get_adaptor(dev::LuxAMDGPUDevice) = LuxAMDGPUAdaptor(dev.device) +_get_adaptor(::LuxMetalDevice) = LuxMetalAdaptor() + +_get_device_id(::LuxCPUDevice) = nothing +_get_device_id(::LuxCUDADevice{Nothing}) = nothing +_get_device_id(::LuxAMDGPUDevice{Nothing}) = nothing +_get_device_id(::LuxMetalDevice) = nothing + Base.show(io::IO, dev::AbstractLuxDevice) = print(io, nameof(dev)) struct LuxDeviceSelectionException <: Exception end @@ -98,7 +103,8 @@ Return a tuple of supported GPU backends. supported_gpu_backends() = map(_get_device_name, GPU_DEVICES) """ - gpu_device(; force_gpu_usage::Bool=false) -> AbstractLuxDevice() + gpu_device(device_id::Union{Nothing, Int}=nothing; + force_gpu_usage::Bool=false) -> AbstractLuxDevice() Selects GPU device based on the following criteria: @@ -110,12 +116,40 @@ Selects GPU device based on the following criteria: 3. If no GPU device is functional and `force_gpu_usage` is `false`, then `cpu_device()` is invoked. 4. If nothing works, an error is thrown. + +## Arguments + + - `device_id::Union{Nothing, Int}`: The device id to select. If `nothing`, then we return + the last selected device or if none was selected then we run the autoselection and + choose the current device using `CUDA.device()` or `AMDGPU.device()` or similar. If + `Int`, then we select the device with the given id. Note that this is `1`-indexed, in + contrast to the `0`-indexed `CUDA.jl`. For example, `id = 4` corresponds to + `CUDA.device!(3)`. + +!!! warning + + `device_id` is only applicable for `CUDA` and `AMDGPU` backends. For `Metal` and `CPU` + backends, `device_id` is ignored and a warning is printed. + +## Keyword Arguments + + - `force_gpu_usage::Bool`: If `true`, then an error is thrown if no functional GPU + device is found. """ -function gpu_device(device_id=nothing; force_gpu_usage::Bool=false)::AbstractLuxDevice +function gpu_device(device_id::Union{Nothing, Int}=nothing; + force_gpu_usage::Bool=false)::AbstractLuxDevice + device_id == 0 && throw(ArgumentError("`device_id` is 1-indexed.")) + if GPU_DEVICE[] !== nothing - force_gpu_usage && !(GPU_DEVICE[] isa AbstractLuxGPUDevice) && - throw(LuxDeviceSelectionException()) - return GPU_DEVICE[] + dev = GPU_DEVICE[] + if device_id === nothing + force_gpu_usage && !(dev isa AbstractLuxGPUDevice) && + throw(LuxDeviceSelectionException()) + return dev + else + selected_device_id = _get_device_id(dev) + selected_device_id !== nothing && selected_device_id == device_id && return dev + end end device_type = _get_gpu_device(; force_gpu_usage) From d33b2e30c9e7cb8f62833468b8320fcf281173a9 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 24 Feb 2024 12:47:27 -0500 Subject: [PATCH 5/7] Add tests --- Project.toml | 2 +- test/amdgpu.jl | 23 +++++++++++++++++++++++ test/cuda.jl | 23 +++++++++++++++++++++++ 3 files changed, 47 insertions(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 8e83cce..f78a118 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "LuxDeviceUtils" uuid = "34f89e08-e1d5-43b4-8944-0b49ac560553" authors = ["Avik Pal and contributors"] -version = "0.1.15" +version = "0.1.16" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" diff --git a/test/amdgpu.jl b/test/amdgpu.jl index 68e8db0..3675a0e 100644 --- a/test/amdgpu.jl +++ b/test/amdgpu.jl @@ -72,3 +72,26 @@ using FillArrays, Zygote # Extensions @test ps_cpu.farray isa Fill end end + +if LuxAMDGPU.functional() + ps = (; weight=rand(Float32, 10), bias=rand(Float32, 10)) + ps_cpu = deepcopy(ps) + cdev = cpu_device() + for idx in 1:length(AMDGPU.devices()) + amdgpu_device = gpu_device(idx) + @test typeof(amdgpu_device.device) <: AMDGPU.HIPDevice + @test AMDGPU.device_id(amdgpu_device.device) == idx + + ps = ps |> amdgpu_device + @test ps.weight isa ROCArray + @test ps.bias isa ROCArray + @test AMDGPU.device_id(AMDGPU.device(ps.weight)) == idx + @test AMDGPU.device_id(AMDGPU.device(ps.bias)) == idx + @test isequal(cdev(ps.weight), ps_cpu.weight) + @test isequal(cdev(ps.bias), ps_cpu.bias) + end + + ps = ps |> cdev + @test ps.weight isa Array + @test ps.bias isa Array +end diff --git a/test/cuda.jl b/test/cuda.jl index 613f132..9a7c2c3 100644 --- a/test/cuda.jl +++ b/test/cuda.jl @@ -72,3 +72,26 @@ using FillArrays, Zygote # Extensions @test ps_cpu.farray isa Fill end end + +if LuxCUDA.functional() + ps = (; weight=rand(Float32, 10), bias=rand(Float32, 10)) + ps_cpu = deepcopy(ps) + cdev = cpu_device() + for idx in 1:length(CUDA.devices()) + cuda_device = gpu_device(idx) + @test typeof(cuda_device.device) <: CUDA.CuDevice + @test cuda_device.device.handle == (idx - 1) + + ps = ps |> cuda_device + @test ps.weight isa CuArray + @test ps.bias isa CuArray + @test CUDA.device(ps.weight).handle == idx - 1 + @test CUDA.device(ps.bias).handle == idx - 1 + @test isequal(cdev(ps.weight), ps_cpu.weight) + @test isequal(cdev(ps.bias), ps_cpu.bias) + end + + ps = ps |> cdev + @test ps.weight isa Array + @test ps.bias isa Array +end From 5effa9ca836f13521748e172f238521e854fe454 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 24 Feb 2024 14:25:02 -0500 Subject: [PATCH 6/7] Fix ambiguity problems --- ext/LuxDeviceUtilsLuxAMDGPUExt.jl | 2 ++ ext/LuxDeviceUtilsLuxCUDAExt.jl | 2 ++ test/amdgpu.jl | 2 +- test/cuda.jl | 2 +- 4 files changed, 6 insertions(+), 2 deletions(-) diff --git a/ext/LuxDeviceUtilsLuxAMDGPUExt.jl b/ext/LuxDeviceUtilsLuxAMDGPUExt.jl index 0a8ea7d..be83184 100644 --- a/ext/LuxDeviceUtilsLuxAMDGPUExt.jl +++ b/ext/LuxDeviceUtilsLuxAMDGPUExt.jl @@ -50,7 +50,9 @@ function adapt_storage(to::LuxAMDGPUAdaptor, x) return x_new end end +adapt_storage(::LuxAMDGPUAdaptor{Nothing}, rng::AbstractRNG) = rng adapt_storage(::LuxAMDGPUAdaptor, rng::AbstractRNG) = rng +adapt_storage(::LuxAMDGPUAdaptor{Nothing}, rng::Random.TaskLocalRNG) = AMDGPU.rocrand_rng() adapt_storage(::LuxAMDGPUAdaptor, rng::Random.TaskLocalRNG) = AMDGPU.rocrand_rng() adapt_storage(::LuxCPUAdaptor, rng::AMDGPU.rocRAND.RNG) = Random.default_rng() diff --git a/ext/LuxDeviceUtilsLuxCUDAExt.jl b/ext/LuxDeviceUtilsLuxCUDAExt.jl index 49a1e0b..09cfaac 100644 --- a/ext/LuxDeviceUtilsLuxCUDAExt.jl +++ b/ext/LuxDeviceUtilsLuxCUDAExt.jl @@ -50,7 +50,9 @@ function adapt_storage(to::LuxCUDAAdaptor, x) return x_new end end +adapt_storage(::LuxCUDAAdaptor{Nothing}, rng::AbstractRNG) = rng adapt_storage(::LuxCUDAAdaptor, rng::AbstractRNG) = rng +adapt_storage(::LuxCUDAAdaptor{Nothing}, rng::Random.TaskLocalRNG) = CUDA.default_rng() adapt_storage(::LuxCUDAAdaptor, rng::Random.TaskLocalRNG) = CUDA.default_rng() adapt_storage(::LuxCPUAdaptor, rng::CUDA.RNG) = Random.default_rng() diff --git a/test/amdgpu.jl b/test/amdgpu.jl index 3675a0e..9247fdb 100644 --- a/test/amdgpu.jl +++ b/test/amdgpu.jl @@ -82,7 +82,7 @@ if LuxAMDGPU.functional() @test typeof(amdgpu_device.device) <: AMDGPU.HIPDevice @test AMDGPU.device_id(amdgpu_device.device) == idx - ps = ps |> amdgpu_device + global ps = ps |> amdgpu_device @test ps.weight isa ROCArray @test ps.bias isa ROCArray @test AMDGPU.device_id(AMDGPU.device(ps.weight)) == idx diff --git a/test/cuda.jl b/test/cuda.jl index 9a7c2c3..e0dc343 100644 --- a/test/cuda.jl +++ b/test/cuda.jl @@ -82,7 +82,7 @@ if LuxCUDA.functional() @test typeof(cuda_device.device) <: CUDA.CuDevice @test cuda_device.device.handle == (idx - 1) - ps = ps |> cuda_device + global ps = ps |> cuda_device @test ps.weight isa CuArray @test ps.bias isa CuArray @test CUDA.device(ps.weight).handle == idx - 1 From 8a9985c7fb3d8bddd5c3ce952984b5a43f2dd8bc Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 25 Feb 2024 15:15:03 -0500 Subject: [PATCH 7/7] Fix get_device for multi-gpu --- ext/LuxDeviceUtilsLuxAMDGPUExt.jl | 2 +- ext/LuxDeviceUtilsLuxCUDAExt.jl | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/ext/LuxDeviceUtilsLuxAMDGPUExt.jl b/ext/LuxDeviceUtilsLuxAMDGPUExt.jl index be83184..c13e3df 100644 --- a/ext/LuxDeviceUtilsLuxAMDGPUExt.jl +++ b/ext/LuxDeviceUtilsLuxAMDGPUExt.jl @@ -29,7 +29,7 @@ LuxDeviceUtils._get_device_id(dev::LuxAMDGPUDevice) = AMDGPU.device_id(dev.devic LuxDeviceUtils.default_device_rng(::LuxAMDGPUDevice) = AMDGPU.rocrand_rng() # Query Device from Array -LuxDeviceUtils.get_device(::AMDGPU.AnyROCArray) = LuxAMDGPUDevice() +LuxDeviceUtils.get_device(x::AMDGPU.AnyROCArray) = LuxAMDGPUDevice(AMDGPU.device(x)) # Device Transfer ## To GPU diff --git a/ext/LuxDeviceUtilsLuxCUDAExt.jl b/ext/LuxDeviceUtilsLuxCUDAExt.jl index 09cfaac..56cb1eb 100644 --- a/ext/LuxDeviceUtilsLuxCUDAExt.jl +++ b/ext/LuxDeviceUtilsLuxCUDAExt.jl @@ -29,7 +29,7 @@ LuxDeviceUtils._get_device_id(dev::LuxCUDADevice) = CUDA.deviceid(dev.device) + LuxDeviceUtils.default_device_rng(::LuxCUDADevice) = CUDA.default_rng() # Query Device from Array -LuxDeviceUtils.get_device(::CUDA.AnyCuArray) = LuxCUDADevice() +LuxDeviceUtils.get_device(x::CUDA.AnyCuArray) = LuxCUDADevice(CUDA.device(x)) # Device Transfer ## To GPU