diff --git a/Project.toml b/Project.toml index ad31dda..11719aa 100644 --- a/Project.toml +++ b/Project.toml @@ -10,6 +10,7 @@ Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" LuxCore = "bb33d45b-7691-41d6-9220-0943567d0623" Preferences = "21216c6a-2e73-6563-6e65-726566657250" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +UnrolledUtilities = "0fe1646c-419e-43be-ac14-22321958931b" [weakdeps] AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e" @@ -65,6 +66,7 @@ SafeTestsets = "0.1" SparseArrays = "1.10" Test = "1.10" Tracker = "0.2.34" +UnrolledUtilities = "0.1.2" Zygote = "0.6.69" julia = "1.10" oneAPI = "1.5" diff --git a/src/LuxDeviceUtils.jl b/src/LuxDeviceUtils.jl index 9dc0083..2c3059b 100644 --- a/src/LuxDeviceUtils.jl +++ b/src/LuxDeviceUtils.jl @@ -6,6 +6,7 @@ using Functors: Functors, fmap, fleaves using LuxCore: LuxCore using Preferences: @delete_preferences!, @load_preference, @set_preferences! using Random: AbstractRNG, Random +using UnrolledUtilities: unrolled_mapreduce const CRC = ChainRulesCore @@ -394,7 +395,7 @@ for op in (:get_device, :get_device_type) function $(_op)(x::Union{Tuple, NamedTuple}) length(x) == 0 && return $(op == :get_device ? nothing : Nothing) - return mapreduce($(op), __combine_devices, values(x)) + return unrolled_mapreduce($(op), __combine_devices, values(x)) end end @@ -406,7 +407,7 @@ end __recursible_array_eltype(::Type{T}) where {T} = !isbitstype(T) && !(T <: Number) __combine_devices(::Nothing, ::Nothing) = nothing -__combine_devices(::Type{Nothing}, ::Type{Nothing}) = nothing +__combine_devices(::Type{Nothing}, ::Type{Nothing}) = Nothing __combine_devices(::Nothing, dev::AbstractLuxDevice) = dev __combine_devices(::Type{Nothing}, ::Type{T}) where {T <: AbstractLuxDevice} = T __combine_devices(dev::AbstractLuxDevice, ::Nothing) = dev diff --git a/test/amdgpu_tests.jl b/test/amdgpu_tests.jl index f2e6ebe..9edaa14 100644 --- a/test/amdgpu_tests.jl +++ b/test/amdgpu_tests.jl @@ -46,6 +46,7 @@ using FillArrays, Zygote # Extensions ps_xpu = ps |> device @test get_device(ps_xpu) isa LuxAMDGPUDevice + @test get_device_type(ps_xpu) <: LuxAMDGPUDevice @test ps_xpu.a.c isa aType @test ps_xpu.b isa aType @test ps_xpu.a.d == ps.a.d @@ -69,6 +70,7 @@ using FillArrays, Zygote # Extensions ps_cpu = ps_xpu |> cpu_device() @test get_device(ps_cpu) isa LuxCPUDevice + @test get_device_type(ps_cpu) <: LuxCPUDevice @test ps_cpu.a.c isa Array @test ps_cpu.b isa Array @test ps_cpu.a.c == ps.a.c @@ -99,11 +101,24 @@ using FillArrays, Zygote # Extensions x = rand(Float32, 10, 2) x_dev = x |> dev @test get_device(x_dev) isa parameterless_type(typeof(dev)) + @test get_device_type(x_dev) <: parameterless_type(typeof(dev)) if LuxDeviceUtils.functional(LuxAMDGPUDevice) dev2 = gpu_device(length(AMDGPU.devices())) x_dev2 = x_dev |> dev2 @test get_device(x_dev2) isa typeof(dev2) + @test get_device_type(x_dev2) <: parameterless_type(typeof(dev2)) + end + + @testset "get_device_type compile constant" begin + x = rand(10, 10) |> dev + ps = (; weight=x, bias=x, d=(x, x)) + + return_val(x) = Val(get_device_type(x)) # If it is a compile time constant then type inference will work + @test @inferred(return_val(ps)) isa Val{parameterless_type(typeof(dev))} + + return_val2(x) = Val(get_device(x)) + @test_throws ErrorException @inferred(return_val2(ps)) end end @@ -111,8 +126,10 @@ end if LuxDeviceUtils.functional(LuxAMDGPUDevice) x = rand(10, 10) |> LuxAMDGPUDevice() @test get_device(x) isa LuxAMDGPUDevice + @test get_device_type(x) <: LuxAMDGPUDevice x_view = view(x, 1:5, 1:5) @test get_device(x_view) isa LuxAMDGPUDevice + @test get_device_type(x_view) <: LuxAMDGPUDevice end end diff --git a/test/cuda_tests.jl b/test/cuda_tests.jl index d8e9217..fee6d2c 100644 --- a/test/cuda_tests.jl +++ b/test/cuda_tests.jl @@ -45,6 +45,7 @@ using FillArrays, Zygote # Extensions ps_xpu = ps |> device @test get_device(ps_xpu) isa LuxCUDADevice + @test get_device_type(ps_xpu) <: LuxCUDADevice @test ps_xpu.a.c isa aType @test ps_xpu.b isa aType @test ps_xpu.a.d == ps.a.d @@ -68,6 +69,7 @@ using FillArrays, Zygote # Extensions ps_cpu = ps_xpu |> cpu_device() @test get_device(ps_cpu) isa LuxCPUDevice + @test get_device_type(ps_cpu) <: LuxCPUDevice @test ps_cpu.a.c isa Array @test ps_cpu.b isa Array @test ps_cpu.a.c == ps.a.c @@ -99,27 +101,46 @@ using FillArrays, Zygote # Extensions data = MyStruct(rand(10)) @test get_device(data) isa LuxCPUDevice + @test get_device_type(data) <: LuxCPUDevice data_dev = data |> device if LuxDeviceUtils.functional(LuxCUDADevice) @test get_device(data_dev) isa LuxCUDADevice + @test get_device_type(data_dev) <: LuxCUDADevice else @test get_device(data_dev) isa LuxCPUDevice + @test get_device_type(data_dev) <: LuxCPUDevice end ps_mixed = (; a=rand(2), c=(rand(2), 1), st=MyStruct(rand(2)), b=device(rand(2))) @test get_device(ps_mixed.st) isa LuxCPUDevice + @test get_device_type(ps_mixed.st) <: LuxCPUDevice @test get_device(ps_mixed.c) isa LuxCPUDevice + @test get_device_type(ps_mixed.c) <: LuxCPUDevice @test_throws ArgumentError get_device(ps_mixed) + @test_throws ArgumentError get_device_type(ps_mixed) dev = gpu_device() x = rand(Float32, 10, 2) x_dev = x |> dev @test get_device(x_dev) isa parameterless_type(typeof(dev)) + @test get_device_type(x_dev) <: parameterless_type(typeof(dev)) if LuxDeviceUtils.functional(LuxCUDADevice) dev2 = gpu_device(length(CUDA.devices())) x_dev2 = x_dev |> dev2 @test get_device(x_dev2) isa typeof(dev2) + @test get_device_type(x_dev2) <: parameterless_type(typeof(dev2)) + end + + @testset "get_device_type compile constant" begin + x = rand(10, 10) |> dev + ps = (; weight=x, bias=x, d=(x, x)) + + return_val(x) = Val(get_device_type(x)) # If it is a compile time constant then type inference will work + @test @inferred(return_val(ps)) isa Val{parameterless_type(typeof(dev))} + + return_val2(x) = Val(get_device(x)) + @test_throws ErrorException @inferred(return_val2(ps)) end end @@ -127,8 +148,10 @@ end if LuxDeviceUtils.functional(LuxCUDADevice) x = rand(10, 10) |> LuxCUDADevice() @test get_device(x) isa LuxCUDADevice + @test get_device_type(x) <: LuxCUDADevice x_view = view(x, 1:5, 1:5) @test get_device(x_view) isa LuxCUDADevice + @test get_device_type(x_view) <: LuxCUDADevice end end diff --git a/test/metal_tests.jl b/test/metal_tests.jl index 1e7ce23..9c15dc5 100644 --- a/test/metal_tests.jl +++ b/test/metal_tests.jl @@ -43,6 +43,7 @@ using FillArrays, Zygote # Extensions ps_xpu = ps |> device @test get_device(ps_xpu) isa LuxMetalDevice + @test get_device_type(ps_xpu) <: LuxMetalDevice @test ps_xpu.a.c isa aType @test ps_xpu.b isa aType @test ps_xpu.a.d == ps.a.d @@ -66,6 +67,7 @@ using FillArrays, Zygote # Extensions ps_cpu = ps_xpu |> cpu_device() @test get_device(ps_cpu) isa LuxCPUDevice + @test get_device_type(ps_cpu) <: LuxCPUDevice @test ps_cpu.a.c isa Array @test ps_cpu.b isa Array @test ps_cpu.a.c == ps.a.c @@ -91,14 +93,28 @@ using FillArrays, Zygote # Extensions ps_mixed = (; a=rand(2), b=device(rand(2))) @test_throws ArgumentError get_device(ps_mixed) + @test_throws ArgumentError get_device_type(ps_mixed) + + @testset "get_device_type compile constant" begin + x = rand(10, 10) |> dev + ps = (; weight=x, bias=x, d=(x, x)) + + return_val(x) = Val(get_device_type(x)) # If it is a compile time constant then type inference will work + @test @inferred(return_val(ps)) isa Val{parameterless_type(typeof(dev))} + + return_val2(x) = Val(get_device(x)) + @test @inferred(return_val2(ps)) isa Val{typeof(get_device(x))} + end end @testset "Wrapper Arrays" begin if LuxDeviceUtils.functional(LuxMetalDevice) x = rand(Float32, 10, 10) |> LuxMetalDevice() @test get_device(x) isa LuxMetalDevice + @test get_device_type(x) <: LuxMetalDevice x_view = view(x, 1:5, 1:5) @test get_device(x_view) isa LuxMetalDevice + @test get_device_type(x_view) <: LuxMetalDevice end end diff --git a/test/misc_tests.jl b/test/misc_tests.jl index 681f890..dd0ef8e 100644 --- a/test/misc_tests.jl +++ b/test/misc_tests.jl @@ -152,3 +152,14 @@ end transfers. Apply this function on the parameters and states generated \ using `Lux.setup`.") dev(my_layer) end + +@testset "get_device_type compile constant" begin + x = rand(10, 10) + ps = (; weight=x, bias=x, d=(x, x)) + + return_val(x) = Val(get_device_type(x)) # If it is a compile time constant then type inference will work + @test @inferred(return_val(ps)) isa Val{typeof(cpu_device())} + + return_val2(x) = Val(get_device(x)) + @test @inferred(return_val2(ps)) isa Val{cpu_device()} +end diff --git a/test/oneapi_tests.jl b/test/oneapi_tests.jl index 9cdd9ef..19568ec 100644 --- a/test/oneapi_tests.jl +++ b/test/oneapi_tests.jl @@ -43,6 +43,7 @@ using FillArrays, Zygote # Extensions ps_xpu = ps |> device @test get_device(ps_xpu) isa LuxoneAPIDevice + @test get_device_type(ps_xpu) <: LuxoneAPIDevice @test ps_xpu.a.c isa aType @test ps_xpu.b isa aType @test ps_xpu.a.d == ps.a.d @@ -66,6 +67,7 @@ using FillArrays, Zygote # Extensions ps_cpu = ps_xpu |> cpu_device() @test get_device(ps_cpu) isa LuxCPUDevice + @test get_device_type(ps_cpu) <: LuxCPUDevice @test ps_cpu.a.c isa Array @test ps_cpu.b isa Array @test ps_cpu.a.c == ps.a.c @@ -91,14 +93,28 @@ using FillArrays, Zygote # Extensions ps_mixed = (; a=rand(2), b=device(rand(2))) @test_throws ArgumentError get_device(ps_mixed) + @test_throws ArgumentError get_device_type(ps_mixed) + + @testset "get_device_type compile constant" begin + x = rand(10, 10) |> dev + ps = (; weight=x, bias=x, d=(x, x)) + + return_val(x) = Val(get_device_type(x)) # If it is a compile time constant then type inference will work + @test @inferred(return_val(ps)) isa Val{parameterless_type(typeof(dev))} + + return_val2(x) = Val(get_device(x)) + @test @inferred(return_val2(ps)) isa Val{typeof(get_device(x))} + end end @testset "Wrapper Arrays" begin if LuxDeviceUtils.functional(LuxoneAPIDevice) x = rand(10, 10) |> LuxoneAPIDevice() @test get_device(x) isa LuxoneAPIDevice + @test get_device_type(x) <: LuxoneAPIDevice x_view = view(x, 1:5, 1:5) @test get_device(x_view) isa LuxoneAPIDevice + @test get_device_type(x_view) <: LuxoneAPIDevice end end