diff --git a/src/lib/utils.jl b/src/lib/utils.jl index 5d4634814..e51ff2d3c 100644 --- a/src/lib/utils.jl +++ b/src/lib/utils.jl @@ -111,3 +111,29 @@ a real. 1 0 """ hessian(f, x::AbstractArray) = forward_jacobian(x -> gradient(f, x)[1], x)[2] + +""" + isderiving() + isderiving(x) + +Check whether the current function call is happening while taking the derivative. + + + julia> function f(x) + @show isderiving() + end + + f (generic function with 1 method) + + julia> f(3) + isderiving() = false + false + + julia> gradient(f, 4) + isderiving() = true + (nothing,) +""" +isderiving() = false +isderiving(x) = false +@adjoint isderiving() = true, _ -> nothing +@adjoint isderiving(x) = true, x -> (nothing,) diff --git a/test/tools.jl b/test/tools.jl index 921b8a56f..717612284 100644 --- a/test/tools.jl +++ b/test/tools.jl @@ -22,4 +22,43 @@ deleteat!(buff, 1) @test length(buff) == 1 @test buff[1] === c -end \ No newline at end of file +end + +@testset "isderiving" begin + + function f(x) + if Zygote.isderiving(x) + x^2 + else + 2x^2 + end + end + + # Test higher order derivatives + gs = gradient(4) do x + gradient(x) do y + f(y) + end[1] + end + + @test gs == (2,) + + struct Tester + cpu_offload::Float64 + end + + function Tester(p) + @show Zygote.isderiving(p) + cpu_offload = Zygote.isderiving(p) ? 0.0 : 0.2 + Tester(cpu_offload) + end + + function f(p) + sum(Tester(p).cpu_offload .* p) + end + + p = [1.0] + gs = gradient(f, p) + @test gs[1] == [0.] + +end