Skip to content

Commit

Permalink
Add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Feb 26, 2024
1 parent ab718c6 commit 00d08b5
Show file tree
Hide file tree
Showing 5 changed files with 61 additions and 17 deletions.
9 changes: 9 additions & 0 deletions src/compat/chainrulescore.jl
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,15 @@ end
function ChainRulesCore.rrule(::Type{CA}, nt::NamedTuple) where {CA<:ComponentArray}
y = CA(nt)

function ∇NamedTupleToComponentArray::AbstractArray)
if length(Δ) == length(y)
return ∇NamedTupleToComponentArray(ComponentArray(vec(Δ), getaxes(y)))
end
error("Got pullback input of shape $(size(Δ)) & type $(typeof(Δ)) for output " *
"of shape $(size(y)) & type $(typeof(y))")
return nothing
end

function ∇NamedTupleToComponentArray::ComponentArray)
return ChainRulesCore.NoTangent(), NamedTuple(Δ)
end
Expand Down
5 changes: 1 addition & 4 deletions src/compat/functors.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1 @@
function Functors.functor(::Type{<:ComponentArray}, c)
return (
NamedTuple{propertynames(c)}(getproperty.((c,), propertynames(c))), ComponentArray)
end
Functors.functor(::Type{<:ComponentVector}, c) = NamedTuple(c), ComponentVector
2 changes: 2 additions & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,11 @@ JLArrays = "27aeb0d3-9eb9-45fb-866b-73c2ecf80fcb"
LabelledArrays = "2ee39098-c373-598a-b85f-a56591580800"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
OffsetArrays = "6fe1bfb0-de20-5000-8ca7-80f57d26f881"
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
TruncatedStacktraces = "781d530d-4396-4725-bb49-402e4bee1e77"
Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
52 changes: 39 additions & 13 deletions test/autodiff_tests.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import FiniteDiff, ForwardDiff, ReverseDiff, Tracker, Zygote

using Optimisers
using Test

F(a, x) = sum(abs2, a) * x^3
Expand Down Expand Up @@ -38,6 +38,22 @@ truth = ComponentArray(a = [32, 48], x = 156)
@test out isa Vector{<:ForwardDiff.Dual}
end

@testset "Optimisers Update" begin
ca_ = deepcopy(ca)
opt_st = Optimisers.setup(Adam(0.01), ca_)
gs_zyg = only(Zygote.gradient(F_idx_val, ca_))
@test !(last(Optimisers.update(opt_st, ca_, gs_zyg)) ca)
Optimisers.update!(opt_st, ca_, gs_zyg)
@test !(ca_ ca)

ca_ = deepcopy(ca)
opt_st = Optimisers.setup(Adam(0.01), ca_)
gs_rdiff = ReverseDiff.gradient(F_idx_val, ca_)
@test !(last(Optimisers.update(opt_st, ca_, gs_rdiff)) ca)
Optimisers.update!(opt_st, ca_, gs_rdiff)
@test !(ca_ ca)
end

@testset "Projection" begin
gs_ca = Zygote.gradient(sum, ca)[1]

Expand Down Expand Up @@ -76,18 +92,28 @@ end
@test ∂r ∂r_ca
end

function F_prop(x)
@assert propertynames(x) == (:x, :y)
return sum(abs2, x.x .- x.y)
end

@testset "Preserve Properties" begin
x = ComponentArray(; x = [1.0, 5.0], y = [3.0, 4.0])

# # This is commented out because the gradient operation itself is broken due to Zygote's inability
# # to support mutation and ComponentArray's use of mutation for construction from a NamedTuple.
# # It would be nice to support this eventually, so I'll just leave this commented (because @test_broken
# # wouldn't work here because the error happens before the test)
# @testset "Issues" begin
# function mysum(x::AbstractVector)
# y = ComponentVector(x=x)
# return sum(y)
# end
gs_z = only(Zygote.gradient(F_prop, x))
gs_rdiff = ReverseDiff.gradient(F_prop, x)

# Δ = Zygote.gradient(mysum, rand(10))
@test gs_z gs_rdiff
end

@testset "Issues" begin
function mysum(x::AbstractVector)
y = ComponentVector(x=x)
z = ComponentVector(; z = x .^ 2)
return sum(y) + sum(abs2, z)
end

# @test Δ isa Vector{Float64}
# end
Δ = only(Zygote.gradient(mysum, rand(10)))

@test Δ isa AbstractVector{Float64}
end
10 changes: 10 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ using StaticArrays
using OffsetArrays
using Test
using Unitful
using Functors
import TruncatedStacktraces # This is loaded just to trigger the extension package


## Test setup
Expand Down Expand Up @@ -690,6 +692,14 @@ end
@test_throws ArgumentError axpby!(2, x, 3, y)
end

@testset "Functors" begin
for carray in (ca, ca_Float32, ca_MVector, ca_SVector, ca_composed, ca2, caa)
θ, re = Functors.functor(carray)
@test θ isa NamedTuple
@test re(θ) == carray
end
end

@testset "Autodiff" begin
include("autodiff_tests.jl")
end
Expand Down

0 comments on commit 00d08b5

Please sign in to comment.