diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/DifferentiationInterfaceForwardDiffExt.jl b/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/DifferentiationInterfaceForwardDiffExt.jl index f041d033..1f52ad2b 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/DifferentiationInterfaceForwardDiffExt.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/DifferentiationInterfaceForwardDiffExt.jl @@ -10,6 +10,7 @@ using DifferentiationInterface: HessianExtras, JacobianExtras, NoDerivativeExtras, + NoSecondDerivativeExtras, PushforwardExtras using ForwardDiff.DiffResults: DiffResults, DiffResult, GradientResult, MutableDiffResult using ForwardDiff: diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/onearg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/onearg.jl index ad4aa0c6..854429ac 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/onearg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/onearg.jl @@ -61,6 +61,58 @@ function DI.pushforward!( return dy end +## Second derivative + +function DI.prepare_second_derivative(f::F, backend::AutoForwardDiff, x) where {F} + return NoSecondDerivativeExtras() +end + +function DI.second_derivative( + f::F, backend::AutoForwardDiff, x, ::NoSecondDerivativeExtras +) where {F} + T = tag_type(f, backend, x) + xdual = make_dual(T, x, one(x)) + T2 = tag_type(f, backend, xdual) + ydual = f(make_dual(T2, xdual, one(xdual))) + return myderivative(T, myderivative(T2, ydual)) +end + +function DI.second_derivative!( + f::F, der2, backend::AutoForwardDiff, x, ::NoSecondDerivativeExtras +) where {F} + T = tag_type(f, backend, x) + xdual = make_dual(T, x, one(x)) + T2 = tag_type(f, backend, xdual) + ydual = f(make_dual(T2, xdual, one(xdual))) + return myderivative!(T, der2, myderivative(T2, ydual)) +end + +function DI.value_derivative_and_second_derivative( + f::F, backend::AutoForwardDiff, x, ::NoSecondDerivativeExtras +) where {F} + T = tag_type(f, backend, x) + xdual = make_dual(T, x, one(x)) + T2 = tag_type(f, backend, xdual) + ydual = f(make_dual(T2, xdual, one(xdual))) + y = myvalue(T, myvalue(T2, ydual)) + der = myderivative(T, myvalue(T2, ydual)) + der2 = myderivative(T, myderivative(T2, ydual)) + return y, der, der2 +end + +function DI.value_derivative_and_second_derivative!( + f::F, der, der2, backend::AutoForwardDiff, x, ::NoSecondDerivativeExtras +) where {F} + T = tag_type(f, backend, x) + xdual = make_dual(T, x, one(x)) + T2 = tag_type(f, backend, xdual) + ydual = f(make_dual(T2, xdual, one(xdual))) + y = myvalue(T, myvalue(T2, ydual)) + myderivative!(T, der, myvalue(T2, ydual)) + myderivative!(T, der2, myderivative(T2, ydual)) + return y, der, der2 +end + ## Gradient struct ForwardDiffGradientExtras{C} <: GradientExtras