diff --git a/Project.toml b/Project.toml index da0cab4..f78a118 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "LuxDeviceUtils" uuid = "34f89e08-e1d5-43b4-8944-0b49ac560553" authors = ["Avik Pal and contributors"] -version = "0.1.15" +version = "0.1.16" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" @@ -11,7 +11,6 @@ LuxCore = "bb33d45b-7691-41d6-9220-0943567d0623" PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a" Preferences = "21216c6a-2e73-6563-6e65-726566657250" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" -SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" [weakdeps] FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" @@ -20,6 +19,7 @@ LuxAMDGPU = "83120cb1-ca15-4f04-bf3b-6967d2e6b60b" LuxCUDA = "d0bbae9a-e099-4d5b-a835-1c6931763bda" Metal = "dde4c033-4e86-420c-a63e-0dd931031962" RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd" +SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [extensions] @@ -29,6 +29,7 @@ LuxDeviceUtilsLuxAMDGPUExt = "LuxAMDGPU" LuxDeviceUtilsLuxCUDAExt = "LuxCUDA" LuxDeviceUtilsMetalGPUArraysExt = ["GPUArrays", "Metal"] LuxDeviceUtilsRecursiveArrayToolsExt = "RecursiveArrayTools" +LuxDeviceUtilsSparseArraysExt = "SparseArrays" LuxDeviceUtilsZygoteExt = "Zygote" [compat] diff --git a/ext/LuxDeviceUtilsLuxAMDGPUExt.jl b/ext/LuxDeviceUtilsLuxAMDGPUExt.jl index ac951f1..c13e3df 100644 --- a/ext/LuxDeviceUtilsLuxAMDGPUExt.jl +++ b/ext/LuxDeviceUtilsLuxAMDGPUExt.jl @@ -5,19 +5,54 @@ import Adapt: adapt_storage, adapt __init__() = reset_gpu_device!() -LuxDeviceUtils.__is_loaded(::LuxAMDGPUDevice) = true -LuxDeviceUtils.__is_functional(::LuxAMDGPUDevice) = LuxAMDGPU.functional() +LuxDeviceUtils.__is_loaded(::Union{LuxAMDGPUDevice, <:Type{LuxAMDGPUDevice}}) = true +function LuxDeviceUtils.__is_functional(::Union{LuxAMDGPUDevice, <:Type{LuxAMDGPUDevice}}) + return LuxAMDGPU.functional() +end + +function LuxDeviceUtils._with_device(::Type{LuxAMDGPUDevice}, ::Nothing) + return LuxAMDGPUDevice(nothing) +end +function LuxDeviceUtils._with_device(::Type{LuxAMDGPUDevice}, id::Int) + id > length(AMDGPU.devices()) && + throw(ArgumentError("id = $id > length(AMDGPU.devices()) = $(length(AMDGPU.devices()))")) + old_dev = AMDGPU.device() + AMDGPU.device!(AMDGPU.devices()[id]) + device = LuxAMDGPUDevice(AMDGPU.device()) + AMDGPU.device!(old_dev) + return device +end + +LuxDeviceUtils._get_device_id(dev::LuxAMDGPUDevice) = AMDGPU.device_id(dev.device) # Default RNG LuxDeviceUtils.default_device_rng(::LuxAMDGPUDevice) = AMDGPU.rocrand_rng() # Query Device from Array -LuxDeviceUtils.get_device(::AMDGPU.AnyROCArray) = LuxAMDGPUDevice() +LuxDeviceUtils.get_device(x::AMDGPU.AnyROCArray) = LuxAMDGPUDevice(AMDGPU.device(x)) # Device Transfer ## To GPU -adapt_storage(::LuxAMDGPUAdaptor, x) = roc(x) +adapt_storage(::LuxAMDGPUAdaptor{Nothing}, x) = roc(x) +function adapt_storage(to::LuxAMDGPUAdaptor, x) + old_dev = AMDGPU.device() # remember the current device + if !(x isa AMDGPU.AnyROCArray) + AMDGPU.device!(to.device) + x_new = roc(x) + AMDGPU.device!(old_dev) + return x_new + elseif AMDGPU.device_id(AMDGPU.device(x)) == AMDGPU.device_id(to.device) + return x + else + AMDGPU.device!(to.device) + x_new = copy(x) + AMDGPU.device!(old_dev) + return x_new + end +end +adapt_storage(::LuxAMDGPUAdaptor{Nothing}, rng::AbstractRNG) = rng adapt_storage(::LuxAMDGPUAdaptor, rng::AbstractRNG) = rng +adapt_storage(::LuxAMDGPUAdaptor{Nothing}, rng::Random.TaskLocalRNG) = AMDGPU.rocrand_rng() adapt_storage(::LuxAMDGPUAdaptor, rng::Random.TaskLocalRNG) = AMDGPU.rocrand_rng() adapt_storage(::LuxCPUAdaptor, rng::AMDGPU.rocRAND.RNG) = Random.default_rng() diff --git a/ext/LuxDeviceUtilsLuxCUDAExt.jl b/ext/LuxDeviceUtilsLuxCUDAExt.jl index 4edf554..56cb1eb 100644 --- a/ext/LuxDeviceUtilsLuxCUDAExt.jl +++ b/ext/LuxDeviceUtilsLuxCUDAExt.jl @@ -5,19 +5,54 @@ import Adapt: adapt_storage, adapt __init__() = reset_gpu_device!() -LuxDeviceUtils.__is_loaded(::LuxCUDADevice) = true -LuxDeviceUtils.__is_functional(::LuxCUDADevice) = LuxCUDA.functional() +LuxDeviceUtils.__is_loaded(::Union{LuxCUDADevice, Type{<:LuxCUDADevice}}) = true +function LuxDeviceUtils.__is_functional(::Union{LuxCUDADevice, Type{<:LuxCUDADevice}}) + return LuxCUDA.functional() +end + +function LuxDeviceUtils._with_device(::Type{LuxCUDADevice}, ::Nothing) + return LuxCUDADevice(nothing) +end +function LuxDeviceUtils._with_device(::Type{LuxCUDADevice}, id::Int) + id > length(CUDA.devices()) && + throw(ArgumentError("id = $id > length(CUDA.devices()) = $(length(CUDA.devices()))")) + old_dev = CUDA.device() + CUDA.device!(id - 1) + device = LuxCUDADevice(CUDA.device()) + CUDA.device!(old_dev) + return device +end + +LuxDeviceUtils._get_device_id(dev::LuxCUDADevice) = CUDA.deviceid(dev.device) + 1 # Default RNG LuxDeviceUtils.default_device_rng(::LuxCUDADevice) = CUDA.default_rng() # Query Device from Array -LuxDeviceUtils.get_device(::CUDA.AnyCuArray) = LuxCUDADevice() +LuxDeviceUtils.get_device(x::CUDA.AnyCuArray) = LuxCUDADevice(CUDA.device(x)) # Device Transfer ## To GPU -adapt_storage(::LuxCUDAAdaptor, x) = cu(x) +adapt_storage(::LuxCUDAAdaptor{Nothing}, x) = cu(x) +function adapt_storage(to::LuxCUDAAdaptor, x) + old_dev = CUDA.device() # remember the current device + if !(x isa CUDA.AnyCuArray) + CUDA.device!(to.device) + x_new = cu(x) + CUDA.device!(old_dev) + return x_new + elseif CUDA.device(x).handle == to.device.handle + return x + else + CUDA.device!(to.device) + x_new = copy(x) + CUDA.device!(old_dev) + return x_new + end +end +adapt_storage(::LuxCUDAAdaptor{Nothing}, rng::AbstractRNG) = rng adapt_storage(::LuxCUDAAdaptor, rng::AbstractRNG) = rng +adapt_storage(::LuxCUDAAdaptor{Nothing}, rng::Random.TaskLocalRNG) = CUDA.default_rng() adapt_storage(::LuxCUDAAdaptor, rng::Random.TaskLocalRNG) = CUDA.default_rng() adapt_storage(::LuxCPUAdaptor, rng::CUDA.RNG) = Random.default_rng() diff --git a/ext/LuxDeviceUtilsMetalGPUArraysExt.jl b/ext/LuxDeviceUtilsMetalGPUArraysExt.jl index 836ab07..8272d6c 100644 --- a/ext/LuxDeviceUtilsMetalGPUArraysExt.jl +++ b/ext/LuxDeviceUtilsMetalGPUArraysExt.jl @@ -5,8 +5,10 @@ import Adapt: adapt_storage, adapt __init__() = reset_gpu_device!() -LuxDeviceUtils.__is_loaded(::LuxMetalDevice) = true -LuxDeviceUtils.__is_functional(::LuxMetalDevice) = Metal.functional() +LuxDeviceUtils.__is_loaded(::Union{LuxMetalDevice, Type{<:LuxMetalDevice}}) = true +function LuxDeviceUtils.__is_functional(::Union{LuxMetalDevice, Type{<:LuxMetalDevice}}) + return Metal.functional() +end # Default RNG LuxDeviceUtils.default_device_rng(::LuxMetalDevice) = GPUArrays.default_rng(MtlArray) diff --git a/ext/LuxDeviceUtilsSparseArraysExt.jl b/ext/LuxDeviceUtilsSparseArraysExt.jl new file mode 100644 index 0000000..80f5e35 --- /dev/null +++ b/ext/LuxDeviceUtilsSparseArraysExt.jl @@ -0,0 +1,9 @@ +module LuxDeviceUtilsSparseArraysExt + +import Adapt: adapt_storage +import LuxDeviceUtils: LuxCPUAdaptor +import SparseArrays: AbstractSparseArray + +adapt_storage(::LuxCPUAdaptor, x::AbstractSparseArray) = x + +end diff --git a/src/LuxDeviceUtils.jl b/src/LuxDeviceUtils.jl index 04347dc..07397b7 100644 --- a/src/LuxDeviceUtils.jl +++ b/src/LuxDeviceUtils.jl @@ -3,7 +3,7 @@ module LuxDeviceUtils import PrecompileTools: @recompile_invalidations @recompile_invalidations begin - using ChainRulesCore, Functors, LuxCore, Preferences, Random, SparseArrays + using ChainRulesCore, Functors, LuxCore, Preferences, Random import Adapt: adapt, adapt_storage import ChainRulesCore as CRC end @@ -17,37 +17,63 @@ export get_device abstract type AbstractLuxDevice <: Function end abstract type AbstractLuxGPUDevice <: AbstractLuxDevice end -__is_functional(::AbstractLuxDevice) = false -__is_loaded(::AbstractLuxDevice) = false +__is_functional(x) = false +__is_loaded(x) = false struct LuxCPUDevice <: AbstractLuxDevice end -struct LuxCUDADevice <: AbstractLuxGPUDevice end -struct LuxAMDGPUDevice <: AbstractLuxGPUDevice end +@kwdef struct LuxCUDADevice{D} <: AbstractLuxGPUDevice + device::D = nothing +end +@kwdef struct LuxAMDGPUDevice{D} <: AbstractLuxGPUDevice + device::D = nothing +end struct LuxMetalDevice <: AbstractLuxGPUDevice end -__is_functional(::LuxCPUDevice) = true -__is_loaded(::LuxCPUDevice) = true +_with_device(::Type{LuxCPUDevice}, ::Nothing) = LuxCPUDevice() +function _with_device(::Type{LuxCPUDevice}, device_id) + @warn "`device_id` is not applicable for `LuxCPUDevice`." maxlog=1 + return LuxCPUDevice() +end + +_with_device(::Type{LuxMetalDevice}, ::Nothing) = LuxMetalDevice() +function _with_device(::Type{LuxMetalDevice}, device_id) + @warn "`device_id` is not applicable for `LuxMetalDevice`." maxlog=1 + return LuxMetalDevice() +end + +__is_functional(::Union{LuxCPUDevice, Type{<:LuxCPUDevice}}) = true +__is_loaded(::Union{LuxCPUDevice, Type{<:LuxCPUDevice}}) = true -_get_device_name(::LuxCPUDevice) = "CPU" -_get_device_name(::LuxCUDADevice) = "CUDA" -_get_device_name(::LuxAMDGPUDevice) = "AMDGPU" -_get_device_name(::LuxMetalDevice) = "Metal" +_get_device_name(::Union{LuxCPUDevice, Type{<:LuxCPUDevice}}) = "CPU" +_get_device_name(::Union{LuxCUDADevice, Type{<:LuxCUDADevice}}) = "CUDA" +_get_device_name(::Union{LuxAMDGPUDevice, Type{<:LuxAMDGPUDevice}}) = "AMDGPU" +_get_device_name(::Union{LuxMetalDevice, Type{<:LuxMetalDevice}}) = "Metal" -_get_triggerpkg_name(::LuxCPUDevice) = "" -_get_triggerpkg_name(::LuxCUDADevice) = "LuxCUDA" -_get_triggerpkg_name(::LuxAMDGPUDevice) = "LuxAMDGPU" -_get_triggerpkg_name(::LuxMetalDevice) = "Metal" +_get_triggerpkg_name(::Union{LuxCPUDevice, Type{<:LuxCPUDevice}}) = "" +_get_triggerpkg_name(::Union{LuxCUDADevice, Type{<:LuxCUDADevice}}) = "LuxCUDA" +_get_triggerpkg_name(::Union{LuxAMDGPUDevice, Type{<:LuxAMDGPUDevice}}) = "LuxAMDGPU" +_get_triggerpkg_name(::Union{LuxMetalDevice, Type{<:LuxMetalDevice}}) = "Metal" + +_get_adaptor(::LuxCPUDevice) = LuxCPUAdaptor() +_get_adaptor(dev::LuxCUDADevice) = LuxCUDAAdaptor(dev.device) +_get_adaptor(dev::LuxAMDGPUDevice) = LuxAMDGPUAdaptor(dev.device) +_get_adaptor(::LuxMetalDevice) = LuxMetalAdaptor() + +_get_device_id(::LuxCPUDevice) = nothing +_get_device_id(::LuxCUDADevice{Nothing}) = nothing +_get_device_id(::LuxAMDGPUDevice{Nothing}) = nothing +_get_device_id(::LuxMetalDevice) = nothing Base.show(io::IO, dev::AbstractLuxDevice) = print(io, nameof(dev)) struct LuxDeviceSelectionException <: Exception end -function Base.showerror(io::IO, e::LuxDeviceSelectionException) +function Base.showerror(io::IO, ::LuxDeviceSelectionException) return print(io, "LuxDeviceSelectionException(No functional GPU device found!!)") end # Order is important here -const GPU_DEVICES = (LuxCUDADevice(), LuxAMDGPUDevice(), LuxMetalDevice()) +const GPU_DEVICES = (LuxCUDADevice, LuxAMDGPUDevice, LuxMetalDevice) const GPU_DEVICE = Ref{Union{Nothing, AbstractLuxDevice}}(nothing) @@ -57,32 +83,28 @@ const GPU_DEVICE = Ref{Union{Nothing, AbstractLuxDevice}}(nothing) Resets the selected GPU device. This is useful when automatic GPU selection needs to be run again. """ -function reset_gpu_device!() - return GPU_DEVICE[] = nothing -end +reset_gpu_device!() = (GPU_DEVICE[] = nothing) """ supported_gpu_backends() -> Tuple{String, ...} Return a tuple of supported GPU backends. -::: warning +!!! warning -This is not the list of functional backends on the system, but rather backends which -`Lux.jl` supports. + This is not the list of functional backends on the system, but rather backends which + `Lux.jl` supports. -::: +!!! danger -::: danger - -`Metal.jl` support is **extremely** experimental and most things are not expected to work. - -::: + `Metal.jl` support is **extremely** experimental and most things are not expected to + work. """ supported_gpu_backends() = map(_get_device_name, GPU_DEVICES) """ - gpu_device(; force_gpu_usage::Bool=false) -> AbstractLuxDevice() + gpu_device(device_id::Union{Nothing, Int}=nothing; + force_gpu_usage::Bool=false) -> AbstractLuxDevice() Selects GPU device based on the following criteria: @@ -94,15 +116,44 @@ Selects GPU device based on the following criteria: 3. If no GPU device is functional and `force_gpu_usage` is `false`, then `cpu_device()` is invoked. 4. If nothing works, an error is thrown. + +## Arguments + + - `device_id::Union{Nothing, Int}`: The device id to select. If `nothing`, then we return + the last selected device or if none was selected then we run the autoselection and + choose the current device using `CUDA.device()` or `AMDGPU.device()` or similar. If + `Int`, then we select the device with the given id. Note that this is `1`-indexed, in + contrast to the `0`-indexed `CUDA.jl`. For example, `id = 4` corresponds to + `CUDA.device!(3)`. + +!!! warning + + `device_id` is only applicable for `CUDA` and `AMDGPU` backends. For `Metal` and `CPU` + backends, `device_id` is ignored and a warning is printed. + +## Keyword Arguments + + - `force_gpu_usage::Bool`: If `true`, then an error is thrown if no functional GPU + device is found. """ -function gpu_device(; force_gpu_usage::Bool=false)::AbstractLuxDevice +function gpu_device(device_id::Union{Nothing, Int}=nothing; + force_gpu_usage::Bool=false)::AbstractLuxDevice + device_id == 0 && throw(ArgumentError("`device_id` is 1-indexed.")) + if GPU_DEVICE[] !== nothing - force_gpu_usage && !(GPU_DEVICE[] isa AbstractLuxGPUDevice) && - throw(LuxDeviceSelectionException()) - return GPU_DEVICE[] + dev = GPU_DEVICE[] + if device_id === nothing + force_gpu_usage && !(dev isa AbstractLuxGPUDevice) && + throw(LuxDeviceSelectionException()) + return dev + else + selected_device_id = _get_device_id(dev) + selected_device_id !== nothing && selected_device_id == device_id && return dev + end end - device = _get_gpu_device(; force_gpu_usage) + device_type = _get_gpu_device(; force_gpu_usage) + device = _with_device(device_type, device_id) GPU_DEVICE[] = device return device @@ -116,25 +167,25 @@ function _get_gpu_device(; force_gpu_usage::Bool) allowed_backends = supported_gpu_backends() idx = findfirst(isequal(backend), allowed_backends) if backend ∉ allowed_backends - @warn """ - `gpu_backend` preference is set to $backend, which is not a valid backend. - Valid backends are $allowed_backends. - Defaulting to automatic GPU Backend selection. - """ maxlog=1 + @warn "`gpu_backend` preference is set to $backend, which is not a valid \ + backend. Valid backends are $allowed_backends. Defaulting to automatic \ + GPU Backend selection." maxlog=1 else @debug "Using GPU backend set in preferences: $backend." device = GPU_DEVICES[idx] if !__is_loaded(device) - @warn """Trying to use backend: $(_get_device_name(device)) but the trigger package $(device.pkgid) is not loaded. - Ignoring the Preferences backend!!! - Please load the package and call this function again to respect the Preferences backend.""" maxlog=1 + @warn "Trying to use backend: $(_get_device_name(device)) but the trigger \ + package $(device.pkgid) is not loaded. Ignoring the Preferences \ + backend!!! Please load the package and call this function again to \ + respect the Preferences backend." maxlog=1 else if __is_functional(device) @debug "Using GPU backend: $(_get_device_name(device))." return device else - @warn "GPU backend: $(_get_device_name(device)) set via Preferences.jl is not functional. - Defaulting to automatic GPU Backend selection." maxlog=1 + @warn "GPU backend: $(_get_device_name(device)) set via Preferences.jl \ + is not functional. Defaulting to automatic GPU Backend \ + selection." maxlog=1 end end end @@ -150,7 +201,8 @@ function _get_gpu_device(; force_gpu_usage::Bool) end @debug "GPU backend: $(_get_device_name(device)) is not functional." else - @debug "Trigger package for backend ($(_get_device_name(device))): $(_get_trigger_pkgname(device)) not loaded." + @debug "Trigger package for backend ($(_get_device_name(device))): \ + $(_get_trigger_pkgname(device)) not loaded." end end @@ -164,7 +216,7 @@ function _get_gpu_device(; force_gpu_usage::Bool) a. LuxCUDA.jl for NVIDIA CUDA Support. b. LuxAMDGPU.jl for AMD GPU ROCM Support. c. Metal.jl for Apple Metal GPU Support.""" maxlog=1 - return cpu_device() + return LuxCPUDevice end end @@ -188,7 +240,8 @@ gpu_backend!() = gpu_backend!("") function gpu_backend!(backend::String) if backend == "" @delete_preferences!("gpu_backend") - @info "Deleted the local preference for `gpu_backend`. Restart Julia to use the new backend." + @info "Deleted the local preference for `gpu_backend`. Restart Julia to use the \ + new backend." return end @@ -236,22 +289,23 @@ default_device_rng(::LuxCPUDevice) = Random.default_rng() # For Lux, typically models only has these 3 datastructures so we should be mostly fine. for (dev) in (:CPU, :CUDA, :AMDGPU, :Metal) ldev = Symbol("Lux$(dev)Device") - ladaptor = Symbol("Lux$(dev)Adaptor") @eval begin function (D::$(ldev))(x::AbstractArray) - fn = Base.Fix1(adapt, $(ladaptor)()) + ladaptor = _get_adaptor(D) + fn = Base.Fix1(adapt, ladaptor) return _isbitsarray(x) ? fn(x) : map(D, x) end (D::$(ldev))(x::Tuple) = map(D, x) (D::$(ldev))(x::NamedTuple{F}) where {F} = NamedTuple{F}(D(values(x))) - function (::$(ldev))(x) - _isleaf(x) && return adapt($(ladaptor)(), x) - return fmap(Base.Fix1(adapt, $(ladaptor)()), x; exclude=_isleaf) + function (D::$(ldev))(x) + ladaptor = _get_adaptor(D) + _isleaf(x) && return adapt(ladaptor, x) + return fmap(Base.Fix1(adapt, ladaptor), x; exclude=_isleaf) end function (::$(ldev))(NN::LuxCore.AbstractExplicitLayer) @warn "Lux layers are stateless and hence don't participate in device \ - transfers. Apply this function on the parameters and states generated \ - using `Lux.setup`." maxlog=1 + transfers. Apply this function on the parameters and states generated \ + using `Lux.setup`." maxlog=1 return NN end end @@ -264,20 +318,21 @@ end Returns the device of the array `x`. Trigger Packages must be loaded for this to return the correct device. """ -get_device(x::AbstractArray) = LuxCPUDevice() +get_device(::AbstractArray) = LuxCPUDevice() # Adapt Interface abstract type AbstractLuxDeviceAdaptor end struct LuxCPUAdaptor <: AbstractLuxDeviceAdaptor end -struct LuxCUDAAdaptor <: AbstractLuxDeviceAdaptor end -struct LuxAMDGPUAdaptor <: AbstractLuxDeviceAdaptor end +struct LuxCUDAAdaptor{D} <: AbstractLuxDeviceAdaptor + device::D +end +struct LuxAMDGPUAdaptor{D} <: AbstractLuxDeviceAdaptor + device::D +end struct LuxMetalAdaptor <: AbstractLuxDeviceAdaptor end -function adapt_storage(::LuxCPUAdaptor, - x::Union{AbstractRange, SparseArrays.AbstractSparseArray}) - return x -end +adapt_storage(::LuxCPUAdaptor, x::AbstractRange) = x adapt_storage(::LuxCPUAdaptor, x::AbstractArray) = adapt(Array, x) adapt_storage(::LuxCPUAdaptor, rng::AbstractRNG) = rng diff --git a/test/amdgpu.jl b/test/amdgpu.jl index 68e8db0..9247fdb 100644 --- a/test/amdgpu.jl +++ b/test/amdgpu.jl @@ -72,3 +72,26 @@ using FillArrays, Zygote # Extensions @test ps_cpu.farray isa Fill end end + +if LuxAMDGPU.functional() + ps = (; weight=rand(Float32, 10), bias=rand(Float32, 10)) + ps_cpu = deepcopy(ps) + cdev = cpu_device() + for idx in 1:length(AMDGPU.devices()) + amdgpu_device = gpu_device(idx) + @test typeof(amdgpu_device.device) <: AMDGPU.HIPDevice + @test AMDGPU.device_id(amdgpu_device.device) == idx + + global ps = ps |> amdgpu_device + @test ps.weight isa ROCArray + @test ps.bias isa ROCArray + @test AMDGPU.device_id(AMDGPU.device(ps.weight)) == idx + @test AMDGPU.device_id(AMDGPU.device(ps.bias)) == idx + @test isequal(cdev(ps.weight), ps_cpu.weight) + @test isequal(cdev(ps.bias), ps_cpu.bias) + end + + ps = ps |> cdev + @test ps.weight isa Array + @test ps.bias isa Array +end diff --git a/test/cuda.jl b/test/cuda.jl index 613f132..e0dc343 100644 --- a/test/cuda.jl +++ b/test/cuda.jl @@ -72,3 +72,26 @@ using FillArrays, Zygote # Extensions @test ps_cpu.farray isa Fill end end + +if LuxCUDA.functional() + ps = (; weight=rand(Float32, 10), bias=rand(Float32, 10)) + ps_cpu = deepcopy(ps) + cdev = cpu_device() + for idx in 1:length(CUDA.devices()) + cuda_device = gpu_device(idx) + @test typeof(cuda_device.device) <: CUDA.CuDevice + @test cuda_device.device.handle == (idx - 1) + + global ps = ps |> cuda_device + @test ps.weight isa CuArray + @test ps.bias isa CuArray + @test CUDA.device(ps.weight).handle == idx - 1 + @test CUDA.device(ps.bias).handle == idx - 1 + @test isequal(cdev(ps.weight), ps_cpu.weight) + @test isequal(cdev(ps.bias), ps_cpu.bias) + end + + ps = ps |> cdev + @test ps.weight isa Array + @test ps.bias isa Array +end