Skip to content

Commit

Permalink
Fix set_device for AMDGPU
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Apr 7, 2024
1 parent 0a4e5e9 commit e79de33
Show file tree
Hide file tree
Showing 3 changed files with 3 additions and 7 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "LuxDeviceUtils"
uuid = "34f89e08-e1d5-43b4-8944-0b49ac560553"
authors = ["Avik Pal <[email protected]> and contributors"]
version = "0.1.19"
version = "0.1.20"

[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
Expand Down
6 changes: 1 addition & 5 deletions ext/LuxDeviceUtilsAMDGPUExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,7 @@ function LuxDeviceUtils.set_device!(::Type{LuxAMDGPUDevice}, dev::AMDGPU.HIPDevi
return
end
function LuxDeviceUtils.set_device!(::Type{LuxAMDGPUDevice}, id::Int)
if !AMDGPU.functional()
@warn "AMDGPU is not functional."
return
end
AMDGPU.device!(id)
LuxDeviceUtils.set_device!(LuxAMDGPUDevice, AMDGPU.devices()[id])
return
end
function LuxDeviceUtils.set_device!(::Type{LuxAMDGPUDevice}, ::Nothing, rank::Int)
Expand Down
2 changes: 1 addition & 1 deletion ext/LuxDeviceUtilsCUDAExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ function Adapt.adapt_storage(to::LuxCUDAAdaptor, x)
x_new = CUDA.cu(x)
CUDA.device!(old_dev)
return x_new
elseif CUDA.deviceid(x) == to.device
elseif CUDA.device(x) == to.device
return x
else
CUDA.device!(to.device)
Expand Down

0 comments on commit e79de33

Please sign in to comment.