diff --git a/test/cuda.jl b/test/cuda.jl index 5c4a7ee..f19baec 100644 --- a/test/cuda.jl +++ b/test/cuda.jl @@ -1,4 +1,4 @@ -using LuxDeviceUtils, Random +using LuxDeviceUtils, Random, Functors using ArrayInterface: parameterless_type @testset "CPU Fallback" begin @@ -91,7 +91,15 @@ using FillArrays, Zygote # Extensions @test ps_cpu.farray isa Fill end - ps_mixed = (; a=rand(2), b=device(rand(2))) + struct MyStruct + x + end + + Functors.@functor MyStruct + + ps_mixed = (; a=rand(2), c=(rand(2), 1), st=MyStruct(rand(2)), b=device(rand(2))) + @test get_device(ps_mixed.st) isa LuxCPUDevice + @test get_device(ps_mixed.c) isa LuxCPUDevice @test_throws ArgumentError get_device(ps_mixed) dev = gpu_device() diff --git a/test/misc.jl b/test/misc.jl index 6d59372..b8a5590 100644 --- a/test/misc.jl +++ b/test/misc.jl @@ -106,3 +106,26 @@ end @test get_device(x) isa LuxCPUDevice @test get_device(x_view) isa LuxCPUDevice end + +@testset "loaded and functional" begin + @test LuxDeviceUtils.loaded(LuxCPUDevice) + @test LuxDeviceUtils.functional(LuxCPUDevice) +end + +@testset "writing to preferences" begin + @test_logs (:info, + "Deleted the local preference for `gpu_backend`. Restart Julia to use the new backend.") gpu_backend!() + + for backend in (:CUDA, :AMDGPU, :oneAPI, :Metal, LuxAMDGPUDevice(), + LuxCUDADevice(), LuxMetalDevice(), LuxoneAPIDevice()) + backend_name = backend isa Symbol ? string(backend) : + LuxDeviceUtils._get_device_name(backend) + @test_logs (:info, + "GPU backend has been set to $(backend_name). Restart Julia to use the new backend.") gpu_backend!(backend) + end + + gpu_backend!(:CUDA) + @test_logs (:info, "GPU backend is already set to CUDA. No action is required.") gpu_backend!(:CUDA) + + @test_throws ArgumentError gpu_backend!("my_backend") +end