Skip to content

Commit

Permalink
test: test for compile time constant
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Jul 13, 2024
1 parent 435c3dc commit 5db42f2
Show file tree
Hide file tree
Showing 7 changed files with 88 additions and 2 deletions.
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
LuxCore = "bb33d45b-7691-41d6-9220-0943567d0623"
Preferences = "21216c6a-2e73-6563-6e65-726566657250"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
UnrolledUtilities = "0fe1646c-419e-43be-ac14-22321958931b"

[weakdeps]
AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e"
Expand Down Expand Up @@ -65,6 +66,7 @@ SafeTestsets = "0.1"
SparseArrays = "1.10"
Test = "1.10"
Tracker = "0.2.34"
UnrolledUtilities = "0.1.2"
Zygote = "0.6.69"
julia = "1.10"
oneAPI = "1.5"
Expand Down
5 changes: 3 additions & 2 deletions src/LuxDeviceUtils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ using Functors: Functors, fmap, fleaves
using LuxCore: LuxCore
using Preferences: @delete_preferences!, @load_preference, @set_preferences!
using Random: AbstractRNG, Random
using UnrolledUtilities: unrolled_mapreduce

const CRC = ChainRulesCore

Expand Down Expand Up @@ -394,7 +395,7 @@ for op in (:get_device, :get_device_type)

function $(_op)(x::Union{Tuple, NamedTuple})
length(x) == 0 && return $(op == :get_device ? nothing : Nothing)
return mapreduce($(op), __combine_devices, values(x))
return unrolled_mapreduce($(op), __combine_devices, values(x))
end
end

Expand All @@ -406,7 +407,7 @@ end
__recursible_array_eltype(::Type{T}) where {T} = !isbitstype(T) && !(T <: Number)

__combine_devices(::Nothing, ::Nothing) = nothing
__combine_devices(::Type{Nothing}, ::Type{Nothing}) = nothing
__combine_devices(::Type{Nothing}, ::Type{Nothing}) = Nothing
__combine_devices(::Nothing, dev::AbstractLuxDevice) = dev
__combine_devices(::Type{Nothing}, ::Type{T}) where {T <: AbstractLuxDevice} = T
__combine_devices(dev::AbstractLuxDevice, ::Nothing) = dev
Expand Down
17 changes: 17 additions & 0 deletions test/amdgpu_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ using FillArrays, Zygote # Extensions

ps_xpu = ps |> device
@test get_device(ps_xpu) isa LuxAMDGPUDevice
@test get_device_type(ps_xpu) <: LuxAMDGPUDevice
@test ps_xpu.a.c isa aType
@test ps_xpu.b isa aType
@test ps_xpu.a.d == ps.a.d
Expand All @@ -69,6 +70,7 @@ using FillArrays, Zygote # Extensions

ps_cpu = ps_xpu |> cpu_device()
@test get_device(ps_cpu) isa LuxCPUDevice
@test get_device_type(ps_cpu) <: LuxCPUDevice
@test ps_cpu.a.c isa Array
@test ps_cpu.b isa Array
@test ps_cpu.a.c == ps.a.c
Expand Down Expand Up @@ -99,20 +101,35 @@ using FillArrays, Zygote # Extensions
x = rand(Float32, 10, 2)
x_dev = x |> dev
@test get_device(x_dev) isa parameterless_type(typeof(dev))
@test get_device_type(x_dev) <: parameterless_type(typeof(dev))

if LuxDeviceUtils.functional(LuxAMDGPUDevice)
dev2 = gpu_device(length(AMDGPU.devices()))
x_dev2 = x_dev |> dev2
@test get_device(x_dev2) isa typeof(dev2)
@test get_device_type(x_dev2) <: parameterless_type(typeof(dev2))
end

@testset "get_device_type compile constant" begin
x = rand(10, 10) |> dev
ps = (; weight=x, bias=x, d=(x, x))

return_val(x) = Val(get_device_type(x)) # If it is a compile time constant then type inference will work
@test @inferred(return_val(ps)) isa Val{parameterless_type(typeof(dev))}

return_val2(x) = Val(get_device(x))
@test_throws ErrorException @inferred(return_val2(ps))
end
end

@testset "Wrapped Arrays" begin
if LuxDeviceUtils.functional(LuxAMDGPUDevice)
x = rand(10, 10) |> LuxAMDGPUDevice()
@test get_device(x) isa LuxAMDGPUDevice
@test get_device_type(x) <: LuxAMDGPUDevice
x_view = view(x, 1:5, 1:5)
@test get_device(x_view) isa LuxAMDGPUDevice
@test get_device_type(x_view) <: LuxAMDGPUDevice
end
end

Expand Down
23 changes: 23 additions & 0 deletions test/cuda_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ using FillArrays, Zygote # Extensions

ps_xpu = ps |> device
@test get_device(ps_xpu) isa LuxCUDADevice
@test get_device_type(ps_xpu) <: LuxCUDADevice
@test ps_xpu.a.c isa aType
@test ps_xpu.b isa aType
@test ps_xpu.a.d == ps.a.d
Expand All @@ -68,6 +69,7 @@ using FillArrays, Zygote # Extensions

ps_cpu = ps_xpu |> cpu_device()
@test get_device(ps_cpu) isa LuxCPUDevice
@test get_device_type(ps_cpu) <: LuxCPUDevice
@test ps_cpu.a.c isa Array
@test ps_cpu.b isa Array
@test ps_cpu.a.c == ps.a.c
Expand Down Expand Up @@ -99,36 +101,57 @@ using FillArrays, Zygote # Extensions

data = MyStruct(rand(10))
@test get_device(data) isa LuxCPUDevice
@test get_device_type(data) <: LuxCPUDevice
data_dev = data |> device
if LuxDeviceUtils.functional(LuxCUDADevice)
@test get_device(data_dev) isa LuxCUDADevice
@test get_device_type(data_dev) <: LuxCUDADevice
else
@test get_device(data_dev) isa LuxCPUDevice
@test get_device_type(data_dev) <: LuxCPUDevice
end

ps_mixed = (; a=rand(2), c=(rand(2), 1), st=MyStruct(rand(2)), b=device(rand(2)))
@test get_device(ps_mixed.st) isa LuxCPUDevice
@test get_device_type(ps_mixed.st) <: LuxCPUDevice
@test get_device(ps_mixed.c) isa LuxCPUDevice
@test get_device_type(ps_mixed.c) <: LuxCPUDevice
@test_throws ArgumentError get_device(ps_mixed)
@test_throws ArgumentError get_device_type(ps_mixed)

dev = gpu_device()
x = rand(Float32, 10, 2)
x_dev = x |> dev
@test get_device(x_dev) isa parameterless_type(typeof(dev))
@test get_device_type(x_dev) <: parameterless_type(typeof(dev))

if LuxDeviceUtils.functional(LuxCUDADevice)
dev2 = gpu_device(length(CUDA.devices()))
x_dev2 = x_dev |> dev2
@test get_device(x_dev2) isa typeof(dev2)
@test get_device_type(x_dev2) <: parameterless_type(typeof(dev2))
end

@testset "get_device_type compile constant" begin
x = rand(10, 10) |> dev
ps = (; weight=x, bias=x, d=(x, x))

return_val(x) = Val(get_device_type(x)) # If it is a compile time constant then type inference will work
@test @inferred(return_val(ps)) isa Val{parameterless_type(typeof(dev))}

return_val2(x) = Val(get_device(x))
@test_throws ErrorException @inferred(return_val2(ps))
end
end

@testset "Wrapped Arrays" begin
if LuxDeviceUtils.functional(LuxCUDADevice)
x = rand(10, 10) |> LuxCUDADevice()
@test get_device(x) isa LuxCUDADevice
@test get_device_type(x) <: LuxCUDADevice
x_view = view(x, 1:5, 1:5)
@test get_device(x_view) isa LuxCUDADevice
@test get_device_type(x_view) <: LuxCUDADevice
end
end

Expand Down
16 changes: 16 additions & 0 deletions test/metal_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ using FillArrays, Zygote # Extensions

ps_xpu = ps |> device
@test get_device(ps_xpu) isa LuxMetalDevice
@test get_device_type(ps_xpu) <: LuxMetalDevice
@test ps_xpu.a.c isa aType
@test ps_xpu.b isa aType
@test ps_xpu.a.d == ps.a.d
Expand All @@ -66,6 +67,7 @@ using FillArrays, Zygote # Extensions

ps_cpu = ps_xpu |> cpu_device()
@test get_device(ps_cpu) isa LuxCPUDevice
@test get_device_type(ps_cpu) <: LuxCPUDevice
@test ps_cpu.a.c isa Array
@test ps_cpu.b isa Array
@test ps_cpu.a.c == ps.a.c
Expand All @@ -91,14 +93,28 @@ using FillArrays, Zygote # Extensions

ps_mixed = (; a=rand(2), b=device(rand(2)))
@test_throws ArgumentError get_device(ps_mixed)
@test_throws ArgumentError get_device_type(ps_mixed)

@testset "get_device_type compile constant" begin
x = rand(10, 10) |> dev
ps = (; weight=x, bias=x, d=(x, x))

return_val(x) = Val(get_device_type(x)) # If it is a compile time constant then type inference will work
@test @inferred(return_val(ps)) isa Val{parameterless_type(typeof(dev))}

return_val2(x) = Val(get_device(x))
@test @inferred(return_val2(ps)) isa Val{typeof(get_device(x))}
end
end

@testset "Wrapper Arrays" begin
if LuxDeviceUtils.functional(LuxMetalDevice)
x = rand(Float32, 10, 10) |> LuxMetalDevice()
@test get_device(x) isa LuxMetalDevice
@test get_device_type(x) <: LuxMetalDevice
x_view = view(x, 1:5, 1:5)
@test get_device(x_view) isa LuxMetalDevice
@test get_device_type(x_view) <: LuxMetalDevice
end
end

Expand Down
11 changes: 11 additions & 0 deletions test/misc_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -152,3 +152,14 @@ end
transfers. Apply this function on the parameters and states generated \
using `Lux.setup`.") dev(my_layer)
end

@testset "get_device_type compile constant" begin
x = rand(10, 10)
ps = (; weight=x, bias=x, d=(x, x))

return_val(x) = Val(get_device_type(x)) # If it is a compile time constant then type inference will work
@test @inferred(return_val(ps)) isa Val{typeof(cpu_device())}

return_val2(x) = Val(get_device(x))
@test @inferred(return_val2(ps)) isa Val{cpu_device()}
end
16 changes: 16 additions & 0 deletions test/oneapi_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ using FillArrays, Zygote # Extensions

ps_xpu = ps |> device
@test get_device(ps_xpu) isa LuxoneAPIDevice
@test get_device_type(ps_xpu) <: LuxoneAPIDevice
@test ps_xpu.a.c isa aType
@test ps_xpu.b isa aType
@test ps_xpu.a.d == ps.a.d
Expand All @@ -66,6 +67,7 @@ using FillArrays, Zygote # Extensions

ps_cpu = ps_xpu |> cpu_device()
@test get_device(ps_cpu) isa LuxCPUDevice
@test get_device_type(ps_cpu) <: LuxCPUDevice
@test ps_cpu.a.c isa Array
@test ps_cpu.b isa Array
@test ps_cpu.a.c == ps.a.c
Expand All @@ -91,14 +93,28 @@ using FillArrays, Zygote # Extensions

ps_mixed = (; a=rand(2), b=device(rand(2)))
@test_throws ArgumentError get_device(ps_mixed)
@test_throws ArgumentError get_device_type(ps_mixed)

@testset "get_device_type compile constant" begin
x = rand(10, 10) |> dev
ps = (; weight=x, bias=x, d=(x, x))

return_val(x) = Val(get_device_type(x)) # If it is a compile time constant then type inference will work
@test @inferred(return_val(ps)) isa Val{parameterless_type(typeof(dev))}

return_val2(x) = Val(get_device(x))
@test @inferred(return_val2(ps)) isa Val{typeof(get_device(x))}
end
end

@testset "Wrapper Arrays" begin
if LuxDeviceUtils.functional(LuxoneAPIDevice)
x = rand(10, 10) |> LuxoneAPIDevice()
@test get_device(x) isa LuxoneAPIDevice
@test get_device_type(x) <: LuxoneAPIDevice
x_view = view(x, 1:5, 1:5)
@test get_device(x_view) isa LuxoneAPIDevice
@test get_device_type(x_view) <: LuxoneAPIDevice
end
end

Expand Down

0 comments on commit 5db42f2

Please sign in to comment.