Skip to content

Commit

Permalink
Merge pull request #57 from JuliaDiff/ox/inplace
Browse files Browse the repository at this point in the history
Overhall equality checks, to give better failure messages
  • Loading branch information
oxinabox authored Dec 14, 2020
2 parents 9847d0b + dc1f787 commit 7bbd60e
Show file tree
Hide file tree
Showing 10 changed files with 213 additions and 74 deletions.
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "ChainRulesTestUtils"
uuid = "cdddcdb0-9152-4a09-a978-84456f9df70a"
version = "0.5.5"
version = "0.5.6"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Expand All @@ -11,7 +11,7 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[compat]
ChainRulesCore = "0.9.1"
ChainRulesCore = "0.9.13"
Compat = "3"
FiniteDifferences = "0.11.2"
julia = "1"
6 changes: 4 additions & 2 deletions src/ChainRulesTestUtils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,14 @@ using Test
const _fdm = central_fdm(5, 1)

export TestIterator
export test_scalar, frule_test, rrule_test, generate_well_conditioned_matrix
export check_equal, test_scalar, frule_test, rrule_test, generate_well_conditioned_matrix


include("generate_tangent.jl")
include("isapprox.jl")
include("data_generation.jl")
include("iterator.jl")
include("check_result.jl")
include("testers.jl")

include("deprecated.jl")
end # module
83 changes: 82 additions & 1 deletion src/check_result.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,84 @@
# For once you have the sensitivity by two methods (e.g both finite-differencing and AD)
# the code here checks it is correct.
# Goal is to only call `@isapprox` on things that render well
# Note that this must work well both on Differential types and Primal types

"""
check_equal(actual, expected; kwargs...)
`@test`'s that `actual ≈ expected`, but breaks up data such that human readable results
are shown on failures.
Understands things like `unthunk`ing `ChainRuleCore.Thunk`s, etc.
All keyword arguments are passed to `isapprox`.
"""
function check_equal(
actual::Union{AbstractArray{<:Number}, Number},
expected::Union{AbstractArray{<:Number}, Number};
kwargs...
)
@test isapprox(actual, expected; kwargs...)
end

for (T1, T2) in ((AbstractThunk, Any), (AbstractThunk, AbstractThunk), (Any, AbstractThunk))
@eval function check_equal(actual::$T1, expected::$T2; kwargs...)
check_equal(unthunk(actual), unthunk(expected); kwargs...)
end
end


function check_equal(actual::AbstractArray, expected::AbstractArray; kwargs...)
if actual == expected # if equal then we don't need to be smarter
@test true
else
@test eachindex(actual) == eachindex(expected)
@testset "$(typeof(actual))[$ii]" for ii in eachindex(actual)
check_equal(actual[ii], expected[ii]; kwargs...)
end
end
end

function check_equal(actual::Composite{P}, expected::Composite{P}; kwargs...) where P
if actual == expected # if equal then we don't need to be smarter
@test true
else
all_keys = union(keys(actual), keys(expected))
@testset "$P.$ii" for ii in all_keys
check_equal(getproperty(actual, ii), getproperty(expected, ii); kwargs...)
end
end
end

function check_equal(
::Composite{ActualPrimal}, expected::Composite{ExpectedPrimal}
) where {ActualPrimal, ExpectedPrimal}
# this will certainly fail as we have another dispatch for that, but this will give as
# good error message
@test ActualPrimal === ExpectedPrimal
end

# This catches comparisons of Composites and Tuples/NamedTuple
# and gives a error messaage complaining about that
check_equal(::C, expected::T) where {C<:Composite, T} = @test C === T
check_equal(::T, expected::C) where {C<:Composite, T} = @test C === T


check_equal(::Zero, x; kwargs...) = check_equal(zero(x), x; kwargs...)
check_equal(x, ::Zero; kwargs...) = check_equal(x, zero(x); kwargs...)
check_equal(x::Zero, y::Zero; kwargs...) = @test true

# Generic fallback, probably a tuple or something
function check_equal(actual::A, expected::E; kwargs...) where {A, E}
if actual == expected # if equal then we don't need to be smarter
@test true
else
c_actual = collect(actual)
c_expected = collect(expected)
if (c_actual isa A) && (c_expected isa E) # prevent stack-overflow
throw(MethodError, check_equal, (actual, expected))
end
check_equal(c_actual, c_expected; kwargs...)
end
end

"""
_check_add!!_behavour(acc, val)
Expand All @@ -15,5 +96,5 @@ function _check_add!!_behavour(acc, val; kwargs...)
# e.g. if it is immutable. We do test the `add!!` return value.
# That is what people should rely on. The mutation is just to save allocations.
acc_mutated = deepcopy(acc) # prevent this test changing others
@test isapprox(add!!(acc_mutated, val), acc + val; kwargs...)
check_equal(add!!(acc_mutated, val), acc + val; kwargs...)
end
4 changes: 3 additions & 1 deletion src/isapprox.jl → src/deprecated.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
# TODO: reconsider these https://github.com/JuliaDiff/ChainRulesTestUtils.jl/issues/7
# TODO remove these in version 0.6
# We are silently deprecating them as there is no alternative we are providing

Base.isapprox(a, b::Union{AbstractZero, AbstractThunk}; kwargs...) = isapprox(b, a; kwargs...)
Base.isapprox(d_ad::AbstractThunk, d_fd; kwargs...) = isapprox(extern(d_ad), d_fd; kwargs...)
Base.isapprox(d_ad::DoesNotExist, d_fd; kwargs...) = error("Tried to differentiate w.r.t. a `DoesNotExist`")
Expand Down
9 changes: 0 additions & 9 deletions src/iterator.jl
Original file line number Diff line number Diff line change
Expand Up @@ -51,15 +51,6 @@ end

# For testing purposes:

Base.isapprox(iter1::TestIterator, iter2::TestIterator) = false
function Base.isapprox(
iter1::TestIterator{T1,IS,IE},
iter2::TestIterator{T2,IS,IE};
kwargs...,
) where {T1,T2,IS,IE}
return isapprox(iter1.data, iter2.data; kwargs...)
end

function rand_tangent(rng::AbstractRNG, x::TestIterator{<:Any,IS,IE}) where {IS,IE}
∂data = rand_tangent(rng, x.data)
return TestIterator{typeof(∂data),IS,IE}(∂data)
Expand Down
92 changes: 44 additions & 48 deletions src/testers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,10 @@ at input point `z` to confirm that there are correct `frule` and `rrule`s provid
All keyword arguments except for `fdm` and `fkwargs` are passed to `isapprox`.
"""
function test_scalar(f, z; rtol=1e-9, atol=1e-9, fdm=_fdm, fkwargs=NamedTuple(), kwargs...)
# To simplify some of the calls we make later lets group the kwargs for reuse
rule_test_kwargs = (; rtol=rtol, atol=atol, fdm=fdm, fkwargs=fkwargs, kwargs...)
isapprox_kwargs = (; rtol=rtol, atol=atol, kwargs...)

_ensure_not_running_on_functor(f, "test_scalar")
# z = x + im * y
# Ω = u(x, y) + im * v(x, y)
Expand All @@ -123,48 +127,40 @@ function test_scalar(f, z; rtol=1e-9, atol=1e-9, fdm=_fdm, fkwargs=NamedTuple(),
Δx = one(z)
@testset "$f at $z, with tangent $Δx" begin
# check ∂u_∂x and (if Ω is complex) ∂v_∂x via forward mode
frule_test(f, (z, Δx); rtol=rtol, atol=atol, fdm=fdm, fkwargs=fkwargs, kwargs...)
frule_test(f, (z, Δx); rule_test_kwargs...)
if z isa Complex
# check that same tangent is produced for tangent 1.0 and 1.0 + 0.0im
@test isapprox(
frule((Zero(), real(Δx)), f, z; fkwargs...)[2],
frule((Zero(), Δx), f, z; fkwargs...)[2],
rtol=rtol,
atol=atol,
kwargs...,
)
_, real_tangent = frule((Zero(), real(Δx)), f, z; fkwargs...)
_, embedded_tangent = frule((Zero(), Δx), f, z; fkwargs...)
check_equal(real_tangent, embedded_tangent; isapprox_kwargs...)
end
end
if z isa Complex
Δy = one(z) * im
@testset "$f at $z, with tangent $Δy" begin
# check ∂u_∂y and (if Ω is complex) ∂v_∂y via forward mode
frule_test(f, (z, Δy); rtol=rtol, atol=atol, fdm=fdm, fkwargs=fkwargs, kwargs...)
frule_test(f, (z, Δy); rule_test_kwargs...)
end
end

# test jacobian transpose using reverse mode
Δu = one(Ω)
@testset "$f at $z, with cotangent $Δu" begin
# check ∂u_∂x and (if z is complex) ∂u_∂y via reverse mode
rrule_test(f, Δu, (z, Δx); rtol=rtol, atol=atol, fdm=fdm, fkwargs=fkwargs, kwargs...)
rrule_test(f, Δu, (z, Δx); rule_test_kwargs...)
if Ω isa Complex
# check that same cotangent is produced for cotangent 1.0 and 1.0 + 0.0im
back = rrule(f, z)[2]
@test isapprox(
extern(back(real(Δu))[2]),
extern(back(Δu)[2]),
rtol=rtol,
atol=atol,
kwargs...,
)
_, back = rrule(f, z)
_, real_cotangent = back(real(Δu))
_, embedded_cotangent = back(Δu)
check_equal(real_cotangent, embedded_cotangent; isapprox_kwargs...)
end
end
if Ω isa Complex
Δv = one(Ω) * im
@testset "$f at $z, with cotangent $Δv" begin
# check ∂v_∂x and (if z is complex) ∂v_∂y via reverse mode
rrule_test(f, Δv, (z, Δx); rtol=rtol, atol=atol, fdm=fdm, fkwargs=fkwargs, kwargs...)
rrule_test(f, Δv, (z, Δx); rule_test_kwargs...)
end
end
end
Expand All @@ -181,25 +177,22 @@ end
All keyword arguments except for `fdm` and `fkwargs` are passed to `isapprox`.
"""
function frule_test(f, xẋs::Tuple{Any, Any}...; rtol=1e-9, atol=1e-9, fdm=_fdm, fkwargs=NamedTuple(), kwargs...)
# To simplify some of the calls we make later lets group the kwargs for reuse
isapprox_kwargs = (; rtol=rtol, atol=atol, kwargs...)

_ensure_not_running_on_functor(f, "frule_test")
xs, ẋs = first.(xẋs), last.(xẋs)

xs = first.(xẋs)
ẋs = last.(xẋs)
Ω_ad, dΩ_ad = frule((NO_FIELDS, deepcopy(ẋs)...), f, deepcopy(xs)...; deepcopy(fkwargs)...)
Ω = f(deepcopy(xs)...; deepcopy(fkwargs)...)
# if equality check fails, check approximate equality
# use collect so can do vector equality
# TODO: add isapprox replacement that works for more types
@test Ω_ad == Ω || isapprox(collect(Ω_ad), collect(Ω); rtol=rtol, atol=atol)
check_equal(Ω_ad, Ω; isapprox_kwargs...)

ẋs_is_ignored = ẋs .== nothing
# Correctness testing via finite differencing.
dΩ_fd = _make_jvp_call(fdm, (xs...) -> f(deepcopy(xs)...; deepcopy(fkwargs)...), xs, ẋs, ẋs_is_ignored)
@test isapprox(
collect(extern.(dΩ_ad)), # Use collect so can use vector equality
collect(dΩ_fd);
rtol=rtol,
atol=atol,
kwargs...
)
check_equal(dΩ_ad, dΩ_fd; isapprox_kwargs...)


# No tangent is passed in to test accumlation, so generate one
# See: https://github.com/JuliaDiff/ChainRulesTestUtils.jl/issues/66
Expand All @@ -222,43 +215,46 @@ end
All keyword arguments except for `fdm` and `fkwargs` are passed to `isapprox`.
"""
function rrule_test(f, ȳ, xx̄s::Tuple{Any, Any}...; rtol=1e-9, atol=1e-9, fdm=_fdm, fkwargs=NamedTuple(), kwargs...)
# To simplify some of the calls we make later lets group the kwargs for reuse
isapprox_kwargs = (; rtol=rtol, atol=atol, kwargs...)

_ensure_not_running_on_functor(f, "rrule_test")

# Check correctness of evaluation.
xs = first.(xx̄s)
x̄s_acc = last.(xx̄s)
accumulated_x̄ = last.(xx̄s)
y_ad, pullback = rrule(f, xs...; fkwargs...)
y = f(xs...; fkwargs...)
# if equality check fails, check approximate equality
# use collect so can do vector equality
# TODO: add isapprox replacement that works for more types
@test y_ad == y || isapprox(collect(y_ad), collect(y); rtol=rtol, atol=atol)
@assert !(isa(ȳ, Thunk))
check_equal(y_ad, y; isapprox_kwargs...) # make sure primal is correct

∂s = pullback(ȳ)
∂self = ∂s[1]
x̄s_ad = ∂s[2:end]
@test ∂self === NO_FIELDS # No internal fields

x̄s_is_dne = x̄s_acc .== nothing
# Correctness testing via finite differencing.
x̄s_is_dne = accumulated_x̄ .== nothing
x̄s_fd = _make_j′vp_call(fdm, (xs...) -> f(xs...; fkwargs...), ȳ, xs, x̄s_is_dne)
for (x̄_acc, x̄_ad, x̄_fd) in zip(x̄s_acc, x̄s_ad, x̄s_fd)
if x̄_acc === nothing
for (accumulated_x̄, x̄_ad, x̄_fd) in zip(accumulated_x̄, x̄s_ad, x̄s_fd)
if accumulated_x̄ === nothing # then we marked this argument as not differentiable
@assert x̄_fd === nothing # this is how `_make_j′vp_call` works
@test x̄_ad isa DoesNotExist # we said it wasn't differentiable.
else
# The main test of the actual deriviative being correct:
@test isapprox(x̄_ad, x̄_fd; rtol=rtol, atol=atol, kwargs...)

_check_add!!_behavour(x̄_acc, x̄_ad; rtol=rtol, atol=atol, kwargs...)
check_equal(x̄_ad, x̄_fd; isapprox_kwargs...)
_check_add!!_behavour(accumulated_x̄, x̄_ad; isapprox_kwargs...)
end
end

if count(!, x̄s_is_dne) == 1
# for functions with pullbacks that only produce a single non-DNE adjoint, that
# single adjoint should not be `Thunk`ed. InplaceableThunk is fine.
i = findfirst(!, x̄s_is_dne)
@test !(isa(x̄s_ad[i], Thunk))
check_thunking_is_appropriate(x̄s_ad)
end

function check_thunking_is_appropriate(x̄s)
@testset "Don't thunk only non_zero argument" begin
num_zeros = count(x->x isa AbstractZero, x̄s)
num_thunks = count(x->x isa Thunk, x̄s)
if num_zeros + num_thunks == length(x̄s)
@test num_thunks !== 1
end
end
end
Loading

2 comments on commit 7bbd60e

@oxinabox
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/26374

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.5.6 -m "<description of version>" 7bbd60eb21ee5a42954a35587a55605dc8992ba2
git push origin v0.5.6

Please sign in to comment.