From e93d16a68a0da772a12399dac7a2b169b1dc9201 Mon Sep 17 00:00:00 2001 From: Gabriel Gerlero Date: Fri, 7 Jun 2024 12:05:37 -0300 Subject: [PATCH] =?UTF-8?q?Add=20dual=20number=E2=80=93based=20second=20de?= =?UTF-8?q?rivatives=20for=20`ForwardDiff`=20(#310)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Add dual number-based implementations for ForwardDiff scalar derivatives * Use functions from utils.jl * Fixup * Reorder * Add preparation to see error * Add derivative method * Add derivative! and value_and_derivative! methods * Fixup * Drop derivative and derivative! * Add prepare_second_derivative and value_derivative_and_second_derivative! * Fixup * Drop derivative stuff * Add second_derivative and second_derivative! methods * Fixup * Fixup 2 --------- Co-authored-by: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> --- .../DifferentiationInterfaceForwardDiffExt.jl | 1 + .../onearg.jl | 52 +++++++++++++++++++ 2 files changed, 53 insertions(+) diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/DifferentiationInterfaceForwardDiffExt.jl b/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/DifferentiationInterfaceForwardDiffExt.jl index f041d033..1f52ad2b 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/DifferentiationInterfaceForwardDiffExt.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/DifferentiationInterfaceForwardDiffExt.jl @@ -10,6 +10,7 @@ using DifferentiationInterface: HessianExtras, JacobianExtras, NoDerivativeExtras, + NoSecondDerivativeExtras, PushforwardExtras using ForwardDiff.DiffResults: DiffResults, DiffResult, GradientResult, MutableDiffResult using ForwardDiff: diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/onearg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/onearg.jl index ad4aa0c6..854429ac 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/onearg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/onearg.jl @@ -61,6 +61,58 @@ function DI.pushforward!( return dy end +## Second derivative + +function DI.prepare_second_derivative(f::F, backend::AutoForwardDiff, x) where {F} + return NoSecondDerivativeExtras() +end + +function DI.second_derivative( + f::F, backend::AutoForwardDiff, x, ::NoSecondDerivativeExtras +) where {F} + T = tag_type(f, backend, x) + xdual = make_dual(T, x, one(x)) + T2 = tag_type(f, backend, xdual) + ydual = f(make_dual(T2, xdual, one(xdual))) + return myderivative(T, myderivative(T2, ydual)) +end + +function DI.second_derivative!( + f::F, der2, backend::AutoForwardDiff, x, ::NoSecondDerivativeExtras +) where {F} + T = tag_type(f, backend, x) + xdual = make_dual(T, x, one(x)) + T2 = tag_type(f, backend, xdual) + ydual = f(make_dual(T2, xdual, one(xdual))) + return myderivative!(T, der2, myderivative(T2, ydual)) +end + +function DI.value_derivative_and_second_derivative( + f::F, backend::AutoForwardDiff, x, ::NoSecondDerivativeExtras +) where {F} + T = tag_type(f, backend, x) + xdual = make_dual(T, x, one(x)) + T2 = tag_type(f, backend, xdual) + ydual = f(make_dual(T2, xdual, one(xdual))) + y = myvalue(T, myvalue(T2, ydual)) + der = myderivative(T, myvalue(T2, ydual)) + der2 = myderivative(T, myderivative(T2, ydual)) + return y, der, der2 +end + +function DI.value_derivative_and_second_derivative!( + f::F, der, der2, backend::AutoForwardDiff, x, ::NoSecondDerivativeExtras +) where {F} + T = tag_type(f, backend, x) + xdual = make_dual(T, x, one(x)) + T2 = tag_type(f, backend, xdual) + ydual = f(make_dual(T2, xdual, one(xdual))) + y = myvalue(T, myvalue(T2, ydual)) + myderivative!(T, der, myvalue(T2, ydual)) + myderivative!(T, der2, myderivative(T2, ydual)) + return y, der, der2 +end + ## Gradient struct ForwardDiffGradientExtras{C} <: GradientExtras