Skip to content

Commit

Permalink
Fix get_device for multi-gpu
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Feb 25, 2024
1 parent 5effa9c commit 8a9985c
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 2 deletions.
2 changes: 1 addition & 1 deletion ext/LuxDeviceUtilsLuxAMDGPUExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ LuxDeviceUtils._get_device_id(dev::LuxAMDGPUDevice) = AMDGPU.device_id(dev.devic
LuxDeviceUtils.default_device_rng(::LuxAMDGPUDevice) = AMDGPU.rocrand_rng()

# Query Device from Array
LuxDeviceUtils.get_device(::AMDGPU.AnyROCArray) = LuxAMDGPUDevice()
LuxDeviceUtils.get_device(x::AMDGPU.AnyROCArray) = LuxAMDGPUDevice(AMDGPU.device(x))

Check warning on line 32 in ext/LuxDeviceUtilsLuxAMDGPUExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LuxDeviceUtilsLuxAMDGPUExt.jl#L32

Added line #L32 was not covered by tests

# Device Transfer
## To GPU
Expand Down
2 changes: 1 addition & 1 deletion ext/LuxDeviceUtilsLuxCUDAExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ LuxDeviceUtils._get_device_id(dev::LuxCUDADevice) = CUDA.deviceid(dev.device) +
LuxDeviceUtils.default_device_rng(::LuxCUDADevice) = CUDA.default_rng()

# Query Device from Array
LuxDeviceUtils.get_device(::CUDA.AnyCuArray) = LuxCUDADevice()
LuxDeviceUtils.get_device(x::CUDA.AnyCuArray) = LuxCUDADevice(CUDA.device(x))

Check warning on line 32 in ext/LuxDeviceUtilsLuxCUDAExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LuxDeviceUtilsLuxCUDAExt.jl#L32

Added line #L32 was not covered by tests

# Device Transfer
## To GPU
Expand Down

0 comments on commit 8a9985c

Please sign in to comment.