Skip to content

Commit

Permalink
Test for potential multi-device
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Jun 7, 2024
1 parent e00224b commit 3106ab7
Show file tree
Hide file tree
Showing 5 changed files with 37 additions and 13 deletions.
6 changes: 3 additions & 3 deletions ext/LuxDeviceUtilsAMDGPUExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ end
function LuxDeviceUtils._with_device(::Type{LuxAMDGPUDevice}, ::Nothing)
return LuxAMDGPUDevice(nothing)
end
function LuxDeviceUtils._with_device(::Type{LuxAMDGPUDevice}, id::Int)
function LuxDeviceUtils._with_device(::Type{LuxAMDGPUDevice}, id::Integer)

Check warning on line 33 in ext/LuxDeviceUtilsAMDGPUExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LuxDeviceUtilsAMDGPUExt.jl#L33

Added line #L33 was not covered by tests
id > length(AMDGPU.devices()) &&
throw(ArgumentError("id = $id > length(AMDGPU.devices()) = $(length(AMDGPU.devices()))"))
old_dev = AMDGPU.device()
Expand All @@ -56,10 +56,10 @@ end
function LuxDeviceUtils.set_device!(::Type{LuxAMDGPUDevice}, dev::AMDGPU.HIPDevice)
return AMDGPU.device!(dev)

Check warning on line 57 in ext/LuxDeviceUtilsAMDGPUExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LuxDeviceUtilsAMDGPUExt.jl#L57

Added line #L57 was not covered by tests
end
function LuxDeviceUtils.set_device!(::Type{LuxAMDGPUDevice}, id::Int)
function LuxDeviceUtils.set_device!(::Type{LuxAMDGPUDevice}, id::Integer)
return LuxDeviceUtils.set_device!(LuxAMDGPUDevice, AMDGPU.devices()[id])

Check warning on line 60 in ext/LuxDeviceUtilsAMDGPUExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LuxDeviceUtilsAMDGPUExt.jl#L59-L60

Added lines #L59 - L60 were not covered by tests
end
function LuxDeviceUtils.set_device!(::Type{LuxAMDGPUDevice}, ::Nothing, rank::Int)
function LuxDeviceUtils.set_device!(::Type{LuxAMDGPUDevice}, ::Nothing, rank::Integer)

Check warning on line 62 in ext/LuxDeviceUtilsAMDGPUExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LuxDeviceUtilsAMDGPUExt.jl#L62

Added line #L62 was not covered by tests
id = mod1(rank + 1, length(AMDGPU.devices()))
return LuxDeviceUtils.set_device!(LuxAMDGPUDevice, id)
end
Expand Down
6 changes: 3 additions & 3 deletions ext/LuxDeviceUtilsCUDAExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ using CUDA.CUSPARSE: AbstractCuSparseMatrix, AbstractCuSparseVector
using LuxDeviceUtils: LuxDeviceUtils, LuxCUDADevice, LuxCPUDevice
using Random: Random

function LuxDeviceUtils._with_device(::Type{LuxCUDADevice}, id::Int)
function LuxDeviceUtils._with_device(::Type{LuxCUDADevice}, id::Integer)
id > length(CUDA.devices()) &&
throw(ArgumentError("id = $id > length(CUDA.devices()) = $(length(CUDA.devices()))"))
old_dev = CUDA.device()
Expand Down Expand Up @@ -39,10 +39,10 @@ end
function LuxDeviceUtils.set_device!(::Type{LuxCUDADevice}, dev::CUDA.CuDevice)
return CUDA.device!(dev)
end
function LuxDeviceUtils.set_device!(::Type{LuxCUDADevice}, id::Int)
function LuxDeviceUtils.set_device!(::Type{LuxCUDADevice}, id::Integer)
return LuxDeviceUtils.set_device!(LuxCUDADevice, collect(CUDA.devices())[id])
end
function LuxDeviceUtils.set_device!(::Type{LuxCUDADevice}, ::Nothing, rank::Int)
function LuxDeviceUtils.set_device!(::Type{LuxCUDADevice}, ::Nothing, rank::Integer)
id = mod1(rank + 1, length(CUDA.devices()))
return LuxDeviceUtils.set_device!(LuxCUDADevice, id)
end
Expand Down
14 changes: 7 additions & 7 deletions src/LuxDeviceUtils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ Return a tuple of supported GPU backends.
@inline supported_gpu_backends() = map(_get_device_name, GPU_DEVICES)

Check warning on line 125 in src/LuxDeviceUtils.jl

View check run for this annotation

Codecov / codecov/patch

src/LuxDeviceUtils.jl#L125

Added line #L125 was not covered by tests

"""
gpu_device(device_id::Union{Nothing, Int}=nothing;
gpu_device(device_id::Union{Nothing, Integer}=nothing;
force_gpu_usage::Bool=false) -> AbstractLuxDevice()
Selects GPU device based on the following criteria:
Expand All @@ -141,10 +141,10 @@ Selects GPU device based on the following criteria:
## Arguments
- `device_id::Union{Nothing, Int}`: The device id to select. If `nothing`, then we return
- `device_id::Union{Nothing, Integer}`: The device id to select. If `nothing`, then we return
the last selected device or if none was selected then we run the autoselection and
choose the current device using `CUDA.device()` or `AMDGPU.device()` or similar. If
`Int`, then we select the device with the given id. Note that this is `1`-indexed, in
`Integer`, then we select the device with the given id. Note that this is `1`-indexed, in
contrast to the `0`-indexed `CUDA.jl`. For example, `id = 4` corresponds to
`CUDA.device!(3)`.
Expand All @@ -158,7 +158,7 @@ Selects GPU device based on the following criteria:
- `force_gpu_usage::Bool`: If `true`, then an error is thrown if no functional GPU
device is found.
"""
function gpu_device(device_id::Union{Nothing, Int}=nothing;
function gpu_device(device_id::Union{Nothing, <:Integer}=nothing;
force_gpu_usage::Bool=false)::AbstractLuxDevice
device_id == 0 && throw(ArgumentError("`device_id` is 1-indexed."))

Expand Down Expand Up @@ -426,19 +426,19 @@ function set_device!(::Type{T}, dev_or_id) where {T <: AbstractLuxDevice}
end

"""
set_device!(T::Type{<:AbstractLuxDevice}, ::Nothing, rank::Int)
set_device!(T::Type{<:AbstractLuxDevice}, ::Nothing, rank::Integer)
$SET_DEVICE_DOCS
## Arguments
- `T::Type{<:AbstractLuxDevice}`: The device type to set.
- `rank::Int`: Local Rank of the process. This is applicable for distributed training and
- `rank::Integer`: Local Rank of the process. This is applicable for distributed training and
must be `0`-indexed.
$SET_DEVICE_DANGER
"""
function set_device!(::Type{T}, ::Nothing, rank::Int) where {T <: AbstractLuxDevice}
function set_device!(::Type{T}, ::Nothing, rank::Integer) where {T <: AbstractLuxDevice}
return set_device!(T, rank)
end

Expand Down
12 changes: 12 additions & 0 deletions test/amdgpu.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using LuxDeviceUtils, Random
using ArrayInterface: parameterless_type

@testset "CPU Fallback" begin
@test !LuxDeviceUtils.functional(LuxAMDGPUDevice)
Expand Down Expand Up @@ -93,6 +94,17 @@ using FillArrays, Zygote # Extensions

ps_mixed = (; a=rand(2), b=device(rand(2)))
@test_throws ArgumentError get_device(ps_mixed)

dev = gpu_device()
x = rand(Float32, 10, 2)
x_dev = x |> dev
@test get_device(x_dev) isa parameterless_type(typeof(dev))

if LuxDeviceUtils.functional(LuxAMDGPUDevice)
dev2 = gpu_device(length(AMDGPU.devices()))
x_dev2 = x_dev |> dev2
@test get_device(x_dev2) isa typeof(dev2)
end
end

@testset "Wrapped Arrays" begin
Expand Down
12 changes: 12 additions & 0 deletions test/cuda.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using LuxDeviceUtils, Random
using ArrayInterface: parameterless_type

@testset "CPU Fallback" begin
@test !LuxDeviceUtils.functional(LuxCUDADevice)
Expand Down Expand Up @@ -92,6 +93,17 @@ using FillArrays, Zygote # Extensions

ps_mixed = (; a=rand(2), b=device(rand(2)))
@test_throws ArgumentError get_device(ps_mixed)

dev = gpu_device()
x = rand(Float32, 10, 2)
x_dev = x |> dev
@test get_device(x_dev) isa parameterless_type(typeof(dev))

if LuxDeviceUtils.functional(LuxCUDADevice)
dev2 = gpu_device(length(CUDA.devices()))
x_dev2 = x_dev |> dev2
@test get_device(x_dev2) isa typeof(dev2)
end
end

@testset "Wrapped Arrays" begin
Expand Down

0 comments on commit 3106ab7

Please sign in to comment.