diff --git a/Project.toml b/Project.toml index a800b28..9a91a41 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "ChainRulesTestUtils" uuid = "cdddcdb0-9152-4a09-a978-84456f9df70a" -version = "1.9.0" +version = "1.9.1" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" diff --git a/src/check_result.jl b/src/check_result.jl index c95ef17..a0ed890 100644 --- a/src/check_result.jl +++ b/src/check_result.jl @@ -24,7 +24,16 @@ function test_approx( @test_msg msg isapprox(actual, expected; kwargs...) end -for (T1, T2) in ((AbstractThunk, Any), (AbstractThunk, AbstractThunk), (Any, AbstractThunk)) +for (T1, T2) in + ( + (AbstractThunk, Any), + (AbstractThunk, AbstractThunk), + (Any, AbstractThunk), + (Tangent, AbstractThunk), + (AbstractThunk, Tangent), + (AbstractZero, AbstractThunk), + (AbstractThunk, AbstractZero), + ) @eval function test_approx(actual::$T1, expected::$T2, msg=""; kwargs...) return test_approx(unthunk(actual), unthunk(expected), msg; kwargs...) end @@ -123,9 +132,8 @@ function test_approx(actual::Tangent{P,T}, expected, msg=""; kwargs...) where {T end test_approx(x, y::Tangent, msg=""; kwargs...) = test_approx(y, x, msg; kwargs...) -function test_approx(actual::Tangent, expected::AbstractThunk, msg=""; kwargs...) - return test_approx(actual, unthunk(expected), msg; kwargs...) -end +test_approx(z::NoTangent, t::Tangent, msg=""; kwargs...) = all(==(NoTangent()), t) +test_approx(t::Tangent, z::NoTangent, msg=""; kwargs...) = all(==(NoTangent()), t) # This catches comparisons of Tangents and Tuples/NamedTuple # and gives an error message complaining about that. the `@test` will definitely fail diff --git a/test/check_result.jl b/test/check_result.jl index aad05c9..eabb1af 100644 --- a/test/check_result.jl +++ b/test/check_result.jl @@ -83,6 +83,20 @@ end Tangent{Tuple{Float64,Float64}}(1.0, 2.0), @thunk(Tangent{Tuple{Float64,Float64}}(1.0, 2.0)), ) + test_approx( + @thunk(Tangent{Tuple{Float64,Float64}}(1.0, 2.0)), + Tangent{Tuple{Float64,Float64}}(1.0, 2.0), + ) + test_approx(@thunk(ZeroTangent()), ZeroTangent()) + test_approx(ZeroTangent(), @thunk(ZeroTangent())) + test_approx( + Tangent{Tuple{Float64,Float64}}(NoTangent(), NoTangent()), + NoTangent(), + ) + test_approx( + NoTangent(), + Tangent{Tuple{Float64,Float64}}(NoTangent(), NoTangent()), + ) end @testset "negative case" begin @test fails(() -> test_approx(1.0, 2.0))