diff --git a/Project.toml b/Project.toml index ee0f36f7..479b70ad 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "ChainRulesTestUtils" uuid = "cdddcdb0-9152-4a09-a978-84456f9df70a" -version = "0.2.6" +version = "0.2.7" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" diff --git a/src/testers.jl b/src/testers.jl index 5b2ad6e1..255f6417 100644 --- a/src/testers.jl +++ b/src/testers.jl @@ -137,40 +137,22 @@ end `fkwargs` are passed to `f` as keyword arguments. All keyword arguments except for `fdm` and `fkwargs` are passed to `isapprox`. """ -function rrule_test(f, ȳ, (x, x̄)::Tuple{Any, Any}; rtol=1e-9, atol=1e-9, fdm=_fdm, fkwargs=NamedTuple(), kwargs...) - _ensure_not_running_on_functor(f, "rrule_test") - - # Check correctness of evaluation. - fx, pullback = rrule(f, x; fkwargs...) - @test collect(fx) ≈ collect(f(x; fkwargs...)) # use collect so can do vector equality - (∂self, x̄_ad) = if fx isa Tuple - # If the function returned multiple values, - # then it must have multiple seeds for propagating backwards - pullback(ȳ...) - else - pullback(ȳ) - end - - @test ∂self === NO_FIELDS # No internal fields - # Correctness testing via finite differencing. - x̄_fd = only(j′vp(fdm, x -> f(x; fkwargs...), ȳ, x)) # j′vp returns a tuple, but `f` is a unary function. - @test isapprox(x̄_ad, x̄_fd; rtol=rtol, atol=atol, kwargs...) -end - -# case where `f` takes multiple arguments function rrule_test(f, ȳ, xx̄s::Tuple{Any, Any}...; rtol=1e-9, atol=1e-9, fdm=_fdm, fkwargs=NamedTuple(), kwargs...) _ensure_not_running_on_functor(f, "rrule_test") # Check correctness of evaluation. xs, x̄s = collect(zip(xx̄s...)) - y, pullback = rrule(f, xs...; fkwargs...) - @test f(xs...; fkwargs...) == y - + y_ad, pullback = rrule(f, xs...; fkwargs...) + y = f(xs...; fkwargs...) + # use collect so can do vector equality + @test isapprox(collect(y_ad), collect(y); rtol=rtol, atol=atol) @assert !(isa(ȳ, Thunk)) - ∂s = pullback(ȳ) + # If the function returned multiple values, + # then it must have multiple seeds for propagating backwards + ∂s = (y_ad isa Tuple) ? pullback(ȳ...) : pullback(ȳ) ∂self = ∂s[1] x̄s_ad = ∂s[2:end] - @test ∂self === NO_FIELDS + @test ∂self === NO_FIELDS # No internal fields # Correctness testing via finite differencing. x̄s_fd = _make_fdm_call(fdm, (xs...) -> f(xs...; fkwargs...), ȳ, xs, x̄s .== nothing)