Skip to content

Commit

Permalink
Implement gradient_and_hvp (#588)
Browse files Browse the repository at this point in the history
* Implement `gradient_and_hvp`

* Use inner

* Da fixes

* Add symbolic backends

* Fix

* Fix JuliaFormatter to v1
  • Loading branch information
gdalle authored Oct 16, 2024
1 parent 495d988 commit 40b629d
Show file tree
Hide file tree
Showing 17 changed files with 576 additions and 140 deletions.
1 change: 1 addition & 0 deletions DifferentiationInterface/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ FastDifferentiation = "0.4.1"
FiniteDiff = "2.23.1"
FiniteDifferences = "0.12.31"
ForwardDiff = "0.10.36"
JuliaFormatter = "1"
LinearAlgebra = "<0.0.1,1"
Mooncake = "0.4.0"
PolyesterForwardDiff = "0.1.2"
Expand Down
2 changes: 2 additions & 0 deletions DifferentiationInterface/docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,8 @@ prepare_hvp
prepare_hvp_same_point
hvp
hvp!
gradient_and_hvp
gradient_and_hvp!
```

### Hessian
Expand Down
2 changes: 1 addition & 1 deletion DifferentiationInterface/docs/src/explanation/operators.md
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ Several variants of each operator are defined:
| [`jacobian`](@ref) | [`jacobian!`](@ref) | [`value_and_jacobian`](@ref) | [`value_and_jacobian!`](@ref) |
| [`pushforward`](@ref) | [`pushforward!`](@ref) | [`value_and_pushforward`](@ref) | [`value_and_pushforward!`](@ref) |
| [`pullback`](@ref) | [`pullback!`](@ref) | [`value_and_pullback`](@ref) | [`value_and_pullback!`](@ref) |
| [`hvp`](@ref) | [`hvp!`](@ref) | - | - |
| [`hvp`](@ref) | [`hvp!`](@ref) | [`gradient_and_hvp`](@ref) | [`gradient_and_hvp!`](@ref) |

## Mutation and signatures

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -396,20 +396,23 @@ end

## HVP

struct FastDifferentiationHVPPrep{E2,E2!} <: HVPPrep
struct FastDifferentiationHVPPrep{E2,E2!,E1} <: HVPPrep
hvp_exe::E2
hvp_exe!::E2!
gradient_prep::E1
end

function DI.prepare_hvp(f, ::AutoFastDifferentiation, x, tx::NTuple)
function DI.prepare_hvp(f, backend::AutoFastDifferentiation, x, tx::NTuple)
x_var = make_variables(:x, size(x)...)
y_var = f(x_var)

x_vec_var = vec(x_var)
hv_vec_var, v_vec_var = hessian_times_v(y_var, x_vec_var)
hvp_exe = make_function(hv_vec_var, vcat(x_vec_var, v_vec_var); in_place=false)
hvp_exe! = make_function(hv_vec_var, vcat(x_vec_var, v_vec_var); in_place=true)
return FastDifferentiationHVPPrep(hvp_exe, hvp_exe!)

gradient_prep = DI.prepare_gradient(f, backend, x)
return FastDifferentiationHVPPrep(hvp_exe, hvp_exe!, gradient_prep)
end

function DI.hvp(
Expand Down Expand Up @@ -439,6 +442,28 @@ function DI.hvp!(
return tg
end

function DI.gradient_and_hvp(
f, prep::FastDifferentiationHVPPrep, backend::AutoFastDifferentiation, x, tx::NTuple
)
tg = DI.hvp(f, prep, backend, x, tx)
grad = DI.gradient(f, prep.gradient_prep, backend, x)
return grad, tg
end

function DI.gradient_and_hvp!(
f,
grad,
tg::NTuple,
prep::FastDifferentiationHVPPrep,
backend::AutoFastDifferentiation,
x,
tx::NTuple,
)
DI.hvp!(f, tg, prep, backend, x, tx)
DI.gradient!(f, grad, prep.gradient_prep, backend, x)
return grad, tg
end

## Hessian

struct FastDifferentiationHessianPrep{G,E2,E2!} <: HessianPrep
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -541,6 +541,32 @@ function DI.hvp!(
return DI.hvp!(f, tg, prep, SecondOrder(backend, backend), x, tx, contexts...)
end

function DI.gradient_and_hvp(
f::F,
prep::HVPPrep,
backend::AutoForwardDiff,
x,
tx::NTuple,
contexts::Vararg{Context,C},
) where {F,C}
return DI.gradient_and_hvp(f, prep, SecondOrder(backend, backend), x, tx, contexts...)
end

function DI.gradient_and_hvp!(
f::F,
grad,
tg::NTuple,
prep::HVPPrep,
backend::AutoForwardDiff,
x,
tx::NTuple,
contexts::Vararg{Context,C},
) where {F,C}
return DI.gradient_and_hvp!(
f, grad, tg, prep, SecondOrder(backend, backend), x, tx, contexts...
)
end

## Hessian

### Unprepared, only when chunk size and tag are not specified
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ function DI.prepare_hvp(
T = tag_type(f, tagged_outer_backend, x)
xdual = make_dual(T, x, tx)
gradient_prep = DI.prepare_gradient(f, inner(backend), xdual, contexts...)
# TODO: get rid of closure?
function inner_gradient(x, unannotated_contexts...)
annotated_contexts = rewrap(unannotated_contexts...)
return DI.gradient(f, gradient_prep, inner(backend), x, annotated_contexts...)
Expand Down Expand Up @@ -73,3 +74,34 @@ function DI.hvp!(
)
return tg
end

function DI.gradient_and_hvp(
f::F,
prep::ForwardDiffOverSomethingHVPPrep,
::SecondOrder{<:AutoForwardDiff},
x,
tx::NTuple,
contexts::Vararg{Context,C},
) where {F,C}
(; tagged_outer_backend, inner_gradient, outer_pushforward_prep) = prep
return DI.value_and_pushforward(
inner_gradient, outer_pushforward_prep, tagged_outer_backend, x, tx, contexts...
)
end

function DI.gradient_and_hvp!(
f::F,
grad,
tg::NTuple,
prep::ForwardDiffOverSomethingHVPPrep,
::SecondOrder{<:AutoForwardDiff},
x,
tx::NTuple,
contexts::Vararg{Context,C},
) where {F,C}
(; tagged_outer_backend, inner_gradient, outer_pushforward_prep) = prep
new_grad, _ = DI.value_and_pushforward!(
inner_gradient, tg, outer_pushforward_prep, tagged_outer_backend, x, tx, contexts...
)
return copyto!(grad, new_grad), tg
end
Original file line number Diff line number Diff line change
Expand Up @@ -309,6 +309,32 @@ function DI.hvp!(
return DI.hvp!(f, tg, prep, single_threaded(backend), x, tx, contexts...)
end

function DI.gradient_and_hvp(
f,
prep::HVPPrep,
backend::AutoPolyesterForwardDiff,
x,
tx::NTuple,
contexts::Vararg{Context,C},
) where {C}
return DI.gradient_and_hvp(f, prep, single_threaded(backend), x, tx, contexts...)
end

function DI.gradient_and_hvp!(
f,
grad,
tg::NTuple,
prep::HVPPrep,
backend::AutoPolyesterForwardDiff,
x,
tx::NTuple,
contexts::Vararg{Context,C},
) where {C}
return DI.gradient_and_hvp!(
f, grad, tg, prep, single_threaded(backend), x, tx, contexts...
)
end

## Second derivative

function DI.prepare_second_derivative(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -317,6 +317,22 @@ function DI.hvp!(
return tg
end

function DI.gradient_and_hvp(
f, prep::SymbolicsOneArgHVPPrep, backend::AutoSymbolics, x, tx::NTuple
)
tg = DI.hvp(f, prep, backend, x, tx)
grad = DI.gradient(f, prep.gradient_prep, backend, x)
return grad, tg
end

function DI.gradient_and_hvp!(
f, grad, tg::NTuple, prep::SymbolicsOneArgHVPPrep, backend::AutoSymbolics, x, tx::NTuple
)
DI.hvp!(f, tg, prep, backend, x, tx)
DI.gradient!(f, grad, prep.gradient_prep, backend, x)
return grad, tg
end

## Second derivative

struct SymbolicsOneArgSecondDerivativePrep{D,E1,E1!} <: SecondDerivativePrep
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,29 @@ function DI.hvp!(
return DI.hvp!(f, tg, prep, SecondOrder(AutoForwardDiff(), backend), x, tx, contexts...)
end

function DI.gradient_and_hvp(
f, prep::HVPPrep, backend::AutoZygote, x, tx::NTuple, contexts::Vararg{Constant,C}
) where {C}
return DI.gradient_and_hvp(
f, prep, SecondOrder(AutoForwardDiff(), backend), x, tx, contexts...
)
end

function DI.gradient_and_hvp!(
f,
grad,
tg::NTuple,
prep::HVPPrep,
backend::AutoZygote,
x,
tx::NTuple,
contexts::Vararg{Constant,C},
) where {C}
return DI.gradient_and_hvp!(
f, grad, tg, prep, SecondOrder(AutoForwardDiff(), backend), x, tx, contexts...
)
end

## Hessian

function DI.prepare_hessian(f, ::AutoZygote, x, contexts::Vararg{Constant,C}) where {C}
Expand Down
2 changes: 1 addition & 1 deletion DifferentiationInterface/src/DifferentiationInterface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ export jacobian!, jacobian

export second_derivative!, second_derivative
export value_derivative_and_second_derivative, value_derivative_and_second_derivative!
export hvp!, hvp
export hvp!, hvp, gradient_and_hvp, gradient_and_hvp!
export hessian!, hessian
export value_gradient_and_hessian, value_gradient_and_hessian!

Expand Down
46 changes: 32 additions & 14 deletions DifferentiationInterface/src/fallbacks/no_prep.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ for op in [
elseif op == :hessian
:value_gradient_and_hessian
elseif op == :hvp
nothing
:gradient_and_hvp
else
Symbol("value_and_", op)
end
Expand Down Expand Up @@ -138,26 +138,44 @@ for op in [
prep = $prep_op(f, backend, x, seed, contexts...)
return $op!(f, result, prep, backend, x, seed, contexts...)
end

op == :hvp && continue

@eval function $val_and_op(
f::F, backend::AbstractADType, x, seed::NTuple, contexts::Vararg{Context,C}
) where {F,C}
prep = $prep_op(f, backend, x, seed, contexts...)
return $val_and_op(f, prep, backend, x, seed, contexts...)
end
@eval function $val_and_op!(
f::F,
result::NTuple,
backend::AbstractADType,
x,
seed::NTuple,
contexts::Vararg{Context,C},
) where {F,C}
prep = $prep_op(f, backend, x, seed, contexts...)
return $val_and_op!(f, result, prep, backend, x, seed, contexts...)

if op in (:pushforward, :pullback)
@eval function $val_and_op!(
f::F,
result::NTuple,
backend::AbstractADType,
x,
seed::NTuple,
contexts::Vararg{Context,C},
) where {F,C}
prep = $prep_op(f, backend, x, seed, contexts...)
return $val_and_op!(f, result, prep, backend, x, seed, contexts...)
end
elseif op == :hvp
@eval function $val_and_op!(
f::F,
result1,
result2::NTuple,
backend::AbstractADType,
x,
seed::NTuple,
contexts::Vararg{Context,C},
) where {F,C}
prep = $prep_op(f, backend, x, seed, contexts...)
return $val_and_op!(
f, result1, result2, prep, backend, x, seed, contexts...
)
end
end

op == :hvp && continue

@eval function $op(
f!::F, y, backend::AbstractADType, x, seed::NTuple, contexts::Vararg{Context,C}
) where {F,C}
Expand Down
Loading

0 comments on commit 40b629d

Please sign in to comment.