Skip to content

Commit

Permalink
Follow new complex rule conventions (#44)
Browse files Browse the repository at this point in the history
* Bump version number of ChainRulesCore

* Test derivative in scalar rrule is conjugated

* Document assumptions of test_scalar

* Test test_scalar passes for conjugated rrule

* Bump FD version bound

* Increment version number

* Reimplement test_scalar to check Jacobian

* Pass float not int

* Forward kwargs

* Check real and complex 1 give approx same result

* Add comment explaining thunk usage
  • Loading branch information
sethaxen authored Jun 27, 2020
1 parent 7954800 commit 171ff19
Show file tree
Hide file tree
Showing 3 changed files with 77 additions and 25 deletions.
6 changes: 3 additions & 3 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "ChainRulesTestUtils"
uuid = "cdddcdb0-9152-4a09-a978-84456f9df70a"
version = "0.3.1"
version = "0.4.0"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Expand All @@ -11,7 +11,7 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[compat]
ChainRulesCore = "0.8"
ChainRulesCore = "0.9"
Compat = "3"
FiniteDifferences = "0.9, 0.10"
FiniteDifferences = "0.10"
julia = "1"
75 changes: 54 additions & 21 deletions src/testers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -60,38 +60,71 @@ function _make_fdm_call(fdm, f, ȳ, xs, ignores)
end

"""
test_scalar(f, x; rtol=1e-9, atol=1e-9, fdm=central_fdm(5, 1), fkwargs=NamedTuple(), kwargs...)
test_scalar(f, z; rtol=1e-9, atol=1e-9, fdm=central_fdm(5, 1), fkwargs=NamedTuple(), kwargs...)
Given a function `f` with scalar input and scalar output, perform finite differencing checks,
at input point `x` to confirm that there are correct `frule` and `rrule`s provided.
at input point `z` to confirm that there are correct `frule` and `rrule`s provided.
# Arguments
- `f`: Function for which the `frule` and `rrule` should be tested.
- `x`: input at which to evaluate `f` (should generally be set to an arbitary point in the domain).
- `z`: input at which to evaluate `f` (should generally be set to an arbitary point in the domain).
`fkwargs` are passed to `f` as keyword arguments.
All keyword arguments except for `fdm` and `fkwargs` are passed to `isapprox`.
"""
function test_scalar(f, x; rtol=1e-9, atol=1e-9, fdm=_fdm, fkwargs=NamedTuple(), kwargs...)
function test_scalar(f, z; rtol=1e-9, atol=1e-9, fdm=_fdm, fkwargs=NamedTuple(), kwargs...)
_ensure_not_running_on_functor(f, "test_scalar")
# z = x + im * y
# Ω = u(x, y) + im * v(x, y)
Ω = f(z; fkwargs...)

# test jacobian using forward mode
Δx = one(z)
@testset "$f at $z, with tangent $Δx" begin
# check ∂u_∂x and (if Ω is complex) ∂v_∂x via forward mode
frule_test(f, (z, Δx); rtol=rtol, atol=atol, fdm=fdm, fkwargs=fkwargs, kwargs...)
if z isa Complex
# check that same tangent is produced for tangent 1.0 and 1.0 + 0.0im
@test isapprox(
frule((Zero(), real(Δx)), f, z; fkwargs...)[2],
frule((Zero(), Δx), f, z; fkwargs...)[2],
rtol=rtol,
atol=atol,
kwargs...,
)
end
end
if z isa Complex
Δy = one(z) * im
@testset "$f at $z, with tangent $Δy" begin
# check ∂u_∂y and (if Ω is complex) ∂v_∂y via forward mode
frule_test(f, (z, Δy); rtol=rtol, atol=atol, fdm=fdm, fkwargs=fkwargs, kwargs...)
end
end

r_res = rrule(f, x; fkwargs...)
f_res = frule((Zero(), 1), f, x; fkwargs...)
@test r_res !== nothing # Check the rule was defined
@test f_res !== nothing
r_fx, prop_rule = r_res
f_fx, f_∂x = f_res
@testset "$f at $x, $(nameof(rule))" for (rule, fx, ∂x) in (
(rrule, r_fx, prop_rule(1)),
(frule, f_fx, f_∂x)
)
@test fx == f(x; fkwargs...) # Check we still get the normal value, right

if rule == rrule
∂self, ∂x = ∂x
@test ∂self === NO_FIELDS
# test jacobian transpose using reverse mode
Δu = one(Ω)
@testset "$f at $z, with cotangent $Δu" begin
# check ∂u_∂x and (if z is complex) ∂u_∂y via reverse mode
rrule_test(f, Δu, (z, Δx); rtol=rtol, atol=atol, fdm=fdm, fkwargs=fkwargs, kwargs...)
if Ω isa Complex
# check that same cotangent is produced for cotangent 1.0 and 1.0 + 0.0im
back = rrule(f, z)[2]
@test isapprox(
extern(back(real(Δu))[2]),
extern(back(Δu)[2]),
rtol=rtol,
atol=atol,
kwargs...,
)
end
end
if Ω isa Complex
Δv = one(Ω) * im
@testset "$f at $z, with cotangent $Δv" begin
# check ∂v_∂x and (if z is complex) ∂v_∂y via reverse mode
rrule_test(f, Δv, (z, Δx); rtol=rtol, atol=atol, fdm=fdm, fkwargs=fkwargs, kwargs...)
end
@test isapprox(∂x, fdm(x -> f(x; fkwargs...), x); rtol=rtol, atol=atol, kwargs...)
end
end

Expand Down Expand Up @@ -147,7 +180,7 @@ function rrule_test(f, ȳ, xx̄s::Tuple{Any, Any}...; rtol=1e-9, atol=1e-9, fdm
# use collect so can do vector equality
@test isapprox(collect(y_ad), collect(y); rtol=rtol, atol=atol)
@assert !(isa(ȳ, Thunk))

∂s = pullback(ȳ)
∂self = ∂s[1]
x̄s_ad = ∂s[2:end]
Expand Down
21 changes: 20 additions & 1 deletion test/testers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,13 @@ futestkws(x; err = true) = err ? error() : x

fbtestkws(x, y; err = true) = err ? error() : x

sinconj(x) = sin(x)

@testset "testers.jl" begin
@testset "test_scalar" begin
double(x) = 2x
@scalar_rule(double(x), 2)
test_scalar(double, 2)
test_scalar(double, 2.0)
end

@testset "unary: identity(x)" begin
Expand All @@ -30,6 +32,23 @@ fbtestkws(x, y; err = true) = err ? error() : x
end
end

@testset "test derivative conjugated in pullback" begin
ChainRulesCore.frule((_, Δx), ::typeof(sinconj), x) = (sin(x), cos(x) * Δx)

# define rrule using ChainRulesCore's v0.9.0 convention, conjugating the derivative
# in the rrule
function ChainRulesCore.rrule(::typeof(sinconj), x)
# usually we would not thunk for a single output, because it will of course be
# used, but we do here to ensure that test_scalar works even if a scalar rrule
# thunks
sinconj_pullback(ΔΩ) = (NO_FIELDS, @thunk(conj(cos(x)) * ΔΩ))
return sin(x), sinconj_pullback
end

rrule_test(sinconj, randn(ComplexF64), (randn(ComplexF64), randn(ComplexF64)))
test_scalar(sinconj, randn(ComplexF64))
end

@testset "binary: fst(x, y)" begin
fst(x, y) = x
ChainRulesCore.frule((_, dx, dy), ::typeof(fst), x, y) = (x, dx)
Expand Down

2 comments on commit 171ff19

@sethaxen
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/17099

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.4.0 -m "<description of version>" 171ff19e63b6c560cac0b3247a7677bfd80157bf
git push origin v0.4.0

Please sign in to comment.