diff --git a/src/compat/chainrulescore.jl b/src/compat/chainrulescore.jl index 0e6b9d7..9f9211b 100644 --- a/src/compat/chainrulescore.jl +++ b/src/compat/chainrulescore.jl @@ -49,6 +49,15 @@ end function ChainRulesCore.rrule(::Type{CA}, nt::NamedTuple) where {CA<:ComponentArray} y = CA(nt) + function ∇NamedTupleToComponentArray(Δ::AbstractArray) + if length(Δ) == length(y) + return ∇NamedTupleToComponentArray(ComponentArray(vec(Δ), getaxes(y))) + end + error("Got pullback input of shape $(size(Δ)) & type $(typeof(Δ)) for output " * + "of shape $(size(y)) & type $(typeof(y))") + return nothing + end + function ∇NamedTupleToComponentArray(Δ::ComponentArray) return ChainRulesCore.NoTangent(), NamedTuple(Δ) end diff --git a/src/compat/functors.jl b/src/compat/functors.jl index b007409..3978ca1 100644 --- a/src/compat/functors.jl +++ b/src/compat/functors.jl @@ -1,4 +1 @@ -function Functors.functor(::Type{<:ComponentArray}, c) - return ( - NamedTuple{propertynames(c)}(getproperty.((c,), propertynames(c))), ComponentArray) -end +Functors.functor(::Type{<:ComponentVector}, c) = NamedTuple(c), ComponentVector diff --git a/test/Project.toml b/test/Project.toml index daa203a..83210c5 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -9,9 +9,11 @@ JLArrays = "27aeb0d3-9eb9-45fb-866b-73c2ecf80fcb" LabelledArrays = "2ee39098-c373-598a-b85f-a56591580800" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" OffsetArrays = "6fe1bfb0-de20-5000-8ca7-80f57d26f881" +Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" +TruncatedStacktraces = "781d530d-4396-4725-bb49-402e4bee1e77" Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" diff --git a/test/autodiff_tests.jl b/test/autodiff_tests.jl index 091784a..b088642 100644 --- a/test/autodiff_tests.jl +++ b/test/autodiff_tests.jl @@ -1,5 +1,5 @@ import FiniteDiff, ForwardDiff, ReverseDiff, Tracker, Zygote - +using Optimisers using Test F(a, x) = sum(abs2, a) * x^3 @@ -38,6 +38,22 @@ truth = ComponentArray(a = [32, 48], x = 156) @test out isa Vector{<:ForwardDiff.Dual} end +@testset "Optimisers Update" begin + ca_ = deepcopy(ca) + opt_st = Optimisers.setup(Adam(0.01), ca_) + gs_zyg = only(Zygote.gradient(F_idx_val, ca_)) + @test !(last(Optimisers.update(opt_st, ca_, gs_zyg)) ≈ ca) + Optimisers.update!(opt_st, ca_, gs_zyg) + @test !(ca_ ≈ ca) + + ca_ = deepcopy(ca) + opt_st = Optimisers.setup(Adam(0.01), ca_) + gs_rdiff = ReverseDiff.gradient(F_idx_val, ca_) + @test !(last(Optimisers.update(opt_st, ca_, gs_rdiff)) ≈ ca) + Optimisers.update!(opt_st, ca_, gs_rdiff) + @test !(ca_ ≈ ca) +end + @testset "Projection" begin gs_ca = Zygote.gradient(sum, ca)[1] @@ -76,18 +92,28 @@ end @test ∂r ≈ ∂r_ca end +function F_prop(x) + @assert propertynames(x) == (:x, :y) + return sum(abs2, x.x .- x.y) +end + +@testset "Preserve Properties" begin + x = ComponentArray(; x = [1.0, 5.0], y = [3.0, 4.0]) -# # This is commented out because the gradient operation itself is broken due to Zygote's inability -# # to support mutation and ComponentArray's use of mutation for construction from a NamedTuple. -# # It would be nice to support this eventually, so I'll just leave this commented (because @test_broken -# # wouldn't work here because the error happens before the test) -# @testset "Issues" begin -# function mysum(x::AbstractVector) -# y = ComponentVector(x=x) -# return sum(y) -# end + gs_z = only(Zygote.gradient(F_prop, x)) + gs_rdiff = ReverseDiff.gradient(F_prop, x) -# Δ = Zygote.gradient(mysum, rand(10)) + @test gs_z ≈ gs_rdiff +end + +@testset "Issues" begin + function mysum(x::AbstractVector) + y = ComponentVector(x=x) + z = ComponentVector(; z = x .^ 2) + return sum(y) + sum(abs2, z) + end -# @test Δ isa Vector{Float64} -# end + Δ = only(Zygote.gradient(mysum, rand(10))) + + @test Δ isa AbstractVector{Float64} +end diff --git a/test/runtests.jl b/test/runtests.jl index 4f4d416..3e7b54e 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -8,6 +8,8 @@ using StaticArrays using OffsetArrays using Test using Unitful +using Functors +import TruncatedStacktraces # This is loaded just to trigger the extension package ## Test setup @@ -690,6 +692,14 @@ end @test_throws ArgumentError axpby!(2, x, 3, y) end +@testset "Functors" begin + for carray in (ca, ca_Float32, ca_MVector, ca_SVector, ca_composed, ca2, caa) + θ, re = Functors.functor(carray) + @test θ isa NamedTuple + @test re(θ) == carray + end +end + @testset "Autodiff" begin include("autodiff_tests.jl") end