diff --git a/.gitignore b/.gitignore index 8ce8c635..f637f003 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,5 @@ +dev/ + # Files generated by invoking Julia with --code-coverage *.jl.cov *.jl.*.cov diff --git a/Project.toml b/Project.toml index 09128742..12f66bcd 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "ChainRulesTestUtils" uuid = "cdddcdb0-9152-4a09-a978-84456f9df70a" -version = "0.6.1" +version = "0.6.2" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" diff --git a/docs/Project.toml b/docs/Project.toml index f3a06e61..6594306b 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -1,4 +1,5 @@ [deps] +ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a" Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" diff --git a/docs/src/index.md b/docs/src/index.md index 92de026c..21bca0cd 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -5,11 +5,136 @@ [ChainRulesTestUtils.jl](https://github.com/JuliaDiff/ChainRulesTestUtils.jl) helps you test [`ChainRulesCore.frule`](http://www.juliadiff.org/ChainRulesCore.jl/dev/api.html) and [`ChainRulesCore.rrule`](http://www.juliadiff.org/ChainRulesCore.jl/dev/api.html) methods, when adding rules for your functions in your own packages. - For information about ChainRules, including how to write rules, refer to the general ChainRules Documentation: [![](https://img.shields.io/badge/docs-master-blue.svg)](https://JuliaDiff.github.io/ChainRulesCore.jl/dev) [![](https://img.shields.io/badge/docs-stable-blue.svg)](https://JuliaDiff.github.io/ChainRulesCore.jl/stable) +## Canonical example + +Let's suppose a custom transformation has been defined +```jldoctest ex; output = false +function two2three(x1::Float64, x2::Float64) + return 1.0, 2.0*x1, 3.0*x2 +end + +# output +two2three (generic function with 1 method) +``` +along with the `frule` +```jldoctest ex; output = false +using ChainRulesCore + +function ChainRulesCore.frule((Δf, Δx1, Δx2), ::typeof(two2three), x1, x2) + y = two2three(x1, x2) + ∂y = Composite{Tuple{Float64, Float64, Float64}}(Zero(), 2.0*Δx1, 3.0*Δx2) + return y, ∂y +end +# output + +``` +and `rrule` +```jldoctest ex; output = false +function ChainRulesCore.rrule(::typeof(two2three), x1, x2) + y = two2three(x1, x2) + function two2three_pullback(Ȳ) + return (NO_FIELDS, 2.0*Ȳ[2], 3.0*Ȳ[3]) + end + return y, two2three_pullback +end +# output + +``` + +The [`frule_test`](@ref)/[`rrule_test`](@ref) helper function compares the `frule`/`rrule` outputs +to the gradients obtained by finite differencing. +They can be used for any type and number of inputs and outputs. + +### Testing the `frule` + +[`frule_test`](@ref) takes in the function `f` and tuples `(x, ẋ)` for each function argument `x`. +The call will test the `frule` for function `f` at the point `x` in the domain. Keep +this in mind when testing discontinuous rules for functions like +[ReLU](https://en.wikipedia.org/wiki/Rectifier_(neural_networks)), which should ideally +be tested at both `x` being above and below zero. +Additionally, choosing `ẋ` in an unfortunate way (e.g. as zeros) could hide +underlying problems with the defined `frule`. + +```jldoctest ex; output = false +using ChainRulesTestUtils + +x1, x2 = (3.33, -7.77) +ẋ1, ẋ2 = (rand(), rand()) + +frule_test(two2three, (x1, ẋ1), (x2, ẋ2)) +# output +Test Summary: | Pass Total +Tuple{Float64,Float64,Float64}.1 | 1 1 +Test Summary: | Pass Total +Tuple{Float64,Float64,Float64}.2 | 1 1 +Test Summary: | Pass Total +Tuple{Float64,Float64,Float64}.3 | 1 1 +Test Passed +``` + +### Testing the `rrule` + +[`rrule_test`](@ref) takes in the function `f`, sensitivities of the function outputs `ȳ`, +and tuples `(x, x̄)` for each function argument `x`. +`x̄` is the accumulated adjoint which can be set arbitrarily. +The call will test the `rrule` for function `f` at the point `x`, and similarly to +`frule` some rules should be tested at multiple points in the domain. +Choosing `ȳ` in an unfortunate way (e.g. as zeros) could hide underlying problems with +the `rrule`. +```jldoctest ex; output = false +x1, x2 = (3.33, -7.77) +x̄1, x̄2 = (rand(), rand()) +ȳs = (rand(), rand(), rand()) + +rrule_test(two2three, ȳs, (x1, x̄1), (x2, x̄2)) + +# output +Test Summary: | +Don't thunk only non_zero argument | No tests +Test.DefaultTestSet("Don't thunk only non_zero argument", Any[], 0, false) +``` + +## Scalar example + +For functions with a single argument and a single output, such as e.g. ReLU, +```jldoctest ex; output = false +function relu(x::Real) + return max(0, x) +end + +# output +relu (generic function with 1 method) +``` +with the `frule` and `rrule` defined with the help of `@scalar_rule` macro +```jldoctest ex; output = false +@scalar_rule relu(x::Real) x <= 0 ? zero(x) : one(x) + +# output + +``` + +`test_scalar` function is provided to test both the `frule` and the `rrule` with a single +call. +```jldoctest ex; output = false +test_scalar(relu, 0.5) +test_scalar(relu, -0.5) + +# output +Test Summary: | Pass Total +relu at 0.5, with tangent 1.0 | 3 3 +Test Summary: | Pass Total +relu at 0.5, with cotangent 1.0 | 4 4 +Test Summary: | Pass Total +relu at -0.5, with tangent 1.0 | 3 3 +Test Summary: | Pass Total +relu at -0.5, with cotangent 1.0 | 4 4 +``` + + # API Documentation ```@autodocs diff --git a/src/iterator.jl b/src/iterator.jl index 85ec08b8..8a3072d5 100644 --- a/src/iterator.jl +++ b/src/iterator.jl @@ -10,7 +10,7 @@ The iterator wraps another iterator `data`, such as an array, that must have at many features implemented as the test iterator and have a `FiniteDifferences.to_vec` overload. By default, the iterator it has the same features as `data`. -The optional methods `eltype`, length`, and `size` are automatically defined and forwarded +The optional methods `eltype`, `length`, and `size` are automatically defined and forwarded to `data` if the type arguments indicate that they should be defined. """ struct TestIterator{T,IS,IE} diff --git a/src/testers.jl b/src/testers.jl index ec79fc8c..57143835 100644 --- a/src/testers.jl +++ b/src/testers.jl @@ -228,6 +228,7 @@ function frule_test(f, xẋs::Tuple{Any, Any}...; rtol::Real=1e-9, atol::Real=1e end res = frule((NO_FIELDS, deepcopy(ẋs)...), f, deepcopy(xs)...; deepcopy(fkwargs)...) res === nothing && throw(MethodError(frule, typeof((f, xs...)))) + res isa Tuple || error("The frule should return (y, ∂y), not $res.") Ω_ad, dΩ_ad = res Ω = f(deepcopy(xs)...; deepcopy(fkwargs)...) check_equal(Ω_ad, Ω; isapprox_kwargs...) @@ -280,6 +281,7 @@ function rrule_test(f, ȳ, xx̄s::Tuple{Any, Any}...; rtol::Real=1e-9, atol::Re check_inferred && _test_inferred(pullback, ȳ) ∂s = pullback(ȳ) + ∂s isa Tuple || error("The pullback must return (∂self, ∂args...), not $∂s.") ∂self = ∂s[1] x̄s_ad = ∂s[2:end] @test ∂self === NO_FIELDS # No internal fields