Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Always use fmap and bump functors to 0.5 #1039

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ FastClosures = "0.3.2"
Flux = "0.14.25"
ForwardDiff = "0.10.36"
FunctionWrappers = "1.1.3"
Functors = "0.4.12"
Functors = "0.4.12, 0.5"
GPUArraysCore = "0.1.6, 0.2"
LinearAlgebra = "1.10"
LossFunctions = "0.11.1"
Expand Down
2 changes: 1 addition & 1 deletion docs/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ DocumenterVitepress = "0.1.3"
Enzyme = "0.13.13"
FiniteDiff = "2.23.1"
ForwardDiff = "0.10.36"
Functors = "0.4.12"
Functors = "0.4.12, 0.5"
GPUArraysCore = "0.1, 0.2"
KernelAbstractions = "0.9"
LinearAlgebra = "1.10"
Expand Down
2 changes: 1 addition & 1 deletion examples/BayesianNN/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[compat]
CairoMakie = "0.12"
Functors = "0.4"
Functors = "0.4, 0.5"
LinearAlgebra = "1"
Lux = "1"
Random = "1"
Expand Down
2 changes: 1 addition & 1 deletion lib/LuxCore/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ ChainRulesCore = "1.24"
Compat = "4.15.0"
DispatchDoctor = "0.4.10"
EnzymeCore = "0.8.5"
Functors = "0.4.12"
Functors = "0.4.12, 0.5"
MLDataDevices = "1"
Random = "1.10"
Reactant = "0.2.4"
Expand Down
2 changes: 1 addition & 1 deletion lib/LuxCore/test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Aqua = "0.8.7"
EnzymeCore = "0.8.5"
ExplicitImports = "1.9.0"
Functors = "0.4.12"
Functors = "0.4.12, 0.5"
MLDataDevices = "1.0.0"
Optimisers = "0.3.3"
Random = "1.10"
Expand Down
2 changes: 1 addition & 1 deletion lib/LuxTestUtils/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ DispatchDoctor = "0.4.12"
Enzyme = "0.13.13"
FiniteDiff = "2.23.1"
ForwardDiff = "0.10.36"
Functors = "0.4.11"
Functors = "0.4.11, 0.5"
JET = "0.9.6"
MLDataDevices = "1.0.0"
ReverseDiff = "1.15.3"
Expand Down
2 changes: 1 addition & 1 deletion lib/MLDataDevices/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ CUDA = "5.2"
ChainRulesCore = "1.23"
Compat = "4.15"
FillArrays = "1"
Functors = "0.4.8"
Functors = "0.4.8, 0.5"
GPUArrays = "10, 11"
MLUtils = "0.4.4"
Metal = "1"
Expand Down
22 changes: 5 additions & 17 deletions lib/MLDataDevices/src/public.jl
Original file line number Diff line number Diff line change
Expand Up @@ -362,24 +362,12 @@ function set_device!(::Type{T}, ::Nothing, rank::Integer) where {T <: AbstractDe
return set_device!(T, rank)
end

# Dispatches for Different Data Structures
# Abstract Array / Tuples / NamedTuples have special fast paths to facilitate type stability
# For all other types we rely on fmap which means we lose type stability.
# For Lux, typically models only has these 3 datastructures so we should be mostly fine.
for (dev) in (:CPU, :CUDA, :AMDGPU, :Metal, :oneAPI, :XLA)

for (dev) in (:CPU, :CUDA, :AMDGPU, :Metal, :oneAPI, :Reactant)
ldev = Symbol(dev, :Device)
@eval begin
function (D::$(ldev))(x::AbstractArray{T}) where {T}
if isbitstype(T) || Internal.special_aos(x) || x isa Adapt.WrappedArray
return Adapt.adapt(D, x)
end
return map(D, x)
end
(D::$(ldev))(x::Union{Tuple, NamedTuple}) = map(D, x)
function (D::$(ldev))(x)
isleaf(x) && return Adapt.adapt(D, x)
return Functors.fmap(D, x; exclude=isleaf)
end
@eval function (D::$(ldev))(x)
isleaf(x) && return Adapt.adapt(D, x)
return Functors.fmap(D, x; exclude=isleaf)
end
end

Expand Down
2 changes: 1 addition & 1 deletion lib/MLDataDevices/test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ ComponentArrays = "0.15.8"
ExplicitImports = "1.9.0"
FillArrays = "1"
ForwardDiff = "0.10.36"
Functors = "0.4.8"
Functors = "0.4.8, 0.5"
MLUtils = "0.4"
Pkg = "1.10"
Random = "1.10"
Expand Down
14 changes: 14 additions & 0 deletions lib/MLDataDevices/test/misc_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -219,3 +219,17 @@ end

@test only(Zygote.gradient(x -> sum(abs2, gdev(x)), x')) isa Matrix{Float64}
end

@testset "data movement is type stable" begin
cpu = cpu_device()
gpu = gpu_device()

r = [1, 2]
x = (a = r, b = 3, c =(4, (d=5, e=r)))
y = @inferred(gpu(x))
x2 = @inferred(cpu(y))

# identity is preserved
@test y.a === y.c[2].e
@test x2.a === x2.c[2].e
end
2 changes: 1 addition & 1 deletion test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ Documenter = "1.4"
Enzyme = "0.13.13"
ExplicitImports = "1.9.0"
ForwardDiff = "0.10.36"
Functors = "0.4.12"
Functors = "0.4.12, 0.5"
Hwloc = "3.2.0"
InteractiveUtils = "<0.0.1, 1"
LinearAlgebra = "1.10"
Expand Down
Loading