Skip to content

Commit

Permalink
fix: make gpu(x) return unmodified x when GPU backends aren't loa…
Browse files Browse the repository at this point in the history
…ded (#2295)

* fix: make gpu return unmodified input when gpu isn't available

* add tests

* fix
  • Loading branch information
IanButterworth authored Jul 19, 2023
1 parent fb507aa commit 2fe82a8
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 0 deletions.
3 changes: 3 additions & 0 deletions src/functor.jl
Original file line number Diff line number Diff line change
Expand Up @@ -341,6 +341,7 @@ function gpu(::FluxCUDAAdaptor, x)
`CUDA.jl` must be loaded to access it.
Add `using CUDA` or `import CUDA` to your code.
""" maxlog=1
return x
end
end

Expand All @@ -361,6 +362,7 @@ function gpu(::FluxAMDAdaptor, x)
`AMDGPU.jl` must be loaded to access it.
Add `using AMDGPU` or `import AMDGPU` to your code.
""" maxlog=1
return x
end
end

Expand All @@ -380,6 +382,7 @@ function gpu(::FluxMetalAdaptor, x)
The Metal functionality is being called but
`Metal.jl` must be loaded to access it.
""" maxlog=1
return x
end
end

Expand Down
4 changes: 4 additions & 0 deletions test/functors.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
x = rand(Float32, 10, 10)
if !(Flux.CUDA_LOADED[] || Flux.AMD_LOADED[] || Flux.METAL_LOADED[])
@test x === gpu(x)
end
3 changes: 3 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,9 @@ Random.seed!(0)
include("outputsize.jl")
end

@testset "functors" begin
include("functors.jl")
end

if get(ENV, "FLUX_TEST_CUDA", "false") == "true"
using CUDA
Expand Down

0 comments on commit 2fe82a8

Please sign in to comment.