From 6e580a4ee76dfbe4d8b91ddd1a5d9ef2d1d6047c Mon Sep 17 00:00:00 2001 From: Miha Zgubic Date: Mon, 18 Jan 2021 11:55:01 +0000 Subject: [PATCH 1/7] ignore dev folder --- .gitignore | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.gitignore b/.gitignore index 8ce8c635..f637f003 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,5 @@ +dev/ + # Files generated by invoking Julia with --code-coverage *.jl.cov *.jl.*.cov From 29ed778dbbdc2a1475c60e0fb938f749dc2900be Mon Sep 17 00:00:00 2001 From: Miha Zgubic Date: Mon, 18 Jan 2021 11:55:14 +0000 Subject: [PATCH 2/7] fix formatting --- src/iterator.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/iterator.jl b/src/iterator.jl index 85ec08b8..8a3072d5 100644 --- a/src/iterator.jl +++ b/src/iterator.jl @@ -10,7 +10,7 @@ The iterator wraps another iterator `data`, such as an array, that must have at many features implemented as the test iterator and have a `FiniteDifferences.to_vec` overload. By default, the iterator it has the same features as `data`. -The optional methods `eltype`, length`, and `size` are automatically defined and forwarded +The optional methods `eltype`, `length`, and `size` are automatically defined and forwarded to `data` if the type arguments indicate that they should be defined. """ struct TestIterator{T,IS,IE} From b30b80dc647c406e18aef94b8d19469c8a26b671 Mon Sep 17 00:00:00 2001 From: Miha Zgubic Date: Mon, 18 Jan 2021 12:11:07 +0000 Subject: [PATCH 3/7] check the frule and pullback return types --- src/testers.jl | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/testers.jl b/src/testers.jl index ec79fc8c..57143835 100644 --- a/src/testers.jl +++ b/src/testers.jl @@ -228,6 +228,7 @@ function frule_test(f, xẋs::Tuple{Any, Any}...; rtol::Real=1e-9, atol::Real=1e end res = frule((NO_FIELDS, deepcopy(ẋs)...), f, deepcopy(xs)...; deepcopy(fkwargs)...) res === nothing && throw(MethodError(frule, typeof((f, xs...)))) + res isa Tuple || error("The frule should return (y, ∂y), not $res.") Ω_ad, dΩ_ad = res Ω = f(deepcopy(xs)...; deepcopy(fkwargs)...) check_equal(Ω_ad, Ω; isapprox_kwargs...) @@ -280,6 +281,7 @@ function rrule_test(f, ȳ, xx̄s::Tuple{Any, Any}...; rtol::Real=1e-9, atol::Re check_inferred && _test_inferred(pullback, ȳ) ∂s = pullback(ȳ) + ∂s isa Tuple || error("The pullback must return (∂self, ∂args...), not $∂s.") ∂self = ∂s[1] x̄s_ad = ∂s[2:end] @test ∂self === NO_FIELDS # No internal fields From c29106aa80f2f19665841d7c1fc2e2c735373e60 Mon Sep 17 00:00:00 2001 From: Miha Zgubic Date: Mon, 18 Jan 2021 16:39:12 +0000 Subject: [PATCH 4/7] bump version --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 09128742..12f66bcd 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "ChainRulesTestUtils" uuid = "cdddcdb0-9152-4a09-a978-84456f9df70a" -version = "0.6.1" +version = "0.6.2" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" From abed20a0f920b47e05fdbc1fc74c90f8b999ad54 Mon Sep 17 00:00:00 2001 From: Miha Zgubic Date: Mon, 18 Jan 2021 18:01:29 +0000 Subject: [PATCH 5/7] add generic and scalar examples --- docs/src/index.md | 100 +++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 99 insertions(+), 1 deletion(-) diff --git a/docs/src/index.md b/docs/src/index.md index 92de026c..44feb351 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -5,11 +5,109 @@ [ChainRulesTestUtils.jl](https://github.com/JuliaDiff/ChainRulesTestUtils.jl) helps you test [`ChainRulesCore.frule`](http://www.juliadiff.org/ChainRulesCore.jl/dev/api.html) and [`ChainRulesCore.rrule`](http://www.juliadiff.org/ChainRulesCore.jl/dev/api.html) methods, when adding rules for your functions in your own packages. - For information about ChainRules, including how to write rules, refer to the general ChainRules Documentation: [![](https://img.shields.io/badge/docs-master-blue.svg)](https://JuliaDiff.github.io/ChainRulesCore.jl/dev) [![](https://img.shields.io/badge/docs-stable-blue.svg)](https://JuliaDiff.github.io/ChainRulesCore.jl/stable) +## Canonical example + +Let's suppose a custom transformation has been defined +``` +function two2three(a::Float64, b::Float64) + return 1.0, 2.0*a, 3.0*b +end +``` +along with the `frule` +``` +function ChainRulesCore.frule((Δf, Δa, Δb), ::typeof(two2three), a, b) + y = two2three(a, b) + ∂y = Composite{Tuple{Float64, Float64, Float64}}(Zero(), 2.0*Δa, 3.0*Δb) + return y, ∂y +end +``` +and `rrule` +``` +function ChainRulesCore.rrule(::typeof(two2three), a, b) + y = two2three(a, b) + function two2three_pullback(Ȳ) + return (NO_FIELDS, 2.0*Ȳ[2], 3.0*Ȳ[3]) + end + return y, two2three_pullback +end +``` + +The `test_frule`/`test_rrule` helper function compares the `frule`/`rrule` outputs +to the gradients obtained by finite differencing. +They can be used for any type and number of inputs and outputs. + +### Testing the `frule` + +`frule_test` takes in the function `f` and tuples `(x, ẋ)` for each function argument `x`. +The call will test the `frule` for function `f` at the point `x` in the domain. Keep +this in mind when testing discontinuous rules for functions like +[ReLU](https://en.wikipedia.org/wiki/Rectifier_(neural_networks)), which should ideally +be tested at both `x` being above and below zero. +Additionally, choosing `ẋ` in an unfortunate way (e.g. as zeros) could hide +underlying problems with the defined `frule`. + +``` +xs = (3.33, -7.77) +ẋs = (rand(), rand()) +frule_test(two2three, (xs[1], ẋs[1]), (xs[2], ẋs[2])) +``` + +### Testing the `rrule` + +`rrule_test` takes in the function `f`, sensitivities of the function outputs `ȳ`, +and tuples `(x, x̄)` for each function argument `x`. +`x̄` is the accumulated adjoint which should be set randomly. +The call will test the `rrule` for function `f` at the point `x`, and similarly to +`frule` some rules should be tested at multiple points in the domain. +Choosing `ȳ` in an unfortunate way (e.g. as zeros) could hide underlying problems with +the `rrule`. +``` +xs = (3.33, -7.77) +ȳs = (rand(), rand(), rand()) +x̄s = (rand(), rand()) +rrule_test(two2three, ȳs, (xs[1], x̄s[1]), (xs[2], x̄s[2])) +``` + +## Scalar example + +For functions with a single argument and a single output, such as e.g. `ReLU`, +``` +function relu(x::Real) + return max(0, x) +end +``` +with the `frule` +``` +function ChainRulesCore.frule((Δf, Δx), ::typeof(relu), x::Real) + y = relu(x) + dydx = x <= 0 ? zero(x) : one(x) + return y, dydx .* Δx +end +``` +and `rrule` defined, +``` +function ChainRulesCore.rrule(::typeof(relu), x::Real) + y = relu(x) + dydx = x <= 0 ? zero(x) : one(x) + function relu_pullback(Ȳ) + return (NO_FIELDS, Ȳ .* dydx) + end + return y, relu_pullback +end +``` + +`test_scalar` function is provided to test both the `frule` and the `rrule` with a single +call. As discussed, it should be tested at two different points in the domain. +``` +test_scalar(relu, 0.5) +test_scalar(relu, -0.5) +``` + + # API Documentation ```@autodocs From 4a6eafb0d7d97ebc629c20324743d1a25d9bda27 Mon Sep 17 00:00:00 2001 From: Miha Zgubic Date: Mon, 18 Jan 2021 18:05:45 +0000 Subject: [PATCH 6/7] delete a sentence --- docs/src/index.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/src/index.md b/docs/src/index.md index 44feb351..d3642095 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -101,7 +101,7 @@ end ``` `test_scalar` function is provided to test both the `frule` and the `rrule` with a single -call. As discussed, it should be tested at two different points in the domain. +call. ``` test_scalar(relu, 0.5) test_scalar(relu, -0.5) From b805973d53c017ffb1bde06d68405ce248b11659 Mon Sep 17 00:00:00 2001 From: Miha Zgubic Date: Tue, 19 Jan 2021 10:13:20 +0000 Subject: [PATCH 7/7] code review comments --- docs/Project.toml | 1 + docs/src/index.md | 113 ++++++++++++++++++++++++++++------------------ 2 files changed, 71 insertions(+), 43 deletions(-) diff --git a/docs/Project.toml b/docs/Project.toml index f3a06e61..6594306b 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -1,4 +1,5 @@ [deps] +ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a" Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" diff --git a/docs/src/index.md b/docs/src/index.md index d3642095..21bca0cd 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -12,37 +12,46 @@ For information about ChainRules, including how to write rules, refer to the gen ## Canonical example Let's suppose a custom transformation has been defined -``` -function two2three(a::Float64, b::Float64) - return 1.0, 2.0*a, 3.0*b +```jldoctest ex; output = false +function two2three(x1::Float64, x2::Float64) + return 1.0, 2.0*x1, 3.0*x2 end + +# output +two2three (generic function with 1 method) ``` along with the `frule` -``` -function ChainRulesCore.frule((Δf, Δa, Δb), ::typeof(two2three), a, b) - y = two2three(a, b) - ∂y = Composite{Tuple{Float64, Float64, Float64}}(Zero(), 2.0*Δa, 3.0*Δb) +```jldoctest ex; output = false +using ChainRulesCore + +function ChainRulesCore.frule((Δf, Δx1, Δx2), ::typeof(two2three), x1, x2) + y = two2three(x1, x2) + ∂y = Composite{Tuple{Float64, Float64, Float64}}(Zero(), 2.0*Δx1, 3.0*Δx2) return y, ∂y end +# output + ``` and `rrule` -``` -function ChainRulesCore.rrule(::typeof(two2three), a, b) - y = two2three(a, b) +```jldoctest ex; output = false +function ChainRulesCore.rrule(::typeof(two2three), x1, x2) + y = two2three(x1, x2) function two2three_pullback(Ȳ) return (NO_FIELDS, 2.0*Ȳ[2], 3.0*Ȳ[3]) end return y, two2three_pullback end +# output + ``` -The `test_frule`/`test_rrule` helper function compares the `frule`/`rrule` outputs +The [`frule_test`](@ref)/[`rrule_test`](@ref) helper function compares the `frule`/`rrule` outputs to the gradients obtained by finite differencing. They can be used for any type and number of inputs and outputs. ### Testing the `frule` -`frule_test` takes in the function `f` and tuples `(x, ẋ)` for each function argument `x`. +[`frule_test`](@ref) takes in the function `f` and tuples `(x, ẋ)` for each function argument `x`. The call will test the `frule` for function `f` at the point `x` in the domain. Keep this in mind when testing discontinuous rules for functions like [ReLU](https://en.wikipedia.org/wiki/Rectifier_(neural_networks)), which should ideally @@ -50,61 +59,79 @@ be tested at both `x` being above and below zero. Additionally, choosing `ẋ` in an unfortunate way (e.g. as zeros) could hide underlying problems with the defined `frule`. -``` -xs = (3.33, -7.77) -ẋs = (rand(), rand()) -frule_test(two2three, (xs[1], ẋs[1]), (xs[2], ẋs[2])) +```jldoctest ex; output = false +using ChainRulesTestUtils + +x1, x2 = (3.33, -7.77) +ẋ1, ẋ2 = (rand(), rand()) + +frule_test(two2three, (x1, ẋ1), (x2, ẋ2)) +# output +Test Summary: | Pass Total +Tuple{Float64,Float64,Float64}.1 | 1 1 +Test Summary: | Pass Total +Tuple{Float64,Float64,Float64}.2 | 1 1 +Test Summary: | Pass Total +Tuple{Float64,Float64,Float64}.3 | 1 1 +Test Passed ``` ### Testing the `rrule` -`rrule_test` takes in the function `f`, sensitivities of the function outputs `ȳ`, +[`rrule_test`](@ref) takes in the function `f`, sensitivities of the function outputs `ȳ`, and tuples `(x, x̄)` for each function argument `x`. -`x̄` is the accumulated adjoint which should be set randomly. +`x̄` is the accumulated adjoint which can be set arbitrarily. The call will test the `rrule` for function `f` at the point `x`, and similarly to `frule` some rules should be tested at multiple points in the domain. Choosing `ȳ` in an unfortunate way (e.g. as zeros) could hide underlying problems with the `rrule`. -``` -xs = (3.33, -7.77) +```jldoctest ex; output = false +x1, x2 = (3.33, -7.77) +x̄1, x̄2 = (rand(), rand()) ȳs = (rand(), rand(), rand()) -x̄s = (rand(), rand()) -rrule_test(two2three, ȳs, (xs[1], x̄s[1]), (xs[2], x̄s[2])) + +rrule_test(two2three, ȳs, (x1, x̄1), (x2, x̄2)) + +# output +Test Summary: | +Don't thunk only non_zero argument | No tests +Test.DefaultTestSet("Don't thunk only non_zero argument", Any[], 0, false) ``` ## Scalar example -For functions with a single argument and a single output, such as e.g. `ReLU`, -``` +For functions with a single argument and a single output, such as e.g. ReLU, +```jldoctest ex; output = false function relu(x::Real) return max(0, x) end + +# output +relu (generic function with 1 method) ``` -with the `frule` -``` -function ChainRulesCore.frule((Δf, Δx), ::typeof(relu), x::Real) - y = relu(x) - dydx = x <= 0 ? zero(x) : one(x) - return y, dydx .* Δx -end -``` -and `rrule` defined, -``` -function ChainRulesCore.rrule(::typeof(relu), x::Real) - y = relu(x) - dydx = x <= 0 ? zero(x) : one(x) - function relu_pullback(Ȳ) - return (NO_FIELDS, Ȳ .* dydx) - end - return y, relu_pullback -end +with the `frule` and `rrule` defined with the help of `@scalar_rule` macro +```jldoctest ex; output = false +@scalar_rule relu(x::Real) x <= 0 ? zero(x) : one(x) + +# output + ``` `test_scalar` function is provided to test both the `frule` and the `rrule` with a single call. -``` +```jldoctest ex; output = false test_scalar(relu, 0.5) test_scalar(relu, -0.5) + +# output +Test Summary: | Pass Total +relu at 0.5, with tangent 1.0 | 3 3 +Test Summary: | Pass Total +relu at 0.5, with cotangent 1.0 | 4 4 +Test Summary: | Pass Total +relu at -0.5, with tangent 1.0 | 3 3 +Test Summary: | Pass Total +relu at -0.5, with cotangent 1.0 | 4 4 ```