Skip to content

Commit

Permalink
test: functions and closures
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Oct 18, 2024
1 parent b3bef22 commit 37429fb
Show file tree
Hide file tree
Showing 5 changed files with 105 additions and 0 deletions.
21 changes: 21 additions & 0 deletions test/amdgpu_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,27 @@ using FillArrays, Zygote # Extensions
end
end

@testset "Functions" begin
if MLDataDevices.functional(AMDGPUDevice)
@test get_device(tanh) isa MLDataDevices.UnknownDevice
@test get_device_type(tanh) <: MLDataDevices.UnknownDevice

f(x, y) = () -> (x, x .^ 2, y)

ff = f([1, 2, 3], 1)
@test get_device(ff) isa CPUDevice
@test get_device_type(ff) <: CPUDevice

ff_xpu = ff |> AMDGPUDevice()
@test get_device(ff_xpu) isa AMDGPUDevice
@test get_device_type(ff_xpu) <: AMDGPUDevice

ff_cpu = ff_xpu |> cpu_device()
@test get_device(ff_cpu) isa CPUDevice
@test get_device_type(ff_cpu) <: CPUDevice
end
end

@testset "Wrapped Arrays" begin
if MLDataDevices.functional(AMDGPUDevice)
x = rand(10, 10) |> AMDGPUDevice()
Expand Down
21 changes: 21 additions & 0 deletions test/cuda_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,27 @@ using FillArrays, Zygote # Extensions
end
end

@testset "Functions" begin
if MLDataDevices.functional(CUDADevice)
@test get_device(tanh) isa MLDataDevices.UnknownDevice
@test get_device_type(tanh) <: MLDataDevices.UnknownDevice

f(x, y) = () -> (x, x .^ 2, y)

ff = f([1, 2, 3], 1)
@test get_device(ff) isa CPUDevice
@test get_device_type(ff) <: CPUDevice

ff_xpu = ff |> CUDADevice()
@test get_device(ff_xpu) isa CUDADevice
@test get_device_type(ff_xpu) <: CUDADevice

ff_cpu = ff_xpu |> cpu_device()
@test get_device(ff_cpu) isa CPUDevice
@test get_device_type(ff_cpu) <: CPUDevice
end
end

@testset "Wrapped Arrays" begin
if MLDataDevices.functional(CUDADevice)
x = rand(10, 10) |> CUDADevice()
Expand Down
21 changes: 21 additions & 0 deletions test/metal_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,27 @@ using FillArrays, Zygote # Extensions
end
end

@testset "Functions" begin
if MLDataDevices.functional(MetalDevice)
@test get_device(tanh) isa MLDataDevices.UnknownDevice
@test get_device_type(tanh) <: MLDataDevices.UnknownDevice

f(x, y) = () -> (x, x .^ 2, y)

ff = f([1, 2, 3], 1)
@test get_device(ff) isa CPUDevice
@test get_device_type(ff) <: CPUDevice

ff_xpu = ff |> MetalDevice()
@test get_device(ff_xpu) isa MetalDevice
@test get_device_type(ff_xpu) <: MetalDevice

ff_cpu = ff_xpu |> cpu_device()
@test get_device(ff_cpu) isa CPUDevice
@test get_device_type(ff_cpu) <: CPUDevice
end
end

@testset "Wrapper Arrays" begin
if MLDataDevices.functional(MetalDevice)
x = rand(Float32, 10, 10) |> MetalDevice()
Expand Down
21 changes: 21 additions & 0 deletions test/oneapi_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,27 @@ using FillArrays, Zygote # Extensions
end
end

@testset "Functions" begin
if MLDataDevices.functional(oneAPIDevice)
@test get_device(tanh) isa MLDataDevices.UnknownDevice
@test get_device_type(tanh) <: MLDataDevices.UnknownDevice

f(x, y) = () -> (x, x .^ 2, y)

ff = f([1, 2, 3], 1)
@test get_device(ff) isa CPUDevice
@test get_device_type(ff) <: CPUDevice

ff_xpu = ff |> oneAPIDevice()
@test get_device(ff_xpu) isa oneAPIDevice
@test get_device_type(ff_xpu) <: oneAPIDevice

ff_cpu = ff_xpu |> cpu_device()
@test get_device(ff_cpu) isa CPUDevice
@test get_device_type(ff_cpu) <: CPUDevice
end
end

@testset "Wrapper Arrays" begin
if MLDataDevices.functional(oneAPIDevice)
x = rand(10, 10) |> oneAPIDevice()
Expand Down
21 changes: 21 additions & 0 deletions test/xla_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,27 @@ using FillArrays, Zygote # Extensions
end
end

@testset "Functions" begin
if MLDataDevices.functional(XLADevice)
@test get_device(tanh) isa MLDataDevices.UnknownDevice
@test get_device_type(tanh) <: MLDataDevices.UnknownDevice

f(x, y) = () -> (x, x .^ 2, y)

ff = f([1, 2, 3], 1)
@test get_device(ff) isa CPUDevice
@test get_device_type(ff) <: CPUDevice

ff_xpu = ff |> XLADevice()
@test get_device(ff_xpu) isa XLADevice
@test get_device_type(ff_xpu) <: XLADevice

ff_cpu = ff_xpu |> cpu_device()
@test get_device(ff_cpu) isa CPUDevice
@test get_device_type(ff_cpu) <: CPUDevice
end
end

@testset "Wrapped Arrays" begin
if MLDataDevices.functional(XLADevice)
x = rand(10, 10) |> XLADevice()
Expand Down

0 comments on commit 37429fb

Please sign in to comment.