Skip to content

Commit

Permalink
Merge pull request #84 from JuliaDiff/ox/struct
Browse files Browse the repository at this point in the history
check_equality on structural x natural
  • Loading branch information
oxinabox authored Dec 16, 2020
2 parents a630c9a + 35fefc0 commit 559f2c1
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 9 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "ChainRulesTestUtils"
uuid = "cdddcdb0-9152-4a09-a978-84456f9df70a"
version = "0.5.7"
version = "0.5.8"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Expand Down
26 changes: 22 additions & 4 deletions src/check_result.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,10 @@ for (T1, T2) in ((AbstractThunk, Any), (AbstractThunk, AbstractThunk), (Any, Abs
end
end

check_equal(::Zero, x; kwargs...) = check_equal(zero(x), x; kwargs...)
check_equal(x, ::Zero; kwargs...) = check_equal(x, zero(x); kwargs...)
check_equal(x::Zero, y::Zero; kwargs...) = @test true

"""
_can_pass_early(actual, expected; kwargs...)
Used to check if `actual` is basically equal to `expected`, so we don't need to check deeper;
Expand Down Expand Up @@ -77,16 +81,29 @@ function check_equal(
@test ActualPrimal === ExpectedPrimal
end


# Some structual differential and a natural differential
function check_equal(actual::Composite{P, T}, expected; kwargs...) where {T, P}
if _can_pass_early(actual, expected)
@test true
else
@assert (T <: NamedTuple) # it should be a structual differential if we hit this

# We are only checking the properties that are in the Composite
# the natural differential is allowed to have other properties that we ignore
@testset "$P.$ii" for ii in propertynames(actual)
check_equal(getproperty(actual, ii), getproperty(expected, ii); kwargs...)
end
end
end
check_equal(x, y::Composite; kwargs...) = check_equal(y, x; kwargs...)

# This catches comparisons of Composites and Tuples/NamedTuple
# and gives a error messaage complaining about that
const LegacyZygoteCompTypes = Union{Tuple,NamedTuple}
check_equal(::C, expected::T) where {C<:Composite,T<:LegacyZygoteCompTypes} = @test C === T
check_equal(::T, expected::C) where {C<:Composite,T<:LegacyZygoteCompTypes} = @test T === C

check_equal(::Zero, x; kwargs...) = check_equal(zero(x), x; kwargs...)
check_equal(x, ::Zero; kwargs...) = check_equal(x, zero(x); kwargs...)
check_equal(x::Zero, y::Zero; kwargs...) = @test true

# Generic fallback, probably a tuple or something
function check_equal(actual::A, expected::E; kwargs...) where {A, E}
if _can_pass_early(actual, expected)
Expand All @@ -101,6 +118,7 @@ function check_equal(actual::A, expected::E; kwargs...) where {A, E}
end
end


"""
_check_add!!_behavour(acc, val)
Expand Down
12 changes: 8 additions & 4 deletions test/check_result.jl
Original file line number Diff line number Diff line change
Expand Up @@ -51,10 +51,14 @@ end
Composite{Tuple{Float64, Float64}}(1.0, 2.0)
)

D = Diagonal(randn(5))
check_equal(
Composite{typeof(D)}(diag=D.diag),
Composite{typeof(D)}(diag=D.diag)
diag_eg = Diagonal(randn(5))
check_equal( # Structual == Structural
Composite{typeof(diag_eg)}(diag=diag_eg.diag),
Composite{typeof(diag_eg)}(diag=diag_eg.diag)
)
check_equal( # Structural == Natural
Composite{typeof(diag_eg)}(diag=diag_eg.diag),
diag_eg
)

T = (a=1.0, b=2.0)
Expand Down

2 comments on commit 559f2c1

@oxinabox
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/26500

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.5.8 -m "<description of version>" 559f2c15e8bbe6640a7cc2e60945b88fc49b664d
git push origin v0.5.8

Please sign in to comment.