Skip to content

Commit

Permalink
Better hessian check, Const(f) in Enzyme (#284)
Browse files Browse the repository at this point in the history
* Small changes

* Hessian check

* Correct hess_th
  • Loading branch information
gdalle authored May 30, 2024
1 parent 2583e79 commit c6aaabe
Show file tree
Hide file tree
Showing 6 changed files with 29 additions and 13 deletions.
10 changes: 8 additions & 2 deletions DifferentiationInterface/docs/src/tutorial1.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,16 @@ using DifferentiationInterface
## Computing a gradient

A common use case of automatic differentiation (AD) is optimizing real-valued functions with first- or second-order methods.
Let's define a simple objective and a random input vector
Let's define a simple objective (the squared norm) and a random input vector

```@example tuto1
f(x) = sum(abs2, x)
function f(x::AbstractVector{T}) where {T}
y = zero(T)
for i in eachindex(x)
y += abs2(x[i])
end
return y
end
x = collect(1.0:5.0)
```
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@ function DI.value_and_pushforward(
f, backend::AutoForwardOrNothingEnzyme, x, dx, ::NoPushforwardExtras
)
dx_sametype = convert(typeof(x), dx)
y, new_dy = autodiff(forward_mode(backend), f, Duplicated, Duplicated(x, dx_sametype))
y, new_dy = autodiff(
forward_mode(backend), Const(f), Duplicated, Duplicated(x, dx_sametype)
)
return y, new_dy
end

Expand All @@ -15,7 +17,9 @@ function DI.pushforward(
)
dx_sametype = convert(typeof(x), dx)
new_dy = only(
autodiff(forward_mode(backend), f, DuplicatedNoNeed, Duplicated(x, dx_sametype))
autodiff(
forward_mode(backend), Const(f), DuplicatedNoNeed, Duplicated(x, dx_sametype)
),
)
return new_dy
end
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ function DI.value_and_pushforward(
dy_sametype = zero(y)
autodiff(
forward_mode(backend),
f!,
Const(f!),
Const,
Duplicated(y, dy_sametype),
Duplicated(x, dx_sametype),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ DI.prepare_pullback(f, ::AutoReverseOrNothingEnzyme, x, dy) = NoPullbackExtras()
function DI.value_and_pullback(
f, ::AutoReverseOrNothingEnzyme, x::Number, dy::Number, ::NoPullbackExtras
)
der, y = autodiff(ReverseWithPrimal, f, Active, Active(x))
der, y = autodiff(ReverseWithPrimal, Const(f), Active, Active(x))
new_dx = dy * only(der)
return y, new_dx
end
Expand Down Expand Up @@ -43,7 +43,7 @@ function DI.value_and_pullback!(
f, dx, ::AutoReverseOrNothingEnzyme, x::AbstractArray, dy::Number, ::NoPullbackExtras
)
dx_sametype = zero_sametype!(dx, x)
_, y = autodiff(ReverseWithPrimal, f, Active, Duplicated(x, dx_sametype))
_, y = autodiff(ReverseWithPrimal, Const(f), Active, Duplicated(x, dx_sametype))
dx_sametype .*= dy
return y, copyto!(dx, dx_sametype)
end
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@ function DI.value_and_pullback(
)
dy_sametype = convert(typeof(y), copy(dy))
_, new_dx = only(
autodiff(reverse_mode(backend), f!, Const, Duplicated(y, dy_sametype), Active(x))
autodiff(
reverse_mode(backend), Const(f!), Const, Duplicated(y, dy_sametype), Active(x)
),
)
return y, new_dx
end
Expand All @@ -19,7 +21,7 @@ function DI.value_and_pullback(
dy_sametype = convert(typeof(y), copy(dy))
autodiff(
reverse_mode(backend),
f!,
Const(f!),
Const,
Duplicated(y, dy_sametype),
Duplicated(x, dx_sametype),
Expand Down
12 changes: 8 additions & 4 deletions DifferentiationInterface/src/utils/check.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ Check whether `backend` supports differentiation of two-argument functions.
"""
check_twoarg(backend::AbstractADType) = Bool(twoarg_support(backend))

sqnorm(x::AbstractArray) = sum(abs2, x)
hess_checker(x::AbstractArray) = abs2(x[1]) * abs2(x[2])

"""
check_hessian(backend)
Expand All @@ -28,9 +28,13 @@ Check whether `backend` supports second order differentiation by trying to compu
"""
function check_hessian(backend::AbstractADType; verbose=true)
try
x = [1.0, 3.0]
hess = hessian(sqnorm, backend, x)
return isapprox(hess, [2.0 0.0; 0.0 2.0]; rtol=1e-3)
x = [2.0, 3.0]
hess = hessian(hess_checker, backend, x)
hess_th = [
2*abs2(x[2]) 4*x[1]*x[2]
4*x[1]*x[2] 2*abs2(x[1])
]
return isapprox(hess, hess_th; rtol=1e-3)
catch exception
if verbose
@warn "Backend $backend does not support hessian" exception
Expand Down

0 comments on commit c6aaabe

Please sign in to comment.