From c6aaabea1f24e895fae9b8ef46581e133ac9eabb Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Thu, 30 May 2024 07:01:50 +0200 Subject: [PATCH] Better hessian check, Const(f) in Enzyme (#284) * Small changes * Hessian check * Correct hess_th --- DifferentiationInterface/docs/src/tutorial1.md | 10 ++++++++-- .../forward_onearg.jl | 8 ++++++-- .../forward_twoarg.jl | 2 +- .../reverse_onearg.jl | 4 ++-- .../reverse_twoarg.jl | 6 ++++-- DifferentiationInterface/src/utils/check.jl | 12 ++++++++---- 6 files changed, 29 insertions(+), 13 deletions(-) diff --git a/DifferentiationInterface/docs/src/tutorial1.md b/DifferentiationInterface/docs/src/tutorial1.md index 79e45127..0e3cd120 100644 --- a/DifferentiationInterface/docs/src/tutorial1.md +++ b/DifferentiationInterface/docs/src/tutorial1.md @@ -9,10 +9,16 @@ using DifferentiationInterface ## Computing a gradient A common use case of automatic differentiation (AD) is optimizing real-valued functions with first- or second-order methods. -Let's define a simple objective and a random input vector +Let's define a simple objective (the squared norm) and a random input vector ```@example tuto1 -f(x) = sum(abs2, x) +function f(x::AbstractVector{T}) where {T} + y = zero(T) + for i in eachindex(x) + y += abs2(x[i]) + end + return y +end x = collect(1.0:5.0) ``` diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_onearg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_onearg.jl index 513372d3..a04ba42c 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_onearg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_onearg.jl @@ -6,7 +6,9 @@ function DI.value_and_pushforward( f, backend::AutoForwardOrNothingEnzyme, x, dx, ::NoPushforwardExtras ) dx_sametype = convert(typeof(x), dx) - y, new_dy = autodiff(forward_mode(backend), f, Duplicated, Duplicated(x, dx_sametype)) + y, new_dy = autodiff( + forward_mode(backend), Const(f), Duplicated, Duplicated(x, dx_sametype) + ) return y, new_dy end @@ -15,7 +17,9 @@ function DI.pushforward( ) dx_sametype = convert(typeof(x), dx) new_dy = only( - autodiff(forward_mode(backend), f, DuplicatedNoNeed, Duplicated(x, dx_sametype)) + autodiff( + forward_mode(backend), Const(f), DuplicatedNoNeed, Duplicated(x, dx_sametype) + ), ) return new_dy end diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_twoarg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_twoarg.jl index 621260ae..5de9a9c6 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_twoarg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_twoarg.jl @@ -9,7 +9,7 @@ function DI.value_and_pushforward( dy_sametype = zero(y) autodiff( forward_mode(backend), - f!, + Const(f!), Const, Duplicated(y, dy_sametype), Duplicated(x, dx_sametype), diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/reverse_onearg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/reverse_onearg.jl index e72cd1c0..a859562f 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/reverse_onearg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/reverse_onearg.jl @@ -7,7 +7,7 @@ DI.prepare_pullback(f, ::AutoReverseOrNothingEnzyme, x, dy) = NoPullbackExtras() function DI.value_and_pullback( f, ::AutoReverseOrNothingEnzyme, x::Number, dy::Number, ::NoPullbackExtras ) - der, y = autodiff(ReverseWithPrimal, f, Active, Active(x)) + der, y = autodiff(ReverseWithPrimal, Const(f), Active, Active(x)) new_dx = dy * only(der) return y, new_dx end @@ -43,7 +43,7 @@ function DI.value_and_pullback!( f, dx, ::AutoReverseOrNothingEnzyme, x::AbstractArray, dy::Number, ::NoPullbackExtras ) dx_sametype = zero_sametype!(dx, x) - _, y = autodiff(ReverseWithPrimal, f, Active, Duplicated(x, dx_sametype)) + _, y = autodiff(ReverseWithPrimal, Const(f), Active, Duplicated(x, dx_sametype)) dx_sametype .*= dy return y, copyto!(dx, dx_sametype) end diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/reverse_twoarg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/reverse_twoarg.jl index 3eee64b1..818e1896 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/reverse_twoarg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/reverse_twoarg.jl @@ -7,7 +7,9 @@ function DI.value_and_pullback( ) dy_sametype = convert(typeof(y), copy(dy)) _, new_dx = only( - autodiff(reverse_mode(backend), f!, Const, Duplicated(y, dy_sametype), Active(x)) + autodiff( + reverse_mode(backend), Const(f!), Const, Duplicated(y, dy_sametype), Active(x) + ), ) return y, new_dx end @@ -19,7 +21,7 @@ function DI.value_and_pullback( dy_sametype = convert(typeof(y), copy(dy)) autodiff( reverse_mode(backend), - f!, + Const(f!), Const, Duplicated(y, dy_sametype), Duplicated(x, dx_sametype), diff --git a/DifferentiationInterface/src/utils/check.jl b/DifferentiationInterface/src/utils/check.jl index 91214b20..1db3d242 100644 --- a/DifferentiationInterface/src/utils/check.jl +++ b/DifferentiationInterface/src/utils/check.jl @@ -16,7 +16,7 @@ Check whether `backend` supports differentiation of two-argument functions. """ check_twoarg(backend::AbstractADType) = Bool(twoarg_support(backend)) -sqnorm(x::AbstractArray) = sum(abs2, x) +hess_checker(x::AbstractArray) = abs2(x[1]) * abs2(x[2]) """ check_hessian(backend) @@ -28,9 +28,13 @@ Check whether `backend` supports second order differentiation by trying to compu """ function check_hessian(backend::AbstractADType; verbose=true) try - x = [1.0, 3.0] - hess = hessian(sqnorm, backend, x) - return isapprox(hess, [2.0 0.0; 0.0 2.0]; rtol=1e-3) + x = [2.0, 3.0] + hess = hessian(hess_checker, backend, x) + hess_th = [ + 2*abs2(x[2]) 4*x[1]*x[2] + 4*x[1]*x[2] 2*abs2(x[1]) + ] + return isapprox(hess, hess_th; rtol=1e-3) catch exception if verbose @warn "Backend $backend does not support hessian" exception