From 37429fbaf93ac3235be00724eaf07b23f7a3d98d Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 18 Oct 2024 13:25:57 -0400 Subject: [PATCH] test: functions and closures --- test/amdgpu_tests.jl | 21 +++++++++++++++++++++ test/cuda_tests.jl | 21 +++++++++++++++++++++ test/metal_tests.jl | 21 +++++++++++++++++++++ test/oneapi_tests.jl | 21 +++++++++++++++++++++ test/xla_tests.jl | 21 +++++++++++++++++++++ 5 files changed, 105 insertions(+) diff --git a/test/amdgpu_tests.jl b/test/amdgpu_tests.jl index f29c279..41a8797 100644 --- a/test/amdgpu_tests.jl +++ b/test/amdgpu_tests.jl @@ -126,6 +126,27 @@ using FillArrays, Zygote # Extensions end end +@testset "Functions" begin + if MLDataDevices.functional(AMDGPUDevice) + @test get_device(tanh) isa MLDataDevices.UnknownDevice + @test get_device_type(tanh) <: MLDataDevices.UnknownDevice + + f(x, y) = () -> (x, x .^ 2, y) + + ff = f([1, 2, 3], 1) + @test get_device(ff) isa CPUDevice + @test get_device_type(ff) <: CPUDevice + + ff_xpu = ff |> AMDGPUDevice() + @test get_device(ff_xpu) isa AMDGPUDevice + @test get_device_type(ff_xpu) <: AMDGPUDevice + + ff_cpu = ff_xpu |> cpu_device() + @test get_device(ff_cpu) isa CPUDevice + @test get_device_type(ff_cpu) <: CPUDevice + end +end + @testset "Wrapped Arrays" begin if MLDataDevices.functional(AMDGPUDevice) x = rand(10, 10) |> AMDGPUDevice() diff --git a/test/cuda_tests.jl b/test/cuda_tests.jl index bd8a234..1f95831 100644 --- a/test/cuda_tests.jl +++ b/test/cuda_tests.jl @@ -151,6 +151,27 @@ using FillArrays, Zygote # Extensions end end +@testset "Functions" begin + if MLDataDevices.functional(CUDADevice) + @test get_device(tanh) isa MLDataDevices.UnknownDevice + @test get_device_type(tanh) <: MLDataDevices.UnknownDevice + + f(x, y) = () -> (x, x .^ 2, y) + + ff = f([1, 2, 3], 1) + @test get_device(ff) isa CPUDevice + @test get_device_type(ff) <: CPUDevice + + ff_xpu = ff |> CUDADevice() + @test get_device(ff_xpu) isa CUDADevice + @test get_device_type(ff_xpu) <: CUDADevice + + ff_cpu = ff_xpu |> cpu_device() + @test get_device(ff_cpu) isa CPUDevice + @test get_device_type(ff_cpu) <: CPUDevice + end +end + @testset "Wrapped Arrays" begin if MLDataDevices.functional(CUDADevice) x = rand(10, 10) |> CUDADevice() diff --git a/test/metal_tests.jl b/test/metal_tests.jl index a214ebd..aeb596a 100644 --- a/test/metal_tests.jl +++ b/test/metal_tests.jl @@ -115,6 +115,27 @@ using FillArrays, Zygote # Extensions end end +@testset "Functions" begin + if MLDataDevices.functional(MetalDevice) + @test get_device(tanh) isa MLDataDevices.UnknownDevice + @test get_device_type(tanh) <: MLDataDevices.UnknownDevice + + f(x, y) = () -> (x, x .^ 2, y) + + ff = f([1, 2, 3], 1) + @test get_device(ff) isa CPUDevice + @test get_device_type(ff) <: CPUDevice + + ff_xpu = ff |> MetalDevice() + @test get_device(ff_xpu) isa MetalDevice + @test get_device_type(ff_xpu) <: MetalDevice + + ff_cpu = ff_xpu |> cpu_device() + @test get_device(ff_cpu) isa CPUDevice + @test get_device_type(ff_cpu) <: CPUDevice + end +end + @testset "Wrapper Arrays" begin if MLDataDevices.functional(MetalDevice) x = rand(Float32, 10, 10) |> MetalDevice() diff --git a/test/oneapi_tests.jl b/test/oneapi_tests.jl index d1720f0..8bb6026 100644 --- a/test/oneapi_tests.jl +++ b/test/oneapi_tests.jl @@ -115,6 +115,27 @@ using FillArrays, Zygote # Extensions end end +@testset "Functions" begin + if MLDataDevices.functional(oneAPIDevice) + @test get_device(tanh) isa MLDataDevices.UnknownDevice + @test get_device_type(tanh) <: MLDataDevices.UnknownDevice + + f(x, y) = () -> (x, x .^ 2, y) + + ff = f([1, 2, 3], 1) + @test get_device(ff) isa CPUDevice + @test get_device_type(ff) <: CPUDevice + + ff_xpu = ff |> oneAPIDevice() + @test get_device(ff_xpu) isa oneAPIDevice + @test get_device_type(ff_xpu) <: oneAPIDevice + + ff_cpu = ff_xpu |> cpu_device() + @test get_device(ff_cpu) isa CPUDevice + @test get_device_type(ff_cpu) <: CPUDevice + end +end + @testset "Wrapper Arrays" begin if MLDataDevices.functional(oneAPIDevice) x = rand(10, 10) |> oneAPIDevice() diff --git a/test/xla_tests.jl b/test/xla_tests.jl index 138727f..21466bd 100644 --- a/test/xla_tests.jl +++ b/test/xla_tests.jl @@ -114,6 +114,27 @@ using FillArrays, Zygote # Extensions end end +@testset "Functions" begin + if MLDataDevices.functional(XLADevice) + @test get_device(tanh) isa MLDataDevices.UnknownDevice + @test get_device_type(tanh) <: MLDataDevices.UnknownDevice + + f(x, y) = () -> (x, x .^ 2, y) + + ff = f([1, 2, 3], 1) + @test get_device(ff) isa CPUDevice + @test get_device_type(ff) <: CPUDevice + + ff_xpu = ff |> XLADevice() + @test get_device(ff_xpu) isa XLADevice + @test get_device_type(ff_xpu) <: XLADevice + + ff_cpu = ff_xpu |> cpu_device() + @test get_device(ff_cpu) isa CPUDevice + @test get_device_type(ff_cpu) <: CPUDevice + end +end + @testset "Wrapped Arrays" begin if MLDataDevices.functional(XLADevice) x = rand(10, 10) |> XLADevice()