diff --git a/Project.toml b/Project.toml index 347c4c51..3f04c82b 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "ChainRulesTestUtils" uuid = "cdddcdb0-9152-4a09-a978-84456f9df70a" -version = "0.5.7" +version = "0.5.8" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" diff --git a/src/check_result.jl b/src/check_result.jl index 683d7416..6be1af04 100644 --- a/src/check_result.jl +++ b/src/check_result.jl @@ -25,6 +25,10 @@ for (T1, T2) in ((AbstractThunk, Any), (AbstractThunk, AbstractThunk), (Any, Abs end end +check_equal(::Zero, x; kwargs...) = check_equal(zero(x), x; kwargs...) +check_equal(x, ::Zero; kwargs...) = check_equal(x, zero(x); kwargs...) +check_equal(x::Zero, y::Zero; kwargs...) = @test true + """ _can_pass_early(actual, expected; kwargs...) Used to check if `actual` is basically equal to `expected`, so we don't need to check deeper; @@ -77,16 +81,29 @@ function check_equal( @test ActualPrimal === ExpectedPrimal end + +# Some structual differential and a natural differential +function check_equal(actual::Composite{P, T}, expected; kwargs...) where {T, P} + if _can_pass_early(actual, expected) + @test true + else + @assert (T <: NamedTuple) # it should be a structual differential if we hit this + + # We are only checking the properties that are in the Composite + # the natural differential is allowed to have other properties that we ignore + @testset "$P.$ii" for ii in propertynames(actual) + check_equal(getproperty(actual, ii), getproperty(expected, ii); kwargs...) + end + end +end +check_equal(x, y::Composite; kwargs...) = check_equal(y, x; kwargs...) + # This catches comparisons of Composites and Tuples/NamedTuple # and gives a error messaage complaining about that const LegacyZygoteCompTypes = Union{Tuple,NamedTuple} check_equal(::C, expected::T) where {C<:Composite,T<:LegacyZygoteCompTypes} = @test C === T check_equal(::T, expected::C) where {C<:Composite,T<:LegacyZygoteCompTypes} = @test T === C -check_equal(::Zero, x; kwargs...) = check_equal(zero(x), x; kwargs...) -check_equal(x, ::Zero; kwargs...) = check_equal(x, zero(x); kwargs...) -check_equal(x::Zero, y::Zero; kwargs...) = @test true - # Generic fallback, probably a tuple or something function check_equal(actual::A, expected::E; kwargs...) where {A, E} if _can_pass_early(actual, expected) @@ -101,6 +118,7 @@ function check_equal(actual::A, expected::E; kwargs...) where {A, E} end end + """ _check_add!!_behavour(acc, val) diff --git a/test/check_result.jl b/test/check_result.jl index 1e98155e..5960447c 100644 --- a/test/check_result.jl +++ b/test/check_result.jl @@ -51,10 +51,14 @@ end Composite{Tuple{Float64, Float64}}(1.0, 2.0) ) - D = Diagonal(randn(5)) - check_equal( - Composite{typeof(D)}(diag=D.diag), - Composite{typeof(D)}(diag=D.diag) + diag_eg = Diagonal(randn(5)) + check_equal( # Structual == Structural + Composite{typeof(diag_eg)}(diag=diag_eg.diag), + Composite{typeof(diag_eg)}(diag=diag_eg.diag) + ) + check_equal( # Structural == Natural + Composite{typeof(diag_eg)}(diag=diag_eg.diag), + diag_eg ) T = (a=1.0, b=2.0)