From ef5f72841ceaf38b99d2528ab0cacc09e2a38fa7 Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Mon, 5 Oct 2020 19:28:15 +0100 Subject: [PATCH 01/13] overhall equality checks and add accumulation checks --- Project.toml | 2 +- src/ChainRulesTestUtils.jl | 1 + src/check_result.jl | 37 ++++++++++++++++++++++++++++++++++++- src/iterator.jl | 4 ++++ src/testers.jl | 32 ++++++++++++++++++-------------- 5 files changed, 60 insertions(+), 16 deletions(-) diff --git a/Project.toml b/Project.toml index 9168a6d5..d918fd79 100644 --- a/Project.toml +++ b/Project.toml @@ -11,7 +11,7 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [compat] -ChainRulesCore = "0.9.1" +ChainRulesCore = "0.9.13" Compat = "3" FiniteDifferences = "0.11.2" julia = "1" diff --git a/src/ChainRulesTestUtils.jl b/src/ChainRulesTestUtils.jl index 2acaf050..880b3282 100644 --- a/src/ChainRulesTestUtils.jl +++ b/src/ChainRulesTestUtils.jl @@ -15,6 +15,7 @@ export TestIterator export test_scalar, frule_test, rrule_test, generate_well_conditioned_matrix include("generate_tangent.jl") +include("check_result.jl") include("isapprox.jl") include("data_generation.jl") include("iterator.jl") diff --git a/src/check_result.jl b/src/check_result.jl index abbe7d5d..5bec4805 100644 --- a/src/check_result.jl +++ b/src/check_result.jl @@ -1,3 +1,38 @@ +# For once you have the sensitivity by two methods (e.g both finite-differencing and AD) +# the code here checks it is correct. +# Goal is to only call `@isapprox` on things that render well + +""" + check_equal(actual, expected; kwargs...) + +`@test`'s that `actual ≈ expected`, but breaks up data such that human readable results +are shown on failures. +All keyword arguments are passed to `isapprox`. +""" +function check_equal( + actual::Union{AbstractArray{<:Number}, Number}, + expected::Union{AbstractArray{<:Number}, Number}; + kwargs... +) + @test isapprox(actual, expected; kwargs...) +end + +function check_equal(actual::AbstractThunk, expected; kwargs...) + check_equal(unthunk(actual), expected; kwargs...) +end + + +function check_equal( + actual::Union{Composite, AbstractArray}, + expected; + kwargs... +) + @test length(actual) == length(expected) + @testset "$ii" for ii in keys(actual) # keys works on all Composites + check_equal(actual[ii], expected[ii]; kwargs...) + end +end + """ _check_add!!_behavour(acc, val) @@ -15,5 +50,5 @@ function _check_add!!_behavour(acc, val; kwargs...) # e.g. if it is immutable. We do test the `add!!` return value. # That is what people should rely on. The mutation is just to save allocations. acc_mutated = deepcopy(acc) # prevent this test changing others - @test isapprox(add!!(acc_mutated, val), acc + val; kwargs...) + @test check_equal(add!!(acc_mutated, val), acc + val; kwargs...) end diff --git a/src/iterator.jl b/src/iterator.jl index 7a256778..b2e4648e 100644 --- a/src/iterator.jl +++ b/src/iterator.jl @@ -60,6 +60,10 @@ function Base.isapprox( return isapprox(iter1.data, iter2.data; kwargs...) end +function check_equal(expected::TestIterator, actual::TestIterator; kwargs...) + return isapprox(expected, actual; kwargs...) +end + function rand_tangent(rng::AbstractRNG, x::TestIterator{<:Any,IS,IE}) where {IS,IE} ∂data = rand_tangent(rng, x.data) return TestIterator{typeof(∂data),IS,IE}(∂data) diff --git a/src/testers.jl b/src/testers.jl index 166bc11d..004d50f6 100644 --- a/src/testers.jl +++ b/src/testers.jl @@ -126,9 +126,9 @@ function test_scalar(f, z; rtol=1e-9, atol=1e-9, fdm=_fdm, fkwargs=NamedTuple(), frule_test(f, (z, Δx); rtol=rtol, atol=atol, fdm=fdm, fkwargs=fkwargs, kwargs...) if z isa Complex # check that same tangent is produced for tangent 1.0 and 1.0 + 0.0im - @test isapprox( - frule((Zero(), real(Δx)), f, z; fkwargs...)[2], - frule((Zero(), Δx), f, z; fkwargs...)[2], + check_equal( + frule((Zero(), real(Δx)), f, z; fkwargs...)[2]::Number, + frule((Zero(), Δx), f, z; fkwargs...)[2]::Number, rtol=rtol, atol=atol, kwargs..., @@ -150,10 +150,10 @@ function test_scalar(f, z; rtol=1e-9, atol=1e-9, fdm=_fdm, fkwargs=NamedTuple(), rrule_test(f, Δu, (z, Δx); rtol=rtol, atol=atol, fdm=fdm, fkwargs=fkwargs, kwargs...) if Ω isa Complex # check that same cotangent is produced for cotangent 1.0 and 1.0 + 0.0im - back = rrule(f, z)[2] - @test isapprox( - extern(back(real(Δu))[2]), - extern(back(Δu)[2]), + _, back = rrule(f, z) + check_equal( + back(real(Δu))[2], + back(Δu)[2], rtol=rtol, atol=atol, kwargs..., @@ -249,16 +249,20 @@ function rrule_test(f, ȳ, xx̄s::Tuple{Any, Any}...; rtol=1e-9, atol=1e-9, fdm @test x̄_ad isa DoesNotExist # we said it wasn't differentiable. else # The main test of the actual deriviative being correct: - @test isapprox(x̄_ad, x̄_fd; rtol=rtol, atol=atol, kwargs...) - + check_result(accumulated_x̄, x̄_ad, x̄_fd; rtol=rtol, atol=atol, kwargs...) _check_add!!_behavour(x̄_acc, x̄_ad; rtol=rtol, atol=atol, kwargs...) end end - if count(!, x̄s_is_dne) == 1 - # for functions with pullbacks that only produce a single non-DNE adjoint, that - # single adjoint should not be `Thunk`ed. InplaceableThunk is fine. - i = findfirst(!, x̄s_is_dne) - @test !(isa(x̄s_ad[i], Thunk)) + check_thunking_is_appropriate(x̄s_ad) +end + +function check_thunking_is_appropriate(x̄s) + @testset "Don't thunk only non_zero argument" begin + num_zeros = count(x->x isa AbstractZero, x̄s) + num_thunks = count(x->x isa Thunk, x̄s) + if num_zeros + num_thunks == length(x̄s) + @test num_thunks !== 1 + end end end From f4b1b9896bc920b8eb751724fa4eee2c758e1aaf Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Fri, 11 Dec 2020 17:20:18 +0000 Subject: [PATCH 02/13] remove all uses of isapprox and use _checkequal --- src/ChainRulesTestUtils.jl | 1 - src/check_result.jl | 47 ++++++++++++++++-------- src/iterator.jl | 2 +- src/testers.jl | 74 +++++++++++++++++--------------------- 4 files changed, 67 insertions(+), 57 deletions(-) diff --git a/src/ChainRulesTestUtils.jl b/src/ChainRulesTestUtils.jl index 880b3282..2acaf050 100644 --- a/src/ChainRulesTestUtils.jl +++ b/src/ChainRulesTestUtils.jl @@ -15,7 +15,6 @@ export TestIterator export test_scalar, frule_test, rrule_test, generate_well_conditioned_matrix include("generate_tangent.jl") -include("check_result.jl") include("isapprox.jl") include("data_generation.jl") include("iterator.jl") diff --git a/src/check_result.jl b/src/check_result.jl index 5bec4805..45b95965 100644 --- a/src/check_result.jl +++ b/src/check_result.jl @@ -1,15 +1,16 @@ # For once you have the sensitivity by two methods (e.g both finite-differencing and AD) # the code here checks it is correct. # Goal is to only call `@isapprox` on things that render well +# Note that this must work well both on Differnetial types and Primal types """ - check_equal(actual, expected; kwargs...) + _check_equal(actual, expected; kwargs...) `@test`'s that `actual ≈ expected`, but breaks up data such that human readable results are shown on failures. All keyword arguments are passed to `isapprox`. """ -function check_equal( +function _check_equal( actual::Union{AbstractArray{<:Number}, Number}, expected::Union{AbstractArray{<:Number}, Number}; kwargs... @@ -17,22 +18,40 @@ function check_equal( @test isapprox(actual, expected; kwargs...) end -function check_equal(actual::AbstractThunk, expected; kwargs...) - check_equal(unthunk(actual), expected; kwargs...) +for (T1, T2) in ((AbstractThunk, Any), (AbstractThunk, AbstractThunk), (Any, AbstractThunk)) + @eval function _check_equal(actual::$T1, expected::$T2; kwargs...) + _check_equal(unthunk(actual), unthunk(expected); kwargs...) + end end - -function check_equal( - actual::Union{Composite, AbstractArray}, - expected; - kwargs... -) - @test length(actual) == length(expected) - @testset "$ii" for ii in keys(actual) # keys works on all Composites - check_equal(actual[ii], expected[ii]; kwargs...) +function _check_equal(actual::Union{Composite, AbstractArray}, expected; kwargs...) + if actual == expected # if equal then we don't need to be smarter + @test true + else + @test length(actual) == length(expected) + @testset "$ii" for ii in keys(actual) # keys works on all Composites + _check_equal(actual[ii], expected[ii]; kwargs...) + end end end +_check_equal(::AbstractZero, x; kwargs...) = _check_equal(zero(x), x; kwargs...) +_check_equal(x, ::AbstractZero; kwargs...) = _check_equal(x, zero(x); kwargs...) +_check_equal(x::AbstractZero, y::AbstractZero; kwargs...) = @test x === y + +# Generic fallback, probably a tuple or something +function _check_equal(actual::A, expected::E; kwargs...) where {A, E} + if actual == expected # if equal then we don't need to be smarter + @test true + else + c_actual = collect(actual) + c_expected = collect(expected) + if (c_actual isa A) && (c_expected isa E) # prevent stack-overflow + throw(MethodError, _check_equal, (actual, expected)) + end + _check_equal(c_actual, c_expected; kwargs...) + end +end """ _check_add!!_behavour(acc, val) @@ -50,5 +69,5 @@ function _check_add!!_behavour(acc, val; kwargs...) # e.g. if it is immutable. We do test the `add!!` return value. # That is what people should rely on. The mutation is just to save allocations. acc_mutated = deepcopy(acc) # prevent this test changing others - @test check_equal(add!!(acc_mutated, val), acc + val; kwargs...) + _check_equal(add!!(acc_mutated, val), acc + val; kwargs...) end diff --git a/src/iterator.jl b/src/iterator.jl index b2e4648e..a3384008 100644 --- a/src/iterator.jl +++ b/src/iterator.jl @@ -60,7 +60,7 @@ function Base.isapprox( return isapprox(iter1.data, iter2.data; kwargs...) end -function check_equal(expected::TestIterator, actual::TestIterator; kwargs...) +function _check_equal(expected::TestIterator, actual::TestIterator; kwargs...) return isapprox(expected, actual; kwargs...) end diff --git a/src/testers.jl b/src/testers.jl index 004d50f6..d6afbb7c 100644 --- a/src/testers.jl +++ b/src/testers.jl @@ -114,6 +114,10 @@ at input point `z` to confirm that there are correct `frule` and `rrule`s provid All keyword arguments except for `fdm` and `fkwargs` are passed to `isapprox`. """ function test_scalar(f, z; rtol=1e-9, atol=1e-9, fdm=_fdm, fkwargs=NamedTuple(), kwargs...) + # To simplify some of the calls we make later lets group the kwargs for reuse + rule_test_kwargs = (; rtol=rtol, atol=atol, fdm=fdm, fkwargs=fkwargs, kwargs...) + isapprox_kwargs = (; rtol=rtol, atol=atol, kwargs...) + _ensure_not_running_on_functor(f, "test_scalar") # z = x + im * y # Ω = u(x, y) + im * v(x, y) @@ -123,23 +127,19 @@ function test_scalar(f, z; rtol=1e-9, atol=1e-9, fdm=_fdm, fkwargs=NamedTuple(), Δx = one(z) @testset "$f at $z, with tangent $Δx" begin # check ∂u_∂x and (if Ω is complex) ∂v_∂x via forward mode - frule_test(f, (z, Δx); rtol=rtol, atol=atol, fdm=fdm, fkwargs=fkwargs, kwargs...) + frule_test(f, (z, Δx); rule_test_kwargs...) if z isa Complex # check that same tangent is produced for tangent 1.0 and 1.0 + 0.0im - check_equal( - frule((Zero(), real(Δx)), f, z; fkwargs...)[2]::Number, - frule((Zero(), Δx), f, z; fkwargs...)[2]::Number, - rtol=rtol, - atol=atol, - kwargs..., - ) + _, real_tangent = frule((Zero(), real(Δx)), f, z; fkwargs...) + _, embedded_tangent = frule((Zero(), Δx), f, z; fkwargs...) + _check_equal(real_tangent, embedded_tangent; isapprox_kwargs...) end end if z isa Complex Δy = one(z) * im @testset "$f at $z, with tangent $Δy" begin # check ∂u_∂y and (if Ω is complex) ∂v_∂y via forward mode - frule_test(f, (z, Δy); rtol=rtol, atol=atol, fdm=fdm, fkwargs=fkwargs, kwargs...) + frule_test(f, (z, Δy); rule_test_kwargs...) end end @@ -147,24 +147,20 @@ function test_scalar(f, z; rtol=1e-9, atol=1e-9, fdm=_fdm, fkwargs=NamedTuple(), Δu = one(Ω) @testset "$f at $z, with cotangent $Δu" begin # check ∂u_∂x and (if z is complex) ∂u_∂y via reverse mode - rrule_test(f, Δu, (z, Δx); rtol=rtol, atol=atol, fdm=fdm, fkwargs=fkwargs, kwargs...) + rrule_test(f, Δu, (z, Δx); rule_test_kwargs...) if Ω isa Complex # check that same cotangent is produced for cotangent 1.0 and 1.0 + 0.0im _, back = rrule(f, z) - check_equal( - back(real(Δu))[2], - back(Δu)[2], - rtol=rtol, - atol=atol, - kwargs..., - ) + _, real_cotangent = back(real(Δu)) + _, embedded_cotangent = back(Δu) + _check_equal(real_cotangent, embedded_cotangent; isapprox_kwargs...) end end if Ω isa Complex Δv = one(Ω) * im @testset "$f at $z, with cotangent $Δv" begin # check ∂v_∂x and (if z is complex) ∂v_∂y via reverse mode - rrule_test(f, Δv, (z, Δx); rtol=rtol, atol=atol, fdm=fdm, fkwargs=fkwargs, kwargs...) + rrule_test(f, Δv, (z, Δx); rule_test_kwargs...) end end end @@ -181,25 +177,22 @@ end All keyword arguments except for `fdm` and `fkwargs` are passed to `isapprox`. """ function frule_test(f, xẋs::Tuple{Any, Any}...; rtol=1e-9, atol=1e-9, fdm=_fdm, fkwargs=NamedTuple(), kwargs...) + # To simplify some of the calls we make later lets group the kwargs for reuse + isapprox_kwargs = (; rtol=rtol, atol=atol, kwargs...) + _ensure_not_running_on_functor(f, "frule_test") - xs, ẋs = first.(xẋs), last.(xẋs) + + xs = first.(xẋs) + ẋs = last.(xẋs) Ω_ad, dΩ_ad = frule((NO_FIELDS, deepcopy(ẋs)...), f, deepcopy(xs)...; deepcopy(fkwargs)...) Ω = f(deepcopy(xs)...; deepcopy(fkwargs)...) - # if equality check fails, check approximate equality - # use collect so can do vector equality - # TODO: add isapprox replacement that works for more types - @test Ω_ad == Ω || isapprox(collect(Ω_ad), collect(Ω); rtol=rtol, atol=atol) + _check_equal(Ω_ad, Ω; isapprox_kwargs...) ẋs_is_ignored = ẋs .== nothing # Correctness testing via finite differencing. dΩ_fd = _make_jvp_call(fdm, (xs...) -> f(deepcopy(xs)...; deepcopy(fkwargs)...), xs, ẋs, ẋs_is_ignored) - @test isapprox( - collect(extern.(dΩ_ad)), # Use collect so can use vector equality - collect(dΩ_fd); - rtol=rtol, - atol=atol, - kwargs... - ) + _check_equal(dΩ_ad, dΩ_fd; isapprox_kwargs...) + # No tangent is passed in to test accumlation, so generate one # See: https://github.com/JuliaDiff/ChainRulesTestUtils.jl/issues/66 @@ -222,35 +215,34 @@ end All keyword arguments except for `fdm` and `fkwargs` are passed to `isapprox`. """ function rrule_test(f, ȳ, xx̄s::Tuple{Any, Any}...; rtol=1e-9, atol=1e-9, fdm=_fdm, fkwargs=NamedTuple(), kwargs...) + # To simplify some of the calls we make later lets group the kwargs for reuse + isapprox_kwargs = (; rtol=rtol, atol=atol, kwargs...) + _ensure_not_running_on_functor(f, "rrule_test") # Check correctness of evaluation. xs = first.(xx̄s) - x̄s_acc = last.(xx̄s) + accumulated_x̄ = last.(xx̄s) y_ad, pullback = rrule(f, xs...; fkwargs...) y = f(xs...; fkwargs...) - # if equality check fails, check approximate equality - # use collect so can do vector equality - # TODO: add isapprox replacement that works for more types - @test y_ad == y || isapprox(collect(y_ad), collect(y); rtol=rtol, atol=atol) - @assert !(isa(ȳ, Thunk)) + _check_equal(y_ad, y; isapprox_kwargs...) # make sure primal is correct ∂s = pullback(ȳ) ∂self = ∂s[1] x̄s_ad = ∂s[2:end] @test ∂self === NO_FIELDS # No internal fields - x̄s_is_dne = x̄s_acc .== nothing # Correctness testing via finite differencing. + x̄s_is_dne = accumulated_x̄ .== nothing x̄s_fd = _make_j′vp_call(fdm, (xs...) -> f(xs...; fkwargs...), ȳ, xs, x̄s_is_dne) - for (x̄_acc, x̄_ad, x̄_fd) in zip(x̄s_acc, x̄s_ad, x̄s_fd) - if x̄_acc === nothing + for (accumulated_x̄, x̄_ad, x̄_fd) in zip(accumulated_x̄, x̄s_ad, x̄s_fd) + if accumulated_x̄ === nothing # then we marked this argument as not differentiable @assert x̄_fd === nothing # this is how `_make_j′vp_call` works @test x̄_ad isa DoesNotExist # we said it wasn't differentiable. else # The main test of the actual deriviative being correct: - check_result(accumulated_x̄, x̄_ad, x̄_fd; rtol=rtol, atol=atol, kwargs...) - _check_add!!_behavour(x̄_acc, x̄_ad; rtol=rtol, atol=atol, kwargs...) + _check_equal(x̄_ad, x̄_fd; isapprox_kwargs...) + _check_add!!_behavour(accumulated_x̄, x̄_ad; isapprox_kwargs...) end end From becabe1dee908ea47b3d42f1e23c4649e94bee15 Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Fri, 11 Dec 2020 20:05:50 +0000 Subject: [PATCH 03/13] Test _check_equals --- src/check_result.jl | 24 +++++++++++++++++++- test/check_result.jl | 53 ++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 76 insertions(+), 1 deletion(-) diff --git a/src/check_result.jl b/src/check_result.jl index 45b95965..230fc970 100644 --- a/src/check_result.jl +++ b/src/check_result.jl @@ -24,7 +24,8 @@ for (T1, T2) in ((AbstractThunk, Any), (AbstractThunk, AbstractThunk), (Any, Abs end end -function _check_equal(actual::Union{Composite, AbstractArray}, expected; kwargs...) + +function _elementwise_check_equal(actual, expected; kwargs...) if actual == expected # if equal then we don't need to be smarter @test true else @@ -35,6 +36,27 @@ function _check_equal(actual::Union{Composite, AbstractArray}, expected; kwargs. end end +function _check_equal(actual::Composite{P}, expected::Composite{P}; kwargs...) where P + return _elementwise_check_equal(actual, expected; kwargs...) +end +function _check_equal(actual::AbstractArray, expected::AbstractArray; kwargs...) + return _elementwise_check_equal(actual, expected; kwargs...) +end + +function _check_equal( + ::Composite{ActualPrimal}, expected::Composite{ExpectedPrimal} + ) where {ActualPrimal, ExpectedPrimal} + # this will certainly fail as we have another dispatch for that, but this will give as + # good error message + @test ActualPrimal === ExpectedPrimal +end + +# This catches comparisons of Composites and Tuples/NamedTuple +# and gives a error messaage complaining about that +_check_equal(::C, expected::T) where {C<:Composite, T} = @test C === T +_check_equal(::T, expected::C) where {C<:Composite, T} = @test C === T + + _check_equal(::AbstractZero, x; kwargs...) = _check_equal(zero(x), x; kwargs...) _check_equal(x, ::AbstractZero; kwargs...) = _check_equal(x, zero(x); kwargs...) _check_equal(x::AbstractZero, y::AbstractZero; kwargs...) = @test x === y diff --git a/test/check_result.jl b/test/check_result.jl index 99d80076..74564a25 100644 --- a/test/check_result.jl +++ b/test/check_result.jl @@ -17,4 +17,57 @@ X̄ -> (X̄[1] += 3.0; X̄), ))) end + + + @testset "_check_equal" begin + check = ChainRulesTestUtils._check_equal + + @testset "possive cases" begin + check(1.0, 1.0) + check(1.0 + im, 1.0 + im) + check(1.0, 1.0+1e-100) # isapprox _behavour + check((1.5, 2.5, 3.5), (1.5, 2.5, 3.5 + 1e-100)) + + check(Zero(), 0.0) + + check([1.0, 2.0], [1.0, 2.0]) + check([[1.0], [2.0]], [[1.0], [2.0]]) + + check(@thunk(10*0.1*[[1.0], [2.0]]), [[1.0], [2.0]]) + + check( + Composite{Tuple{Float64, Float64}}(1.0, 2.0), + Composite{Tuple{Float64, Float64}}(1.0, 2.0) + ) + + D = Diagonal(randn(5)) + check( + Composite{typeof(D)}(diag=D.diag), + Composite{typeof(D)}(diag=D.diag) + ) + end + @testset "negative case" begin + @test fails(()->check(1.0, 2.0)) + @test fails(()->check(1.0 + im, 1.0 - im)) + @test fails(()->check((1.5, 2.5, 3.5), (1.5, 2.5, 4.5))) + + @test fails(()->check(Zero(), 20.0)) + @test fails(()->check(10.0, Zero())) + @test fails(()->check(DoesNotExist(), Zero())) + + @test fails(()->check([1.0, 2.0], [1.0, 3.9])) + @test fails(()->check([[1.0], [2.0]], [[1.1], [2.0]])) + + @test fails(()->check(@thunk(10*[[1.0], [2.0]]), [[1.0], [2.0]])) + end + @testset "type negative" begin + @test fails() do # these have different primals so should not be equal + check( + Composite{Tuple{Float32, Float32}}(1f0, 2f0), + Composite{Tuple{Float64, Float64}}(1.0, 2.0) + ) + end + @test fails(()->check((1.0, 2.0), Composite{Tuple{Float64, Float64}}(1.0, 2.0))) + end + end end From 305022eed443b2b0efa63068d2e063fd6d65af41 Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Fri, 11 Dec 2020 20:31:30 +0000 Subject: [PATCH 04/13] Handle composites with some zeros --- src/check_result.jl | 18 +++++++++++------- test/check_result.jl | 14 ++++++++++++-- 2 files changed, 23 insertions(+), 9 deletions(-) diff --git a/src/check_result.jl b/src/check_result.jl index 230fc970..b98136a1 100644 --- a/src/check_result.jl +++ b/src/check_result.jl @@ -25,22 +25,26 @@ for (T1, T2) in ((AbstractThunk, Any), (AbstractThunk, AbstractThunk), (Any, Abs end -function _elementwise_check_equal(actual, expected; kwargs...) +function _check_equal(actual::AbstractArray, expected::AbstractArray; kwargs...) if actual == expected # if equal then we don't need to be smarter @test true else - @test length(actual) == length(expected) - @testset "$ii" for ii in keys(actual) # keys works on all Composites + @test eachindex(actual) == eachindex(expected) + @testset "$(typeof(actual))[$ii]" for ii in eachindex(actual) _check_equal(actual[ii], expected[ii]; kwargs...) end end end function _check_equal(actual::Composite{P}, expected::Composite{P}; kwargs...) where P - return _elementwise_check_equal(actual, expected; kwargs...) -end -function _check_equal(actual::AbstractArray, expected::AbstractArray; kwargs...) - return _elementwise_check_equal(actual, expected; kwargs...) + if actual == expected # if equal then we don't need to be smarter + @test true + else + all_keys = union(keys(actual), keys(expected)) + @testset "$P.$ii" for ii in all_keys + _check_equal(getproperty(actual, ii), getproperty(expected, ii); kwargs...) + end + end end function _check_equal( diff --git a/test/check_result.jl b/test/check_result.jl index 74564a25..cc47f74a 100644 --- a/test/check_result.jl +++ b/test/check_result.jl @@ -25,8 +25,8 @@ @testset "possive cases" begin check(1.0, 1.0) check(1.0 + im, 1.0 + im) - check(1.0, 1.0+1e-100) # isapprox _behavour - check((1.5, 2.5, 3.5), (1.5, 2.5, 3.5 + 1e-100)) + check(1.0, 1.0+1e-10) # isapprox _behavour + check((1.5, 2.5, 3.5), (1.5, 2.5, 3.5 + 1e-10)) check(Zero(), 0.0) @@ -45,6 +45,16 @@ Composite{typeof(D)}(diag=D.diag), Composite{typeof(D)}(diag=D.diag) ) + + T = (a=1.0, b=2.0) + check( + Composite{typeof(T)}(a=1.0), + Composite{typeof(T)}(a=1.0, b=Zero()) + ) + check( + Composite{typeof(T)}(a=1.0), + Composite{typeof(T)}(a=1.0+1e-10, b=Zero()) + ) end @testset "negative case" begin @test fails(()->check(1.0, 2.0)) From 33176f78f2c72a984a501fb411dec8389f435f6f Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Fri, 11 Dec 2020 20:44:59 +0000 Subject: [PATCH 05/13] move isapprox into deprecated file --- src/ChainRulesTestUtils.jl | 3 ++- src/{isapprox.jl => deprecated.jl} | 4 +++- test/{isapprox.jl => deprecated.jl} | 0 test/runtests.jl | 3 ++- 4 files changed, 7 insertions(+), 3 deletions(-) rename src/{isapprox.jl => deprecated.jl} (91%) rename test/{isapprox.jl => deprecated.jl} (100%) diff --git a/src/ChainRulesTestUtils.jl b/src/ChainRulesTestUtils.jl index 2acaf050..1d4c7781 100644 --- a/src/ChainRulesTestUtils.jl +++ b/src/ChainRulesTestUtils.jl @@ -15,9 +15,10 @@ export TestIterator export test_scalar, frule_test, rrule_test, generate_well_conditioned_matrix include("generate_tangent.jl") -include("isapprox.jl") include("data_generation.jl") include("iterator.jl") include("check_result.jl") include("testers.jl") + +include("deprecated.jl") end # module diff --git a/src/isapprox.jl b/src/deprecated.jl similarity index 91% rename from src/isapprox.jl rename to src/deprecated.jl index fee59e8f..bba6570c 100644 --- a/src/isapprox.jl +++ b/src/deprecated.jl @@ -1,4 +1,6 @@ -# TODO: reconsider these https://github.com/JuliaDiff/ChainRulesTestUtils.jl/issues/7 +# TODO remove these in version 0.6 +# We are silently deprecating them as there is no alternative we are providing + Base.isapprox(a, b::Union{AbstractZero, AbstractThunk}; kwargs...) = isapprox(b, a; kwargs...) Base.isapprox(d_ad::AbstractThunk, d_fd; kwargs...) = isapprox(extern(d_ad), d_fd; kwargs...) Base.isapprox(d_ad::DoesNotExist, d_fd; kwargs...) = error("Tried to differentiate w.r.t. a `DoesNotExist`") diff --git a/test/isapprox.jl b/test/deprecated.jl similarity index 100% rename from test/isapprox.jl rename to test/deprecated.jl diff --git a/test/runtests.jl b/test/runtests.jl index 79ba9496..26f88478 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -7,9 +7,10 @@ using Test @testset "ChainRulesTestUtils.jl" begin include("meta_testing_tools.jl") include("generate_tangent.jl") - include("isapprox.jl") include("iterator.jl") include("check_result.jl") include("testers.jl") include("data_generation.jl") + + include("deprecated.jl") end From aa362e6179d0a597d7a893ce0652d886633cd513 Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Fri, 11 Dec 2020 20:51:44 +0000 Subject: [PATCH 06/13] remove isapprox from TestIterator --- src/iterator.jl | 13 ------------- test/check_result.jl | 13 +++++++++++++ test/iterator.jl | 10 ---------- 3 files changed, 13 insertions(+), 23 deletions(-) diff --git a/src/iterator.jl b/src/iterator.jl index a3384008..85ec08b8 100644 --- a/src/iterator.jl +++ b/src/iterator.jl @@ -51,19 +51,6 @@ end # For testing purposes: -Base.isapprox(iter1::TestIterator, iter2::TestIterator) = false -function Base.isapprox( - iter1::TestIterator{T1,IS,IE}, - iter2::TestIterator{T2,IS,IE}; - kwargs..., -) where {T1,T2,IS,IE} - return isapprox(iter1.data, iter2.data; kwargs...) -end - -function _check_equal(expected::TestIterator, actual::TestIterator; kwargs...) - return isapprox(expected, actual; kwargs...) -end - function rand_tangent(rng::AbstractRNG, x::TestIterator{<:Any,IS,IE}) where {IS,IE} ∂data = rand_tangent(rng, x.data) return TestIterator{typeof(∂data),IS,IE}(∂data) diff --git a/test/check_result.jl b/test/check_result.jl index cc47f74a..0c11d9bd 100644 --- a/test/check_result.jl +++ b/test/check_result.jl @@ -79,5 +79,18 @@ end @test fails(()->check((1.0, 2.0), Composite{Tuple{Float64, Float64}}(1.0, 2.0))) end + + @testset "TestIterator" begin + data = randn(3) + iter1 = TestIterator(data, Base.HasLength(), Base.HasEltype()) + iter2 = TestIterator(data, Base.HasLength(), Base.EltypeUnknown()) + check(iter2, iter1) + + iter3 = TestIterator(data .+ 1e-10, Base.HasLength(), Base.HasEltype()) + check(iter3, iter1) + + iter_bad = TestIterator(data .+ 010, Base.HasLength(), Base.HasEltype()) + @test fails(()->check(iter_bad, iter1)) + end end end diff --git a/test/iterator.jl b/test/iterator.jl index a181e963..d47a7c43 100644 --- a/test/iterator.jl +++ b/test/iterator.jl @@ -79,16 +79,6 @@ @test hash(iter3) == hash(iter1) end - @testset "isapprox" begin - data = randn(3) - iter1 = TestIterator(data, Base.HasLength(), Base.HasEltype()) - iter2 = TestIterator(data, Base.HasLength(), Base.EltypeUnknown()) - @test !isapprox(iter2, iter1) - - iter3 = TestIterator(data .+ eps() .* rand.(), Base.HasLength(), Base.HasEltype()) - @test isapprox(iter3, iter1) - end - @testset "to_vec" begin data = randn(2, 3, 4) iter = TestIterator(data, Base.SizeUnknown(), Base.EltypeUnknown()) From 175c001a165098370ed7ac5ec9e8d4c1aa1d5d37 Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Fri, 11 Dec 2020 20:54:34 +0000 Subject: [PATCH 07/13] spacing Co-authored-by: willtebbutt --- src/check_result.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/check_result.jl b/src/check_result.jl index b98136a1..9e0215c8 100644 --- a/src/check_result.jl +++ b/src/check_result.jl @@ -49,7 +49,7 @@ end function _check_equal( ::Composite{ActualPrimal}, expected::Composite{ExpectedPrimal} - ) where {ActualPrimal, ExpectedPrimal} +) where {ActualPrimal, ExpectedPrimal} # this will certainly fail as we have another dispatch for that, but this will give as # good error message @test ActualPrimal === ExpectedPrimal From fb857af6d5770516c1752f35f4c25bacbcc2da0b Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Fri, 11 Dec 2020 20:57:10 +0000 Subject: [PATCH 08/13] Update test/check_result.jl Co-authored-by: willtebbutt --- test/check_result.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/check_result.jl b/test/check_result.jl index 0c11d9bd..bc44b1e2 100644 --- a/test/check_result.jl +++ b/test/check_result.jl @@ -25,7 +25,7 @@ @testset "possive cases" begin check(1.0, 1.0) check(1.0 + im, 1.0 + im) - check(1.0, 1.0+1e-10) # isapprox _behavour + check(1.0, 1.0+1e-10) # isapprox _behaviour check((1.5, 2.5, 3.5), (1.5, 2.5, 3.5 + 1e-10)) check(Zero(), 0.0) From bdc621b45d4f8a43457cee65c2fbdc9b9360b146 Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Mon, 14 Dec 2020 16:53:25 +0000 Subject: [PATCH 09/13] typo Co-authored-by: willtebbutt --- src/check_result.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/check_result.jl b/src/check_result.jl index 9e0215c8..119333e1 100644 --- a/src/check_result.jl +++ b/src/check_result.jl @@ -1,7 +1,7 @@ # For once you have the sensitivity by two methods (e.g both finite-differencing and AD) # the code here checks it is correct. # Goal is to only call `@isapprox` on things that render well -# Note that this must work well both on Differnetial types and Primal types +# Note that this must work well both on Differential types and Primal types """ _check_equal(actual, expected; kwargs...) From a973aac520052727b6462d693e1eb1ebcc582d3f Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Mon, 14 Dec 2020 16:55:17 +0000 Subject: [PATCH 10/13] Only alll _check_equal on Zero not AbstractZero --- src/check_result.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/check_result.jl b/src/check_result.jl index 119333e1..36d7dfe6 100644 --- a/src/check_result.jl +++ b/src/check_result.jl @@ -61,9 +61,9 @@ _check_equal(::C, expected::T) where {C<:Composite, T} = @test C === T _check_equal(::T, expected::C) where {C<:Composite, T} = @test C === T -_check_equal(::AbstractZero, x; kwargs...) = _check_equal(zero(x), x; kwargs...) -_check_equal(x, ::AbstractZero; kwargs...) = _check_equal(x, zero(x); kwargs...) -_check_equal(x::AbstractZero, y::AbstractZero; kwargs...) = @test x === y +_check_equal(::Zero, x; kwargs...) = _check_equal(zero(x), x; kwargs...) +_check_equal(x, ::Zero; kwargs...) = _check_equal(x, zero(x); kwargs...) +_check_equal(x::Zero, y::Zero; kwargs...) = @test true # Generic fallback, probably a tuple or something function _check_equal(actual::A, expected::E; kwargs...) where {A, E} From 75bc572516503a96ad073c324473b7a608969c1c Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Mon, 14 Dec 2020 18:40:25 +0000 Subject: [PATCH 11/13] remove test of DoesNotExist checkequals Zero --- test/check_result.jl | 1 - 1 file changed, 1 deletion(-) diff --git a/test/check_result.jl b/test/check_result.jl index bc44b1e2..b84e9c9d 100644 --- a/test/check_result.jl +++ b/test/check_result.jl @@ -63,7 +63,6 @@ @test fails(()->check(Zero(), 20.0)) @test fails(()->check(10.0, Zero())) - @test fails(()->check(DoesNotExist(), Zero())) @test fails(()->check([1.0, 2.0], [1.0, 3.9])) @test fails(()->check([[1.0], [2.0]], [[1.1], [2.0]])) From 65e466711619da1bc26e18ecd3a668d0d9683f04 Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Mon, 14 Dec 2020 18:40:45 +0000 Subject: [PATCH 12/13] rename and export check_equal --- src/ChainRulesTestUtils.jl | 3 ++- src/check_result.jl | 37 ++++++++++++------------- src/testers.jl | 12 ++++----- test/check_result.jl | 55 +++++++++++++++++++------------------- 4 files changed, 55 insertions(+), 52 deletions(-) diff --git a/src/ChainRulesTestUtils.jl b/src/ChainRulesTestUtils.jl index 1d4c7781..52b1768c 100644 --- a/src/ChainRulesTestUtils.jl +++ b/src/ChainRulesTestUtils.jl @@ -12,7 +12,8 @@ using Test const _fdm = central_fdm(5, 1) export TestIterator -export test_scalar, frule_test, rrule_test, generate_well_conditioned_matrix +export check_equal, test_scalar, frule_test, rrule_test, generate_well_conditioned_matrix + include("generate_tangent.jl") include("data_generation.jl") diff --git a/src/check_result.jl b/src/check_result.jl index 36d7dfe6..04f13398 100644 --- a/src/check_result.jl +++ b/src/check_result.jl @@ -4,13 +4,14 @@ # Note that this must work well both on Differential types and Primal types """ - _check_equal(actual, expected; kwargs...) + check_equal(actual, expected; kwargs...) `@test`'s that `actual ≈ expected`, but breaks up data such that human readable results are shown on failures. +Understands things like `unthunk`ing `ChainRuleCore.Thunk`s, etc. All keyword arguments are passed to `isapprox`. """ -function _check_equal( +function check_equal( actual::Union{AbstractArray{<:Number}, Number}, expected::Union{AbstractArray{<:Number}, Number}; kwargs... @@ -19,35 +20,35 @@ function _check_equal( end for (T1, T2) in ((AbstractThunk, Any), (AbstractThunk, AbstractThunk), (Any, AbstractThunk)) - @eval function _check_equal(actual::$T1, expected::$T2; kwargs...) - _check_equal(unthunk(actual), unthunk(expected); kwargs...) + @eval function check_equal(actual::$T1, expected::$T2; kwargs...) + check_equal(unthunk(actual), unthunk(expected); kwargs...) end end -function _check_equal(actual::AbstractArray, expected::AbstractArray; kwargs...) +function check_equal(actual::AbstractArray, expected::AbstractArray; kwargs...) if actual == expected # if equal then we don't need to be smarter @test true else @test eachindex(actual) == eachindex(expected) @testset "$(typeof(actual))[$ii]" for ii in eachindex(actual) - _check_equal(actual[ii], expected[ii]; kwargs...) + check_equal(actual[ii], expected[ii]; kwargs...) end end end -function _check_equal(actual::Composite{P}, expected::Composite{P}; kwargs...) where P +function check_equal(actual::Composite{P}, expected::Composite{P}; kwargs...) where P if actual == expected # if equal then we don't need to be smarter @test true else all_keys = union(keys(actual), keys(expected)) @testset "$P.$ii" for ii in all_keys - _check_equal(getproperty(actual, ii), getproperty(expected, ii); kwargs...) + check_equal(getproperty(actual, ii), getproperty(expected, ii); kwargs...) end end end -function _check_equal( +function check_equal( ::Composite{ActualPrimal}, expected::Composite{ExpectedPrimal} ) where {ActualPrimal, ExpectedPrimal} # this will certainly fail as we have another dispatch for that, but this will give as @@ -57,25 +58,25 @@ end # This catches comparisons of Composites and Tuples/NamedTuple # and gives a error messaage complaining about that -_check_equal(::C, expected::T) where {C<:Composite, T} = @test C === T -_check_equal(::T, expected::C) where {C<:Composite, T} = @test C === T +check_equal(::C, expected::T) where {C<:Composite, T} = @test C === T +check_equal(::T, expected::C) where {C<:Composite, T} = @test C === T -_check_equal(::Zero, x; kwargs...) = _check_equal(zero(x), x; kwargs...) -_check_equal(x, ::Zero; kwargs...) = _check_equal(x, zero(x); kwargs...) -_check_equal(x::Zero, y::Zero; kwargs...) = @test true +check_equal(::Zero, x; kwargs...) = check_equal(zero(x), x; kwargs...) +check_equal(x, ::Zero; kwargs...) = check_equal(x, zero(x); kwargs...) +check_equal(x::Zero, y::Zero; kwargs...) = @test true # Generic fallback, probably a tuple or something -function _check_equal(actual::A, expected::E; kwargs...) where {A, E} +function check_equal(actual::A, expected::E; kwargs...) where {A, E} if actual == expected # if equal then we don't need to be smarter @test true else c_actual = collect(actual) c_expected = collect(expected) if (c_actual isa A) && (c_expected isa E) # prevent stack-overflow - throw(MethodError, _check_equal, (actual, expected)) + throw(MethodError, check_equal, (actual, expected)) end - _check_equal(c_actual, c_expected; kwargs...) + check_equal(c_actual, c_expected; kwargs...) end end @@ -95,5 +96,5 @@ function _check_add!!_behavour(acc, val; kwargs...) # e.g. if it is immutable. We do test the `add!!` return value. # That is what people should rely on. The mutation is just to save allocations. acc_mutated = deepcopy(acc) # prevent this test changing others - _check_equal(add!!(acc_mutated, val), acc + val; kwargs...) + check_equal(add!!(acc_mutated, val), acc + val; kwargs...) end diff --git a/src/testers.jl b/src/testers.jl index d6afbb7c..99a7273c 100644 --- a/src/testers.jl +++ b/src/testers.jl @@ -132,7 +132,7 @@ function test_scalar(f, z; rtol=1e-9, atol=1e-9, fdm=_fdm, fkwargs=NamedTuple(), # check that same tangent is produced for tangent 1.0 and 1.0 + 0.0im _, real_tangent = frule((Zero(), real(Δx)), f, z; fkwargs...) _, embedded_tangent = frule((Zero(), Δx), f, z; fkwargs...) - _check_equal(real_tangent, embedded_tangent; isapprox_kwargs...) + check_equal(real_tangent, embedded_tangent; isapprox_kwargs...) end end if z isa Complex @@ -153,7 +153,7 @@ function test_scalar(f, z; rtol=1e-9, atol=1e-9, fdm=_fdm, fkwargs=NamedTuple(), _, back = rrule(f, z) _, real_cotangent = back(real(Δu)) _, embedded_cotangent = back(Δu) - _check_equal(real_cotangent, embedded_cotangent; isapprox_kwargs...) + check_equal(real_cotangent, embedded_cotangent; isapprox_kwargs...) end end if Ω isa Complex @@ -186,12 +186,12 @@ function frule_test(f, xẋs::Tuple{Any, Any}...; rtol=1e-9, atol=1e-9, fdm=_fdm ẋs = last.(xẋs) Ω_ad, dΩ_ad = frule((NO_FIELDS, deepcopy(ẋs)...), f, deepcopy(xs)...; deepcopy(fkwargs)...) Ω = f(deepcopy(xs)...; deepcopy(fkwargs)...) - _check_equal(Ω_ad, Ω; isapprox_kwargs...) + check_equal(Ω_ad, Ω; isapprox_kwargs...) ẋs_is_ignored = ẋs .== nothing # Correctness testing via finite differencing. dΩ_fd = _make_jvp_call(fdm, (xs...) -> f(deepcopy(xs)...; deepcopy(fkwargs)...), xs, ẋs, ẋs_is_ignored) - _check_equal(dΩ_ad, dΩ_fd; isapprox_kwargs...) + check_equal(dΩ_ad, dΩ_fd; isapprox_kwargs...) # No tangent is passed in to test accumlation, so generate one @@ -225,7 +225,7 @@ function rrule_test(f, ȳ, xx̄s::Tuple{Any, Any}...; rtol=1e-9, atol=1e-9, fdm accumulated_x̄ = last.(xx̄s) y_ad, pullback = rrule(f, xs...; fkwargs...) y = f(xs...; fkwargs...) - _check_equal(y_ad, y; isapprox_kwargs...) # make sure primal is correct + check_equal(y_ad, y; isapprox_kwargs...) # make sure primal is correct ∂s = pullback(ȳ) ∂self = ∂s[1] @@ -241,7 +241,7 @@ function rrule_test(f, ȳ, xx̄s::Tuple{Any, Any}...; rtol=1e-9, atol=1e-9, fdm @test x̄_ad isa DoesNotExist # we said it wasn't differentiable. else # The main test of the actual deriviative being correct: - _check_equal(x̄_ad, x̄_fd; isapprox_kwargs...) + check_equal(x̄_ad, x̄_fd; isapprox_kwargs...) _check_add!!_behavour(accumulated_x̄, x̄_ad; isapprox_kwargs...) end end diff --git a/test/check_result.jl b/test/check_result.jl index b84e9c9d..1a4828f1 100644 --- a/test/check_result.jl +++ b/test/check_result.jl @@ -19,77 +19,78 @@ end - @testset "_check_equal" begin - check = ChainRulesTestUtils._check_equal + @testset "check_equal" begin @testset "possive cases" begin - check(1.0, 1.0) - check(1.0 + im, 1.0 + im) - check(1.0, 1.0+1e-10) # isapprox _behaviour - check((1.5, 2.5, 3.5), (1.5, 2.5, 3.5 + 1e-10)) + check_equal(1.0, 1.0) + check_equal(1.0 + im, 1.0 + im) + check_equal(1.0, 1.0+1e-10) # isapprox _behaviour + check_equal((1.5, 2.5, 3.5), (1.5, 2.5, 3.5 + 1e-10)) - check(Zero(), 0.0) + check_equal(Zero(), 0.0) - check([1.0, 2.0], [1.0, 2.0]) - check([[1.0], [2.0]], [[1.0], [2.0]]) + check_equal([1.0, 2.0], [1.0, 2.0]) + check_equal([[1.0], [2.0]], [[1.0], [2.0]]) - check(@thunk(10*0.1*[[1.0], [2.0]]), [[1.0], [2.0]]) + check_equal(@thunk(10*0.1*[[1.0], [2.0]]), [[1.0], [2.0]]) - check( + check_equal( Composite{Tuple{Float64, Float64}}(1.0, 2.0), Composite{Tuple{Float64, Float64}}(1.0, 2.0) ) D = Diagonal(randn(5)) - check( + check_equal( Composite{typeof(D)}(diag=D.diag), Composite{typeof(D)}(diag=D.diag) ) T = (a=1.0, b=2.0) - check( + check_equal( Composite{typeof(T)}(a=1.0), Composite{typeof(T)}(a=1.0, b=Zero()) ) - check( + check_equal( Composite{typeof(T)}(a=1.0), Composite{typeof(T)}(a=1.0+1e-10, b=Zero()) ) end @testset "negative case" begin - @test fails(()->check(1.0, 2.0)) - @test fails(()->check(1.0 + im, 1.0 - im)) - @test fails(()->check((1.5, 2.5, 3.5), (1.5, 2.5, 4.5))) + @test fails(()->check_equal(1.0, 2.0)) + @test fails(()->check_equal(1.0 + im, 1.0 - im)) + @test fails(()->check_equal((1.5, 2.5, 3.5), (1.5, 2.5, 4.5))) - @test fails(()->check(Zero(), 20.0)) - @test fails(()->check(10.0, Zero())) + @test fails(()->check_equal(Zero(), 20.0)) + @test fails(()->check_equal(10.0, Zero())) - @test fails(()->check([1.0, 2.0], [1.0, 3.9])) - @test fails(()->check([[1.0], [2.0]], [[1.1], [2.0]])) + @test fails(()->check_equal([1.0, 2.0], [1.0, 3.9])) + @test fails(()->check_equal([[1.0], [2.0]], [[1.1], [2.0]])) - @test fails(()->check(@thunk(10*[[1.0], [2.0]]), [[1.0], [2.0]])) + @test fails(()->check_equal(@thunk(10*[[1.0], [2.0]]), [[1.0], [2.0]])) end @testset "type negative" begin @test fails() do # these have different primals so should not be equal - check( + check_equal( Composite{Tuple{Float32, Float32}}(1f0, 2f0), Composite{Tuple{Float64, Float64}}(1.0, 2.0) ) end - @test fails(()->check((1.0, 2.0), Composite{Tuple{Float64, Float64}}(1.0, 2.0))) + @test fails() do + check_equal((1.0, 2.0), Composite{Tuple{Float64, Float64}}(1.0, 2.0)) + end end @testset "TestIterator" begin data = randn(3) iter1 = TestIterator(data, Base.HasLength(), Base.HasEltype()) iter2 = TestIterator(data, Base.HasLength(), Base.EltypeUnknown()) - check(iter2, iter1) + check_equal(iter2, iter1) iter3 = TestIterator(data .+ 1e-10, Base.HasLength(), Base.HasEltype()) - check(iter3, iter1) + check_equal(iter3, iter1) iter_bad = TestIterator(data .+ 010, Base.HasLength(), Base.HasEltype()) - @test fails(()->check(iter_bad, iter1)) + @test fails(()->check_equal(iter_bad, iter1)) end end end From dc1f787e1d0d0b31e0f354a82cb9e29b71529133 Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Mon, 14 Dec 2020 18:41:22 +0000 Subject: [PATCH 13/13] bump version --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index d918fd79..18ada4cc 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "ChainRulesTestUtils" uuid = "cdddcdb0-9152-4a09-a978-84456f9df70a" -version = "0.5.5" +version = "0.5.6" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"