Skip to content

Commit

Permalink
Fix ambiguity problems
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Feb 24, 2024
1 parent d33b2e3 commit 5effa9c
Show file tree
Hide file tree
Showing 4 changed files with 6 additions and 2 deletions.
2 changes: 2 additions & 0 deletions ext/LuxDeviceUtilsLuxAMDGPUExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,9 @@ function adapt_storage(to::LuxAMDGPUAdaptor, x)
return x_new
end
end
adapt_storage(::LuxAMDGPUAdaptor{Nothing}, rng::AbstractRNG) = rng
adapt_storage(::LuxAMDGPUAdaptor, rng::AbstractRNG) = rng
adapt_storage(::LuxAMDGPUAdaptor{Nothing}, rng::Random.TaskLocalRNG) = AMDGPU.rocrand_rng()
adapt_storage(::LuxAMDGPUAdaptor, rng::Random.TaskLocalRNG) = AMDGPU.rocrand_rng()

adapt_storage(::LuxCPUAdaptor, rng::AMDGPU.rocRAND.RNG) = Random.default_rng()
Expand Down
2 changes: 2 additions & 0 deletions ext/LuxDeviceUtilsLuxCUDAExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,9 @@ function adapt_storage(to::LuxCUDAAdaptor, x)
return x_new
end
end
adapt_storage(::LuxCUDAAdaptor{Nothing}, rng::AbstractRNG) = rng
adapt_storage(::LuxCUDAAdaptor, rng::AbstractRNG) = rng
adapt_storage(::LuxCUDAAdaptor{Nothing}, rng::Random.TaskLocalRNG) = CUDA.default_rng()
adapt_storage(::LuxCUDAAdaptor, rng::Random.TaskLocalRNG) = CUDA.default_rng()

adapt_storage(::LuxCPUAdaptor, rng::CUDA.RNG) = Random.default_rng()
Expand Down
2 changes: 1 addition & 1 deletion test/amdgpu.jl
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ if LuxAMDGPU.functional()
@test typeof(amdgpu_device.device) <: AMDGPU.HIPDevice
@test AMDGPU.device_id(amdgpu_device.device) == idx

ps = ps |> amdgpu_device
global ps = ps |> amdgpu_device
@test ps.weight isa ROCArray
@test ps.bias isa ROCArray
@test AMDGPU.device_id(AMDGPU.device(ps.weight)) == idx
Expand Down
2 changes: 1 addition & 1 deletion test/cuda.jl
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ if LuxCUDA.functional()
@test typeof(cuda_device.device) <: CUDA.CuDevice
@test cuda_device.device.handle == (idx - 1)

ps = ps |> cuda_device
global ps = ps |> cuda_device
@test ps.weight isa CuArray
@test ps.bias isa CuArray
@test CUDA.device(ps.weight).handle == idx - 1
Expand Down

0 comments on commit 5effa9c

Please sign in to comment.