diff --git a/ext/LuxDeviceUtilsAMDGPUExt.jl b/ext/LuxDeviceUtilsAMDGPUExt.jl index d39c8f9..93a8c84 100644 --- a/ext/LuxDeviceUtilsAMDGPUExt.jl +++ b/ext/LuxDeviceUtilsAMDGPUExt.jl @@ -30,7 +30,7 @@ end function LuxDeviceUtils._with_device(::Type{LuxAMDGPUDevice}, ::Nothing) return LuxAMDGPUDevice(nothing) end -function LuxDeviceUtils._with_device(::Type{LuxAMDGPUDevice}, id::Int) +function LuxDeviceUtils._with_device(::Type{LuxAMDGPUDevice}, id::Integer) id > length(AMDGPU.devices()) && throw(ArgumentError("id = $id > length(AMDGPU.devices()) = $(length(AMDGPU.devices()))")) old_dev = AMDGPU.device() @@ -56,10 +56,10 @@ end function LuxDeviceUtils.set_device!(::Type{LuxAMDGPUDevice}, dev::AMDGPU.HIPDevice) return AMDGPU.device!(dev) end -function LuxDeviceUtils.set_device!(::Type{LuxAMDGPUDevice}, id::Int) +function LuxDeviceUtils.set_device!(::Type{LuxAMDGPUDevice}, id::Integer) return LuxDeviceUtils.set_device!(LuxAMDGPUDevice, AMDGPU.devices()[id]) end -function LuxDeviceUtils.set_device!(::Type{LuxAMDGPUDevice}, ::Nothing, rank::Int) +function LuxDeviceUtils.set_device!(::Type{LuxAMDGPUDevice}, ::Nothing, rank::Integer) id = mod1(rank + 1, length(AMDGPU.devices())) return LuxDeviceUtils.set_device!(LuxAMDGPUDevice, id) end diff --git a/ext/LuxDeviceUtilsCUDAExt.jl b/ext/LuxDeviceUtilsCUDAExt.jl index 19cc144..29ff65c 100644 --- a/ext/LuxDeviceUtilsCUDAExt.jl +++ b/ext/LuxDeviceUtilsCUDAExt.jl @@ -6,7 +6,7 @@ using CUDA.CUSPARSE: AbstractCuSparseMatrix, AbstractCuSparseVector using LuxDeviceUtils: LuxDeviceUtils, LuxCUDADevice, LuxCPUDevice using Random: Random -function LuxDeviceUtils._with_device(::Type{LuxCUDADevice}, id::Int) +function LuxDeviceUtils._with_device(::Type{LuxCUDADevice}, id::Integer) id > length(CUDA.devices()) && throw(ArgumentError("id = $id > length(CUDA.devices()) = $(length(CUDA.devices()))")) old_dev = CUDA.device() @@ -39,10 +39,10 @@ end function LuxDeviceUtils.set_device!(::Type{LuxCUDADevice}, dev::CUDA.CuDevice) return CUDA.device!(dev) end -function LuxDeviceUtils.set_device!(::Type{LuxCUDADevice}, id::Int) +function LuxDeviceUtils.set_device!(::Type{LuxCUDADevice}, id::Integer) return LuxDeviceUtils.set_device!(LuxCUDADevice, collect(CUDA.devices())[id]) end -function LuxDeviceUtils.set_device!(::Type{LuxCUDADevice}, ::Nothing, rank::Int) +function LuxDeviceUtils.set_device!(::Type{LuxCUDADevice}, ::Nothing, rank::Integer) id = mod1(rank + 1, length(CUDA.devices())) return LuxDeviceUtils.set_device!(LuxCUDADevice, id) end diff --git a/src/LuxDeviceUtils.jl b/src/LuxDeviceUtils.jl index bd43c51..b1c9eb5 100644 --- a/src/LuxDeviceUtils.jl +++ b/src/LuxDeviceUtils.jl @@ -125,7 +125,7 @@ Return a tuple of supported GPU backends. @inline supported_gpu_backends() = map(_get_device_name, GPU_DEVICES) """ - gpu_device(device_id::Union{Nothing, Int}=nothing; + gpu_device(device_id::Union{Nothing, Integer}=nothing; force_gpu_usage::Bool=false) -> AbstractLuxDevice() Selects GPU device based on the following criteria: @@ -141,10 +141,10 @@ Selects GPU device based on the following criteria: ## Arguments - - `device_id::Union{Nothing, Int}`: The device id to select. If `nothing`, then we return + - `device_id::Union{Nothing, Integer}`: The device id to select. If `nothing`, then we return the last selected device or if none was selected then we run the autoselection and choose the current device using `CUDA.device()` or `AMDGPU.device()` or similar. If - `Int`, then we select the device with the given id. Note that this is `1`-indexed, in + `Integer`, then we select the device with the given id. Note that this is `1`-indexed, in contrast to the `0`-indexed `CUDA.jl`. For example, `id = 4` corresponds to `CUDA.device!(3)`. @@ -158,7 +158,7 @@ Selects GPU device based on the following criteria: - `force_gpu_usage::Bool`: If `true`, then an error is thrown if no functional GPU device is found. """ -function gpu_device(device_id::Union{Nothing, Int}=nothing; +function gpu_device(device_id::Union{Nothing, <:Integer}=nothing; force_gpu_usage::Bool=false)::AbstractLuxDevice device_id == 0 && throw(ArgumentError("`device_id` is 1-indexed.")) @@ -426,19 +426,19 @@ function set_device!(::Type{T}, dev_or_id) where {T <: AbstractLuxDevice} end """ - set_device!(T::Type{<:AbstractLuxDevice}, ::Nothing, rank::Int) + set_device!(T::Type{<:AbstractLuxDevice}, ::Nothing, rank::Integer) $SET_DEVICE_DOCS ## Arguments - `T::Type{<:AbstractLuxDevice}`: The device type to set. - - `rank::Int`: Local Rank of the process. This is applicable for distributed training and + - `rank::Integer`: Local Rank of the process. This is applicable for distributed training and must be `0`-indexed. $SET_DEVICE_DANGER """ -function set_device!(::Type{T}, ::Nothing, rank::Int) where {T <: AbstractLuxDevice} +function set_device!(::Type{T}, ::Nothing, rank::Integer) where {T <: AbstractLuxDevice} return set_device!(T, rank) end diff --git a/test/amdgpu.jl b/test/amdgpu.jl index 4840b98..159b241 100644 --- a/test/amdgpu.jl +++ b/test/amdgpu.jl @@ -1,4 +1,5 @@ using LuxDeviceUtils, Random +using ArrayInterface: parameterless_type @testset "CPU Fallback" begin @test !LuxDeviceUtils.functional(LuxAMDGPUDevice) @@ -93,6 +94,17 @@ using FillArrays, Zygote # Extensions ps_mixed = (; a=rand(2), b=device(rand(2))) @test_throws ArgumentError get_device(ps_mixed) + + dev = gpu_device() + x = rand(Float32, 10, 2) + x_dev = x |> dev + @test get_device(x_dev) isa parameterless_type(typeof(dev)) + + if LuxDeviceUtils.functional(LuxAMDGPUDevice) + dev2 = gpu_device(length(AMDGPU.devices())) + x_dev2 = x_dev |> dev2 + @test get_device(x_dev2) isa typeof(dev2) + end end @testset "Wrapped Arrays" begin diff --git a/test/cuda.jl b/test/cuda.jl index 3b1983b..5c4a7ee 100644 --- a/test/cuda.jl +++ b/test/cuda.jl @@ -1,4 +1,5 @@ using LuxDeviceUtils, Random +using ArrayInterface: parameterless_type @testset "CPU Fallback" begin @test !LuxDeviceUtils.functional(LuxCUDADevice) @@ -92,6 +93,17 @@ using FillArrays, Zygote # Extensions ps_mixed = (; a=rand(2), b=device(rand(2))) @test_throws ArgumentError get_device(ps_mixed) + + dev = gpu_device() + x = rand(Float32, 10, 2) + x_dev = x |> dev + @test get_device(x_dev) isa parameterless_type(typeof(dev)) + + if LuxDeviceUtils.functional(LuxCUDADevice) + dev2 = gpu_device(length(CUDA.devices())) + x_dev2 = x_dev |> dev2 + @test get_device(x_dev2) isa typeof(dev2) + end end @testset "Wrapped Arrays" begin