diff --git a/Project.toml b/Project.toml index 9046fcf..3bee1a5 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "LuxDeviceUtils" uuid = "34f89e08-e1d5-43b4-8944-0b49ac560553" authors = ["Avik Pal and contributors"] -version = "0.1.19" +version = "0.1.20" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" diff --git a/ext/LuxDeviceUtilsAMDGPUExt.jl b/ext/LuxDeviceUtilsAMDGPUExt.jl index dab9f84..c88619a 100644 --- a/ext/LuxDeviceUtilsAMDGPUExt.jl +++ b/ext/LuxDeviceUtilsAMDGPUExt.jl @@ -36,11 +36,7 @@ function LuxDeviceUtils.set_device!(::Type{LuxAMDGPUDevice}, dev::AMDGPU.HIPDevi return end function LuxDeviceUtils.set_device!(::Type{LuxAMDGPUDevice}, id::Int) - if !AMDGPU.functional() - @warn "AMDGPU is not functional." - return - end - AMDGPU.device!(id) + LuxDeviceUtils.set_device!(LuxAMDGPUDevice, AMDGPU.devices()[id]) return end function LuxDeviceUtils.set_device!(::Type{LuxAMDGPUDevice}, ::Nothing, rank::Int) diff --git a/ext/LuxDeviceUtilsCUDAExt.jl b/ext/LuxDeviceUtilsCUDAExt.jl index a18ce10..ae6a45f 100644 --- a/ext/LuxDeviceUtilsCUDAExt.jl +++ b/ext/LuxDeviceUtilsCUDAExt.jl @@ -59,7 +59,7 @@ function Adapt.adapt_storage(to::LuxCUDAAdaptor, x) x_new = CUDA.cu(x) CUDA.device!(old_dev) return x_new - elseif CUDA.deviceid(x) == to.device + elseif CUDA.device(x) == to.device return x else CUDA.device!(to.device)