From 5c04890c24e5f431c674a2d782adcb1469676bde Mon Sep 17 00:00:00 2001 From: Dhairya Gandhi Date: Sat, 12 Dec 2020 20:43:34 +0530 Subject: [PATCH 1/5] add isderiving --- src/lib/utils.jl | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/src/lib/utils.jl b/src/lib/utils.jl index 5d4634814..c4ad81ae4 100644 --- a/src/lib/utils.jl +++ b/src/lib/utils.jl @@ -111,3 +111,26 @@ a real. 1 0 """ hessian(f, x::AbstractArray) = forward_jacobian(x -> gradient(f, x)[1], x)[2] + +""" + isderiving() + +Check whether the current function call is happening while taking the derivative. + + + julia> function f(x) + @show isderiving() + end + + f (generic function with 1 method) + + julia> f(3) + isderiving() = false + false + + julia> gradient(f, 4) + isderiving() = true + (nothing,) +""" +isderiving() = false +@adjoint isderiving() = true, _ -> nothing From 5c585681d0d3f5fc8c8611ddaa0c3f7382fe37a5 Mon Sep 17 00:00:00 2001 From: Dhairya Gandhi Date: Sun, 13 Dec 2020 17:05:39 +0530 Subject: [PATCH 2/5] make isderiving for higher order differentiation --- src/lib/utils.jl | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/lib/utils.jl b/src/lib/utils.jl index c4ad81ae4..82e02f3db 100644 --- a/src/lib/utils.jl +++ b/src/lib/utils.jl @@ -133,4 +133,6 @@ Check whether the current function call is happening while taking the derivative (nothing,) """ isderiving() = false +isderiving(x) = false @adjoint isderiving() = true, _ -> nothing +@adjoint isderiving(x) = true, x -> (nothing,) From 89458dbeb460d38389ef52ff7ccb57e924ee770b Mon Sep 17 00:00:00 2001 From: Dhairya Gandhi Date: Sun, 13 Dec 2020 18:13:25 +0530 Subject: [PATCH 3/5] add 850 as a test --- test/tools.jl | 41 ++++++++++++++++++++++++++++++++++++++++- 1 file changed, 40 insertions(+), 1 deletion(-) diff --git a/test/tools.jl b/test/tools.jl index 921b8a56f..a27bcd451 100644 --- a/test/tools.jl +++ b/test/tools.jl @@ -22,4 +22,43 @@ deleteat!(buff, 1) @test length(buff) == 1 @test buff[1] === c -end \ No newline at end of file +end + +@testset "isderiving" begin + + function f(x) + if isderiving(x) + x^2 + else + 2x^2 + end + end + + # Test higher order derivatives + gs = gradient(4) do x + gradient(x) do y + f(y) + end[1] + end + + @test gs == (2,) + + struct Tester + cpu_offload::Float64 + end + + function Tester(p) + @show isderiving(p) + cpu_offload = isderiving(p) ? 0.0 : 0.2 + Tester(cpu_offload) + end + + function f(p) + sum(Tester(p).cpu_offload .* p) + end + + p = [1.0] + gs = gradient(f, p) + @test gs[1] == [0.] + +end From f022bbac45d4c52768b525c66cfad532d9f34cb7 Mon Sep 17 00:00:00 2001 From: Dhairya Gandhi Date: Sun, 13 Dec 2020 18:20:33 +0530 Subject: [PATCH 4/5] qualify names in test --- test/tools.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/test/tools.jl b/test/tools.jl index a27bcd451..717612284 100644 --- a/test/tools.jl +++ b/test/tools.jl @@ -27,7 +27,7 @@ end @testset "isderiving" begin function f(x) - if isderiving(x) + if Zygote.isderiving(x) x^2 else 2x^2 @@ -48,8 +48,8 @@ end end function Tester(p) - @show isderiving(p) - cpu_offload = isderiving(p) ? 0.0 : 0.2 + @show Zygote.isderiving(p) + cpu_offload = Zygote.isderiving(p) ? 0.0 : 0.2 Tester(cpu_offload) end From 453fe3a496595bc67bce6ecb32911cd1ba5fa6fe Mon Sep 17 00:00:00 2001 From: Dhairya Gandhi Date: Wed, 16 Dec 2020 16:40:15 +0530 Subject: [PATCH 5/5] Add the missing isderiving API to docstring --- src/lib/utils.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/src/lib/utils.jl b/src/lib/utils.jl index 82e02f3db..e51ff2d3c 100644 --- a/src/lib/utils.jl +++ b/src/lib/utils.jl @@ -114,6 +114,7 @@ hessian(f, x::AbstractArray) = forward_jacobian(x -> gradient(f, x)[1], x)[2] """ isderiving() + isderiving(x) Check whether the current function call is happening while taking the derivative.