From 3636a8205debe3415897be9a8e4ae1a233ec83f9 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 7 Jun 2024 00:06:46 -0700 Subject: [PATCH] Add tests for gpu_backend! --- test/cuda.jl | 12 ++++++++++-- test/misc.jl | 23 +++++++++++++++++++++++ 2 files changed, 33 insertions(+), 2 deletions(-) diff --git a/test/cuda.jl b/test/cuda.jl index 5c4a7ee..f19baec 100644 --- a/test/cuda.jl +++ b/test/cuda.jl @@ -1,4 +1,4 @@ -using LuxDeviceUtils, Random +using LuxDeviceUtils, Random, Functors using ArrayInterface: parameterless_type @testset "CPU Fallback" begin @@ -91,7 +91,15 @@ using FillArrays, Zygote # Extensions @test ps_cpu.farray isa Fill end - ps_mixed = (; a=rand(2), b=device(rand(2))) + struct MyStruct + x + end + + Functors.@functor MyStruct + + ps_mixed = (; a=rand(2), c=(rand(2), 1), st=MyStruct(rand(2)), b=device(rand(2))) + @test get_device(ps_mixed.st) isa LuxCPUDevice + @test get_device(ps_mixed.c) isa LuxCPUDevice @test_throws ArgumentError get_device(ps_mixed) dev = gpu_device() diff --git a/test/misc.jl b/test/misc.jl index 6d59372..b8a5590 100644 --- a/test/misc.jl +++ b/test/misc.jl @@ -106,3 +106,26 @@ end @test get_device(x) isa LuxCPUDevice @test get_device(x_view) isa LuxCPUDevice end + +@testset "loaded and functional" begin + @test LuxDeviceUtils.loaded(LuxCPUDevice) + @test LuxDeviceUtils.functional(LuxCPUDevice) +end + +@testset "writing to preferences" begin + @test_logs (:info, + "Deleted the local preference for `gpu_backend`. Restart Julia to use the new backend.") gpu_backend!() + + for backend in (:CUDA, :AMDGPU, :oneAPI, :Metal, LuxAMDGPUDevice(), + LuxCUDADevice(), LuxMetalDevice(), LuxoneAPIDevice()) + backend_name = backend isa Symbol ? string(backend) : + LuxDeviceUtils._get_device_name(backend) + @test_logs (:info, + "GPU backend has been set to $(backend_name). Restart Julia to use the new backend.") gpu_backend!(backend) + end + + gpu_backend!(:CUDA) + @test_logs (:info, "GPU backend is already set to CUDA. No action is required.") gpu_backend!(:CUDA) + + @test_throws ArgumentError gpu_backend!("my_backend") +end