Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Avoiding Type Piracy in Lux #246

Merged
merged 5 commits into from
Mar 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@ jobs:
- '1.6'
- '1.8'
- '1.9'
- '1.10'
- '1' # Leave this line unchanged. '1' will automatically expand to the latest stable 1.x release of Julia.
- '1.10.0-beta3'
os:
- ubuntu-latest
arch:
Expand Down
14 changes: 12 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "ComponentArrays"
uuid = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
authors = ["Jonnie Diegelman <[email protected]>"]
version = "0.15.10"
version = "0.15.11"

[deps]
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
Expand All @@ -17,20 +17,24 @@ StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c"
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9"
GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7"
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
TruncatedStacktraces = "781d530d-4396-4725-bb49-402e4bee1e77"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[extensions]
ComponentArraysAdaptExt = "Adapt"
ComponentArraysConstructionBaseExt = "ConstructionBase"
ComponentArraysGPUArraysExt = "GPUArrays"
ComponentArraysOptimisersExt = "Optimisers"
ComponentArraysRecursiveArrayToolsExt = "RecursiveArrayTools"
ComponentArraysReverseDiffExt = "ReverseDiff"
ComponentArraysSciMLBaseExt = "SciMLBase"
ComponentArraysTrackerExt = "Tracker"
ComponentArraysTruncatedStacktracesExt = "TruncatedStacktraces"
ComponentArraysZygoteExt = "Zygote"

[compat]
Expand All @@ -41,13 +45,17 @@ ConstructionBase = "1"
ForwardDiff = "0.10"
Functors = "0.4.4"
GPUArrays = "8, 9, 10"
LinearAlgebra = "1"
Optimisers = "0.3"
PackageExtensionCompat = "1"
RecursiveArrayTools = "2, 3"
ReverseDiff = "1"
SciMLBase = "1, 2"
StaticArraysCore = "1"
StaticArrayInterface = "1"
StaticArraysCore = "1"
Test = "1"
Tracker = "0.2"
TruncatedStacktraces = "1.4"
Zygote = "0.6"
julia = "1.6"

Expand All @@ -56,9 +64,11 @@ Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7"
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
TruncatedStacktraces = "781d530d-4396-4725-bb49-402e4bee1e77"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
21 changes: 21 additions & 0 deletions ext/ComponentArraysOptimisersExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
module ComponentArraysOptimisersExt

using ComponentArrays, Optimisers

# Optimisers can handle componentarrays by default, but we can vectorize the entire
# operation here instead of doing multiple smaller operations
Optimisers.setup(opt::AbstractRule, ps::ComponentArray) = Optimisers.setup(opt, getdata(ps))

function Optimisers.update(tree, ps::ComponentArray, gs::ComponentArray)
gs_flat = ComponentArrays.__value(getdata(gs)) # Safety against ReverseDiff
tree, ps_new = Optimisers.update(tree, getdata(ps), gs_flat)
return tree, ComponentArray(ps_new, getaxes(ps))
end

function Optimisers.update!(tree::Optimisers.Leaf, ps::ComponentArray, gs::ComponentArray)
gs_flat = ComponentArrays.__value(getdata(gs)) # Safety against ReverseDiff
tree, ps_new = Optimisers.update!(tree, getdata(ps), gs_flat)
return tree, ComponentArray(ps_new, getaxes(ps))
end

end
23 changes: 18 additions & 5 deletions ext/ComponentArraysReverseDiffExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ module ComponentArraysReverseDiffExt

using ComponentArrays, ReverseDiff

const TrackedComponentArray{V, D, N, DA, A, Ax} = ReverseDiff.TrackedArray{V,D,N,ComponentArray{V,N,A,Ax},DA}
const TrackedComponentArray{V,D,N,DA,A,Ax} = ReverseDiff.TrackedArray{V,D,N,ComponentArray{V,N,A,Ax},DA}

maybe_tracked_array(val::AbstractArray, der, tape, inds, origin) = ReverseDiff.TrackedArray(val, der, tape)
function maybe_tracked_array(val::Real, der, tape, inds, origin::AbstractVector)
Expand All @@ -12,10 +12,10 @@ function maybe_tracked_array(val::Real, der, tape, inds, origin::AbstractVector)
end

for f in [:getindex, :view]
@eval function Base.$f(tca::TrackedComponentArray, inds::Union{Symbol, Val}...)
val = $f(ReverseDiff.value(tca), inds...)
der = Base.maybeview(ReverseDiff.deriv(tca), inds...)
t = ReverseDiff.tape(tca)
@eval function Base.$f(tca::TrackedComponentArray, inds::Union{Symbol,Val}...)
val = $f(ReverseDiff.value(tca), inds...)
der = Base.maybeview(ReverseDiff.deriv(tca), inds...)
t = ReverseDiff.tape(tca)
return maybe_tracked_array(val, der, t, inds, tca)
end
end
Expand All @@ -31,4 +31,17 @@ function Base.getproperty(tca::TrackedComponentArray, s::Symbol)
end
end

function Base.propertynames(::TrackedComponentArray{V,D,N,DA,A,Tuple{Ax}}) where {V,D,N,DA,A,Ax<:ComponentArrays.AbstractAxis}
return propertynames(ComponentArrays.indexmap(Ax))
end

function Base.NamedTuple(tca::TrackedComponentArray)
props = propertynames(tca)
return NamedTuple{props}(getproperty(tca, p) for p in props)
end

@inline ComponentArrays.__value(x::AbstractArray{<:ReverseDiff.TrackedReal}) = ReverseDiff.value.(x)
@inline ComponentArrays.__value(x::ReverseDiff.TrackedArray) = ReverseDiff.value(x)
@inline ComponentArrays.__value(x::TrackedComponentArray) = ComponentArray(ComponentArrays.__value(getdata(x)), getaxes(x))

end
8 changes: 8 additions & 0 deletions ext/ComponentArraysTruncatedStacktracesExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
module ComponentArraysTruncatedStacktracesExt

using ComponentArrays
import TruncatedStacktraces: @truncate_stacktrace

@truncate_stacktrace ComponentArray 1

end
2 changes: 2 additions & 0 deletions src/ComponentArrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@ include("compat/chainrulescore.jl")
include("compat/static_arrays.jl")
export @static_unpack

include("compat/functors.jl")

import PackageExtensionCompat: @require_extensions
function __init__()
@require_extensions
Expand Down
21 changes: 20 additions & 1 deletion src/compat/chainrulescore.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,26 @@ end
# Prevent double projection
(p::ChainRulesCore.ProjectTo{ComponentArray})(dx::ComponentArray) = dx

function (p::ChainRulesCore.ProjectTo{ComponentArray})(t::ChainRulesCore.Tangent{A, <:NamedTuple}) where {A}
function (p::ChainRulesCore.ProjectTo{ComponentArray})(t::ChainRulesCore.Tangent{A,<:NamedTuple}) where {A}
nt = Functors.fmap(ChainRulesCore.backing, ChainRulesCore.backing(t))
return ComponentArray(nt)
end

function ChainRulesCore.rrule(::Type{CA}, nt::NamedTuple) where {CA<:ComponentArray}
y = CA(nt)

function ∇NamedTupleToComponentArray(Δ::AbstractArray)
if length(Δ) == length(y)
return ∇NamedTupleToComponentArray(ComponentArray(vec(Δ), getaxes(y)))
end
error("Got pullback input of shape $(size(Δ)) & type $(typeof(Δ)) for output " *
"of shape $(size(y)) & type $(typeof(y))")
return nothing
end

function ∇NamedTupleToComponentArray(Δ::ComponentArray)
return ChainRulesCore.NoTangent(), NamedTuple(Δ)
end

return y, ∇NamedTupleToComponentArray
end
1 change: 1 addition & 0 deletions src/compat/functors.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Functors.functor(::Type{<:ComponentVector}, c) = NamedTuple(c), ComponentVector
2 changes: 2 additions & 0 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -50,3 +50,5 @@ recursive_eltype(x::AbstractArray{<:Any}) = isempty(x) ? Base.Bottom : mapreduce
recursive_eltype(x::Dict) = isempty(x) ? Base.Bottom : mapreduce(recursive_eltype, promote_type, values(x))
recursive_eltype(::AbstractArray{T,N}) where {T<:Number, N} = T
recursive_eltype(x) = typeof(x)

@inline __value(x) = x
2 changes: 2 additions & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,11 @@ JLArrays = "27aeb0d3-9eb9-45fb-866b-73c2ecf80fcb"
LabelledArrays = "2ee39098-c373-598a-b85f-a56591580800"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
OffsetArrays = "6fe1bfb0-de20-5000-8ca7-80f57d26f881"
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
TruncatedStacktraces = "781d530d-4396-4725-bb49-402e4bee1e77"
Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
52 changes: 39 additions & 13 deletions test/autodiff_tests.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import FiniteDiff, ForwardDiff, ReverseDiff, Tracker, Zygote

using Optimisers
using Test

F(a, x) = sum(abs2, a) * x^3
Expand Down Expand Up @@ -38,6 +38,22 @@ truth = ComponentArray(a = [32, 48], x = 156)
@test out isa Vector{<:ForwardDiff.Dual}
end

@testset "Optimisers Update" begin
ca_ = deepcopy(ca)
opt_st = Optimisers.setup(Adam(0.01), ca_)
gs_zyg = only(Zygote.gradient(F_idx_val, ca_))
@test !(last(Optimisers.update(opt_st, ca_, gs_zyg)) ≈ ca)
Optimisers.update!(opt_st, ca_, gs_zyg)
@test !(ca_ ≈ ca)

ca_ = deepcopy(ca)
opt_st = Optimisers.setup(Adam(0.01), ca_)
gs_rdiff = ReverseDiff.gradient(F_idx_val, ca_)
@test !(last(Optimisers.update(opt_st, ca_, gs_rdiff)) ≈ ca)
Optimisers.update!(opt_st, ca_, gs_rdiff)
@test !(ca_ ≈ ca)
end

@testset "Projection" begin
gs_ca = Zygote.gradient(sum, ca)[1]

Expand Down Expand Up @@ -76,18 +92,28 @@ end
@test ∂r ≈ ∂r_ca
end

function F_prop(x)
@assert propertynames(x) == (:x, :y)
return sum(abs2, x.x .- x.y)
end

@testset "Preserve Properties" begin
x = ComponentArray(; x = [1.0, 5.0], y = [3.0, 4.0])

# # This is commented out because the gradient operation itself is broken due to Zygote's inability
# # to support mutation and ComponentArray's use of mutation for construction from a NamedTuple.
# # It would be nice to support this eventually, so I'll just leave this commented (because @test_broken
# # wouldn't work here because the error happens before the test)
# @testset "Issues" begin
# function mysum(x::AbstractVector)
# y = ComponentVector(x=x)
# return sum(y)
# end
gs_z = only(Zygote.gradient(F_prop, x))
gs_rdiff = ReverseDiff.gradient(F_prop, x)

# Δ = Zygote.gradient(mysum, rand(10))
@test gs_z ≈ gs_rdiff
end

@testset "Issues" begin
function mysum(x::AbstractVector)
y = ComponentVector(x=x)
z = ComponentVector(; z = x .^ 2)
return sum(y) + sum(abs2, z)
end

# @test Δ isa Vector{Float64}
# end
Δ = only(Zygote.gradient(mysum, rand(10)))

@test Δ isa AbstractVector{Float64}
end
10 changes: 10 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ using StaticArrays
using OffsetArrays
using Test
using Unitful
using Functors
import TruncatedStacktraces # This is loaded just to trigger the extension package


## Test setup
Expand Down Expand Up @@ -690,6 +692,14 @@ end
@test_throws ArgumentError axpby!(2, x, 3, y)
end

@testset "Functors" begin
for carray in (ca, ca_Float32, ca_MVector, ca_SVector, ca_composed, ca2, caa)
θ, re = Functors.functor(carray)
@test θ isa NamedTuple
@test re(θ) == carray
end
end

@testset "Autodiff" begin
include("autodiff_tests.jl")
end
Expand Down
Loading