Skip to content

Commit

Permalink
Add dual number–based second derivatives for ForwardDiff (#310)
Browse files Browse the repository at this point in the history
* Add dual number-based implementations for ForwardDiff scalar derivatives

* Use functions from utils.jl

* Fixup

* Reorder

* Add preparation to see error

* Add derivative method

* Add derivative! and value_and_derivative! methods

* Fixup

* Drop derivative and derivative!

* Add prepare_second_derivative and value_derivative_and_second_derivative!

* Fixup

* Drop derivative stuff

* Add second_derivative and second_derivative! methods

* Fixup

* Fixup 2

---------

Co-authored-by: Guillaume Dalle <[email protected]>
  • Loading branch information
gerlero and gdalle authored Jun 7, 2024
1 parent a997154 commit e93d16a
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ using DifferentiationInterface:
HessianExtras,
JacobianExtras,
NoDerivativeExtras,
NoSecondDerivativeExtras,
PushforwardExtras
using ForwardDiff.DiffResults: DiffResults, DiffResult, GradientResult, MutableDiffResult
using ForwardDiff:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,58 @@ function DI.pushforward!(
return dy
end

## Second derivative

function DI.prepare_second_derivative(f::F, backend::AutoForwardDiff, x) where {F}
return NoSecondDerivativeExtras()
end

function DI.second_derivative(
f::F, backend::AutoForwardDiff, x, ::NoSecondDerivativeExtras
) where {F}
T = tag_type(f, backend, x)
xdual = make_dual(T, x, one(x))
T2 = tag_type(f, backend, xdual)
ydual = f(make_dual(T2, xdual, one(xdual)))
return myderivative(T, myderivative(T2, ydual))
end

function DI.second_derivative!(
f::F, der2, backend::AutoForwardDiff, x, ::NoSecondDerivativeExtras
) where {F}
T = tag_type(f, backend, x)
xdual = make_dual(T, x, one(x))
T2 = tag_type(f, backend, xdual)
ydual = f(make_dual(T2, xdual, one(xdual)))
return myderivative!(T, der2, myderivative(T2, ydual))
end

function DI.value_derivative_and_second_derivative(
f::F, backend::AutoForwardDiff, x, ::NoSecondDerivativeExtras
) where {F}
T = tag_type(f, backend, x)
xdual = make_dual(T, x, one(x))
T2 = tag_type(f, backend, xdual)
ydual = f(make_dual(T2, xdual, one(xdual)))
y = myvalue(T, myvalue(T2, ydual))
der = myderivative(T, myvalue(T2, ydual))
der2 = myderivative(T, myderivative(T2, ydual))
return y, der, der2
end

function DI.value_derivative_and_second_derivative!(
f::F, der, der2, backend::AutoForwardDiff, x, ::NoSecondDerivativeExtras
) where {F}
T = tag_type(f, backend, x)
xdual = make_dual(T, x, one(x))
T2 = tag_type(f, backend, xdual)
ydual = f(make_dual(T2, xdual, one(xdual)))
y = myvalue(T, myvalue(T2, ydual))
myderivative!(T, der, myvalue(T2, ydual))
myderivative!(T, der2, myderivative(T2, ydual))
return y, der, der2
end

## Gradient

struct ForwardDiffGradientExtras{C} <: GradientExtras
Expand Down

0 comments on commit e93d16a

Please sign in to comment.