Skip to content

Commit

Permalink
Merge #853
Browse files Browse the repository at this point in the history
853: Add isderiving utility r=DhairyaLGandhi a=DhairyaLGandhi

cc @ChrisRackauckas who wanted a utility like this

Co-authored-by: Dhairya Gandhi <[email protected]>
Co-authored-by: Dhairya Gandhi <[email protected]>
  • Loading branch information
3 people authored Jan 29, 2021
2 parents 86c061c + 453fe3a commit 4e26772
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 1 deletion.
26 changes: 26 additions & 0 deletions src/lib/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -111,3 +111,29 @@ a real.
1 0
"""
hessian(f, x::AbstractArray) = forward_jacobian(x -> gradient(f, x)[1], x)[2]

"""
isderiving()
isderiving(x)
Check whether the current function call is happening while taking the derivative.
julia> function f(x)
@show isderiving()
end
f (generic function with 1 method)
julia> f(3)
isderiving() = false
false
julia> gradient(f, 4)
isderiving() = true
(nothing,)
"""
isderiving() = false
isderiving(x) = false
@adjoint isderiving() = true, _ -> nothing
@adjoint isderiving(x) = true, x -> (nothing,)
41 changes: 40 additions & 1 deletion test/tools.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,4 +22,43 @@
deleteat!(buff, 1)
@test length(buff) == 1
@test buff[1] === c
end
end

@testset "isderiving" begin

function f(x)
if Zygote.isderiving(x)
x^2
else
2x^2
end
end

# Test higher order derivatives
gs = gradient(4) do x
gradient(x) do y
f(y)
end[1]
end

@test gs == (2,)

struct Tester
cpu_offload::Float64
end

function Tester(p)
@show Zygote.isderiving(p)
cpu_offload = Zygote.isderiving(p) ? 0.0 : 0.2
Tester(cpu_offload)
end

function f(p)
sum(Tester(p).cpu_offload .* p)
end

p = [1.0]
gs = gradient(f, p)
@test gs[1] == [0.]

end

0 comments on commit 4e26772

Please sign in to comment.