Skip to content

Commit

Permalink
Remove _isleaf and _isbitstype
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Jun 7, 2024
1 parent 1333b8b commit 4fb34e0
Show file tree
Hide file tree
Showing 5 changed files with 40 additions and 11 deletions.
15 changes: 4 additions & 11 deletions src/LuxDeviceUtils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -318,15 +318,15 @@ default_device_rng(::LuxCPUDevice) = Random.default_rng()
for (dev) in (:CPU, :CUDA, :AMDGPU, :Metal, :oneAPI)
ldev = Symbol("Lux$(dev)Device")
@eval begin
function (D::$(ldev))(x::AbstractArray)
function (D::$(ldev))(x::AbstractArray{T}) where {T}
fn = Base.Fix1(Adapt.adapt, D)
return _isbitsarray(x) ? fn(x) : map(D, x)
return isbitstype(T) ? fn(x) : map(D, x)
end
(D::$(ldev))(x::Tuple) = map(D, x)
(D::$(ldev))(x::NamedTuple{F}) where {F} = NamedTuple{F}(D(values(x)))
function (D::$(ldev))(x)
_isleaf(x) && return Adapt.adapt(D, x)
return fmap(Base.Fix1(Adapt.adapt, D), x; exclude=_isleaf)
Functors.isleaf(x) && return Adapt.adapt(D, x)
return fmap(Base.Fix1(Adapt.adapt, D), x)
end
function (::$(ldev))(NN::LuxCore.AbstractExplicitLayer)
@warn "Lux layers are stateless and hence don't participate in device \
Expand Down Expand Up @@ -476,13 +476,6 @@ for T in (LuxAMDGPUDevice, LuxCUDADevice, LuxMetalDevice, LuxoneAPIDevice)
@eval Adapt.adapt_storage(to::$(T), x::AbstractRange) = Adapt.adapt(to, collect(x))
end

@inline _isbitsarray(::AbstractArray{<:Number}) = true
@inline _isbitsarray(::AbstractArray{T}) where {T} = isbitstype(T)
@inline _isbitsarray(x) = false

@inline _isleaf(::AbstractRNG) = true
@inline _isleaf(x) = _isbitsarray(x) || Functors.isleaf(x)

# Chain Rules Core
function CRC.rrule(::typeof(Adapt.adapt_storage), to::AbstractLuxDevice, x::AbstractArray)
∇adapt_storage = let x = x
Expand Down
9 changes: 9 additions & 0 deletions test/amdgpu.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ using FillArrays, Zygote # Extensions

@testset "Data Transfer" begin
ps = (a=(c=zeros(10, 1), d=1), b=ones(10, 1), e=:c, d="string",
mixed=[2.0f0, 3.0, ones(2, 3)], # mixed array types
rng_default=Random.default_rng(), rng=MersenneTwister(),
one_elem=Zygote.OneElement(2.0f0, (2, 3), (1:3, 1:4)), farray=Fill(1.0f0, (2, 3)))

Expand All @@ -43,6 +44,10 @@ using FillArrays, Zygote # Extensions
@test ps_xpu.a.c isa aType
@test ps_xpu.b isa aType
@test ps_xpu.a.d == ps.a.d
@test ps_xpu.mixed isa Vector
@test ps_xpu.mixed[1] isa Float32
@test ps_xpu.mixed[2] isa Float64
@test ps_xpu.mixed[3] isa aType
@test ps_xpu.e == ps.e
@test ps_xpu.d == ps.d
@test ps_xpu.rng_default isa rngType
Expand All @@ -63,6 +68,10 @@ using FillArrays, Zygote # Extensions
@test ps_cpu.a.c == ps.a.c
@test ps_cpu.b == ps.b
@test ps_cpu.a.d == ps.a.d
@test ps_cpu.mixed isa Vector
@test ps_cpu.mixed[1] isa Float32
@test ps_cpu.mixed[2] isa Float64
@test ps_cpu.mixed[3] isa Array
@test ps_cpu.e == ps.e
@test ps_cpu.d == ps.d
@test ps_cpu.rng_default isa Random.TaskLocalRNG
Expand Down
9 changes: 9 additions & 0 deletions test/cuda.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ using FillArrays, Zygote # Extensions

@testset "Data Transfer" begin
ps = (a=(c=zeros(10, 1), d=1), b=ones(10, 1), e=:c, d="string",
mixed=[2.0f0, 3.0, ones(2, 3)], # mixed array types
rng_default=Random.default_rng(), rng=MersenneTwister(),
one_elem=Zygote.OneElement(2.0f0, (2, 3), (1:3, 1:4)), farray=Fill(1.0f0, (2, 3)))

Expand All @@ -42,6 +43,10 @@ using FillArrays, Zygote # Extensions
@test ps_xpu.a.c isa aType
@test ps_xpu.b isa aType
@test ps_xpu.a.d == ps.a.d
@test ps_xpu.mixed isa Vector
@test ps_xpu.mixed[1] isa Float32
@test ps_xpu.mixed[2] isa Float64
@test ps_xpu.mixed[3] isa aType
@test ps_xpu.e == ps.e
@test ps_xpu.d == ps.d
@test ps_xpu.rng_default isa rngType
Expand All @@ -62,6 +67,10 @@ using FillArrays, Zygote # Extensions
@test ps_cpu.a.c == ps.a.c
@test ps_cpu.b == ps.b
@test ps_cpu.a.d == ps.a.d
@test ps_cpu.mixed isa Vector
@test ps_cpu.mixed[1] isa Float32
@test ps_cpu.mixed[2] isa Float64
@test ps_cpu.mixed[3] isa Array
@test ps_cpu.e == ps.e
@test ps_cpu.d == ps.d
@test ps_cpu.rng_default isa Random.TaskLocalRNG
Expand Down
9 changes: 9 additions & 0 deletions test/metal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ using FillArrays, Zygote # Extensions

@testset "Data Transfer" begin
ps = (a=(c=zeros(10, 1), d=1), b=ones(10, 1), e=:c, d="string",
mixed=[2.0f0, 3.0, ones(2, 3)], # mixed array types
rng_default=Random.default_rng(), rng=MersenneTwister(),
one_elem=Zygote.OneElement(2.0f0, (2, 3), (1:3, 1:4)), farray=Fill(1.0f0, (2, 3)))

Expand All @@ -43,6 +44,10 @@ using FillArrays, Zygote # Extensions
@test ps_xpu.a.c isa aType
@test ps_xpu.b isa aType
@test ps_xpu.a.d == ps.a.d
@test ps_xpu.mixed isa Vector
@test ps_xpu.mixed[1] isa Float32
@test ps_xpu.mixed[2] isa Float64
@test ps_xpu.mixed[3] isa aType
@test ps_xpu.e == ps.e
@test ps_xpu.d == ps.d
@test ps_xpu.rng_default isa rngType
Expand All @@ -63,6 +68,10 @@ using FillArrays, Zygote # Extensions
@test ps_cpu.a.c == ps.a.c
@test ps_cpu.b == ps.b
@test ps_cpu.a.d == ps.a.d
@test ps_cpu.mixed isa Vector
@test ps_cpu.mixed[1] isa Float32
@test ps_cpu.mixed[2] isa Float64
@test ps_cpu.mixed[3] isa Array
@test ps_cpu.e == ps.e
@test ps_cpu.d == ps.d
@test ps_cpu.rng_default isa Random.TaskLocalRNG
Expand Down
9 changes: 9 additions & 0 deletions test/oneapi.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ using FillArrays, Zygote # Extensions

@testset "Data Transfer" begin
ps = (a=(c=zeros(10, 1), d=1), b=ones(10, 1), e=:c, d="string",
mixed=[2.0f0, 3.0, ones(2, 3)], # mixed array types
rng_default=Random.default_rng(), rng=MersenneTwister(),
one_elem=Zygote.OneElement(2.0f0, (2, 3), (1:3, 1:4)), farray=Fill(1.0f0, (2, 3)))

Expand All @@ -43,6 +44,10 @@ using FillArrays, Zygote # Extensions
@test ps_xpu.a.c isa aType
@test ps_xpu.b isa aType
@test ps_xpu.a.d == ps.a.d
@test ps_xpu.mixed isa Vector
@test ps_xpu.mixed[1] isa Float32
@test ps_xpu.mixed[2] isa Float64
@test ps_xpu.mixed[3] isa aType
@test ps_xpu.e == ps.e
@test ps_xpu.d == ps.d
@test ps_xpu.rng_default isa rngType
Expand All @@ -63,6 +68,10 @@ using FillArrays, Zygote # Extensions
@test ps_cpu.a.c == ps.a.c
@test ps_cpu.b == ps.b
@test ps_cpu.a.d == ps.a.d
@test ps_cpu.mixed isa Vector
@test ps_cpu.mixed[1] isa Float32
@test ps_cpu.mixed[2] isa Float64
@test ps_cpu.mixed[3] isa Array
@test ps_cpu.e == ps.e
@test ps_cpu.d == ps.d
@test ps_cpu.rng_default isa Random.TaskLocalRNG
Expand Down

0 comments on commit 4fb34e0

Please sign in to comment.