From 5effa9ca836f13521748e172f238521e854fe454 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 24 Feb 2024 14:25:02 -0500 Subject: [PATCH] Fix ambiguity problems --- ext/LuxDeviceUtilsLuxAMDGPUExt.jl | 2 ++ ext/LuxDeviceUtilsLuxCUDAExt.jl | 2 ++ test/amdgpu.jl | 2 +- test/cuda.jl | 2 +- 4 files changed, 6 insertions(+), 2 deletions(-) diff --git a/ext/LuxDeviceUtilsLuxAMDGPUExt.jl b/ext/LuxDeviceUtilsLuxAMDGPUExt.jl index 0a8ea7d..be83184 100644 --- a/ext/LuxDeviceUtilsLuxAMDGPUExt.jl +++ b/ext/LuxDeviceUtilsLuxAMDGPUExt.jl @@ -50,7 +50,9 @@ function adapt_storage(to::LuxAMDGPUAdaptor, x) return x_new end end +adapt_storage(::LuxAMDGPUAdaptor{Nothing}, rng::AbstractRNG) = rng adapt_storage(::LuxAMDGPUAdaptor, rng::AbstractRNG) = rng +adapt_storage(::LuxAMDGPUAdaptor{Nothing}, rng::Random.TaskLocalRNG) = AMDGPU.rocrand_rng() adapt_storage(::LuxAMDGPUAdaptor, rng::Random.TaskLocalRNG) = AMDGPU.rocrand_rng() adapt_storage(::LuxCPUAdaptor, rng::AMDGPU.rocRAND.RNG) = Random.default_rng() diff --git a/ext/LuxDeviceUtilsLuxCUDAExt.jl b/ext/LuxDeviceUtilsLuxCUDAExt.jl index 49a1e0b..09cfaac 100644 --- a/ext/LuxDeviceUtilsLuxCUDAExt.jl +++ b/ext/LuxDeviceUtilsLuxCUDAExt.jl @@ -50,7 +50,9 @@ function adapt_storage(to::LuxCUDAAdaptor, x) return x_new end end +adapt_storage(::LuxCUDAAdaptor{Nothing}, rng::AbstractRNG) = rng adapt_storage(::LuxCUDAAdaptor, rng::AbstractRNG) = rng +adapt_storage(::LuxCUDAAdaptor{Nothing}, rng::Random.TaskLocalRNG) = CUDA.default_rng() adapt_storage(::LuxCUDAAdaptor, rng::Random.TaskLocalRNG) = CUDA.default_rng() adapt_storage(::LuxCPUAdaptor, rng::CUDA.RNG) = Random.default_rng() diff --git a/test/amdgpu.jl b/test/amdgpu.jl index 3675a0e..9247fdb 100644 --- a/test/amdgpu.jl +++ b/test/amdgpu.jl @@ -82,7 +82,7 @@ if LuxAMDGPU.functional() @test typeof(amdgpu_device.device) <: AMDGPU.HIPDevice @test AMDGPU.device_id(amdgpu_device.device) == idx - ps = ps |> amdgpu_device + global ps = ps |> amdgpu_device @test ps.weight isa ROCArray @test ps.bias isa ROCArray @test AMDGPU.device_id(AMDGPU.device(ps.weight)) == idx diff --git a/test/cuda.jl b/test/cuda.jl index 9a7c2c3..e0dc343 100644 --- a/test/cuda.jl +++ b/test/cuda.jl @@ -82,7 +82,7 @@ if LuxCUDA.functional() @test typeof(cuda_device.device) <: CUDA.CuDevice @test cuda_device.device.handle == (idx - 1) - ps = ps |> cuda_device + global ps = ps |> cuda_device @test ps.weight isa CuArray @test ps.bias isa CuArray @test CUDA.device(ps.weight).handle == idx - 1