diff --git a/DifferentiationInterface/Project.toml b/DifferentiationInterface/Project.toml index 84786f31..2b2cafe0 100644 --- a/DifferentiationInterface/Project.toml +++ b/DifferentiationInterface/Project.toml @@ -53,6 +53,7 @@ FastDifferentiation = "0.4.1" FiniteDiff = "2.23.1" FiniteDifferences = "0.12.31" ForwardDiff = "0.10.36" +JuliaFormatter = "1" LinearAlgebra = "<0.0.1,1" Mooncake = "0.4.0" PolyesterForwardDiff = "0.1.2" diff --git a/DifferentiationInterface/docs/src/api.md b/DifferentiationInterface/docs/src/api.md index 08208a22..33f1ce8f 100644 --- a/DifferentiationInterface/docs/src/api.md +++ b/DifferentiationInterface/docs/src/api.md @@ -93,6 +93,8 @@ prepare_hvp prepare_hvp_same_point hvp hvp! +gradient_and_hvp +gradient_and_hvp! ``` ### Hessian diff --git a/DifferentiationInterface/docs/src/explanation/operators.md b/DifferentiationInterface/docs/src/explanation/operators.md index d8e257c5..942e0b4a 100644 --- a/DifferentiationInterface/docs/src/explanation/operators.md +++ b/DifferentiationInterface/docs/src/explanation/operators.md @@ -58,7 +58,7 @@ Several variants of each operator are defined: | [`jacobian`](@ref) | [`jacobian!`](@ref) | [`value_and_jacobian`](@ref) | [`value_and_jacobian!`](@ref) | | [`pushforward`](@ref) | [`pushforward!`](@ref) | [`value_and_pushforward`](@ref) | [`value_and_pushforward!`](@ref) | | [`pullback`](@ref) | [`pullback!`](@ref) | [`value_and_pullback`](@ref) | [`value_and_pullback!`](@ref) | -| [`hvp`](@ref) | [`hvp!`](@ref) | - | - | +| [`hvp`](@ref) | [`hvp!`](@ref) | [`gradient_and_hvp`](@ref) | [`gradient_and_hvp!`](@ref) | ## Mutation and signatures diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceFastDifferentiationExt/onearg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceFastDifferentiationExt/onearg.jl index b9ce86f3..a39bcde2 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceFastDifferentiationExt/onearg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceFastDifferentiationExt/onearg.jl @@ -396,12 +396,13 @@ end ## HVP -struct FastDifferentiationHVPPrep{E2,E2!} <: HVPPrep +struct FastDifferentiationHVPPrep{E2,E2!,E1} <: HVPPrep hvp_exe::E2 hvp_exe!::E2! + gradient_prep::E1 end -function DI.prepare_hvp(f, ::AutoFastDifferentiation, x, tx::NTuple) +function DI.prepare_hvp(f, backend::AutoFastDifferentiation, x, tx::NTuple) x_var = make_variables(:x, size(x)...) y_var = f(x_var) @@ -409,7 +410,9 @@ function DI.prepare_hvp(f, ::AutoFastDifferentiation, x, tx::NTuple) hv_vec_var, v_vec_var = hessian_times_v(y_var, x_vec_var) hvp_exe = make_function(hv_vec_var, vcat(x_vec_var, v_vec_var); in_place=false) hvp_exe! = make_function(hv_vec_var, vcat(x_vec_var, v_vec_var); in_place=true) - return FastDifferentiationHVPPrep(hvp_exe, hvp_exe!) + + gradient_prep = DI.prepare_gradient(f, backend, x) + return FastDifferentiationHVPPrep(hvp_exe, hvp_exe!, gradient_prep) end function DI.hvp( @@ -439,6 +442,28 @@ function DI.hvp!( return tg end +function DI.gradient_and_hvp( + f, prep::FastDifferentiationHVPPrep, backend::AutoFastDifferentiation, x, tx::NTuple +) + tg = DI.hvp(f, prep, backend, x, tx) + grad = DI.gradient(f, prep.gradient_prep, backend, x) + return grad, tg +end + +function DI.gradient_and_hvp!( + f, + grad, + tg::NTuple, + prep::FastDifferentiationHVPPrep, + backend::AutoFastDifferentiation, + x, + tx::NTuple, +) + DI.hvp!(f, tg, prep, backend, x, tx) + DI.gradient!(f, grad, prep.gradient_prep, backend, x) + return grad, tg +end + ## Hessian struct FastDifferentiationHessianPrep{G,E2,E2!} <: HessianPrep diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/onearg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/onearg.jl index 3bf9211e..3f317cf5 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/onearg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/onearg.jl @@ -541,6 +541,32 @@ function DI.hvp!( return DI.hvp!(f, tg, prep, SecondOrder(backend, backend), x, tx, contexts...) end +function DI.gradient_and_hvp( + f::F, + prep::HVPPrep, + backend::AutoForwardDiff, + x, + tx::NTuple, + contexts::Vararg{Context,C}, +) where {F,C} + return DI.gradient_and_hvp(f, prep, SecondOrder(backend, backend), x, tx, contexts...) +end + +function DI.gradient_and_hvp!( + f::F, + grad, + tg::NTuple, + prep::HVPPrep, + backend::AutoForwardDiff, + x, + tx::NTuple, + contexts::Vararg{Context,C}, +) where {F,C} + return DI.gradient_and_hvp!( + f, grad, tg, prep, SecondOrder(backend, backend), x, tx, contexts... + ) +end + ## Hessian ### Unprepared, only when chunk size and tag are not specified diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/secondorder.jl b/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/secondorder.jl index b69ad90b..e0c84459 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/secondorder.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/secondorder.jl @@ -32,6 +32,7 @@ function DI.prepare_hvp( T = tag_type(f, tagged_outer_backend, x) xdual = make_dual(T, x, tx) gradient_prep = DI.prepare_gradient(f, inner(backend), xdual, contexts...) + # TODO: get rid of closure? function inner_gradient(x, unannotated_contexts...) annotated_contexts = rewrap(unannotated_contexts...) return DI.gradient(f, gradient_prep, inner(backend), x, annotated_contexts...) @@ -73,3 +74,34 @@ function DI.hvp!( ) return tg end + +function DI.gradient_and_hvp( + f::F, + prep::ForwardDiffOverSomethingHVPPrep, + ::SecondOrder{<:AutoForwardDiff}, + x, + tx::NTuple, + contexts::Vararg{Context,C}, +) where {F,C} + (; tagged_outer_backend, inner_gradient, outer_pushforward_prep) = prep + return DI.value_and_pushforward( + inner_gradient, outer_pushforward_prep, tagged_outer_backend, x, tx, contexts... + ) +end + +function DI.gradient_and_hvp!( + f::F, + grad, + tg::NTuple, + prep::ForwardDiffOverSomethingHVPPrep, + ::SecondOrder{<:AutoForwardDiff}, + x, + tx::NTuple, + contexts::Vararg{Context,C}, +) where {F,C} + (; tagged_outer_backend, inner_gradient, outer_pushforward_prep) = prep + new_grad, _ = DI.value_and_pushforward!( + inner_gradient, tg, outer_pushforward_prep, tagged_outer_backend, x, tx, contexts... + ) + return copyto!(grad, new_grad), tg +end diff --git a/DifferentiationInterface/ext/DifferentiationInterfacePolyesterForwardDiffExt/onearg.jl b/DifferentiationInterface/ext/DifferentiationInterfacePolyesterForwardDiffExt/onearg.jl index 6d6f93e6..331baa8d 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfacePolyesterForwardDiffExt/onearg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfacePolyesterForwardDiffExt/onearg.jl @@ -309,6 +309,32 @@ function DI.hvp!( return DI.hvp!(f, tg, prep, single_threaded(backend), x, tx, contexts...) end +function DI.gradient_and_hvp( + f, + prep::HVPPrep, + backend::AutoPolyesterForwardDiff, + x, + tx::NTuple, + contexts::Vararg{Context,C}, +) where {C} + return DI.gradient_and_hvp(f, prep, single_threaded(backend), x, tx, contexts...) +end + +function DI.gradient_and_hvp!( + f, + grad, + tg::NTuple, + prep::HVPPrep, + backend::AutoPolyesterForwardDiff, + x, + tx::NTuple, + contexts::Vararg{Context,C}, +) where {C} + return DI.gradient_and_hvp!( + f, grad, tg, prep, single_threaded(backend), x, tx, contexts... + ) +end + ## Second derivative function DI.prepare_second_derivative( diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceSymbolicsExt/onearg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceSymbolicsExt/onearg.jl index 60c22231..2ad158a1 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceSymbolicsExt/onearg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceSymbolicsExt/onearg.jl @@ -317,6 +317,22 @@ function DI.hvp!( return tg end +function DI.gradient_and_hvp( + f, prep::SymbolicsOneArgHVPPrep, backend::AutoSymbolics, x, tx::NTuple +) + tg = DI.hvp(f, prep, backend, x, tx) + grad = DI.gradient(f, prep.gradient_prep, backend, x) + return grad, tg +end + +function DI.gradient_and_hvp!( + f, grad, tg::NTuple, prep::SymbolicsOneArgHVPPrep, backend::AutoSymbolics, x, tx::NTuple +) + DI.hvp!(f, tg, prep, backend, x, tx) + DI.gradient!(f, grad, prep.gradient_prep, backend, x) + return grad, tg +end + ## Second derivative struct SymbolicsOneArgSecondDerivativePrep{D,E1,E1!} <: SecondDerivativePrep diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceZygoteExt/DifferentiationInterfaceZygoteExt.jl b/DifferentiationInterface/ext/DifferentiationInterfaceZygoteExt/DifferentiationInterfaceZygoteExt.jl index e73be43e..9056d18f 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceZygoteExt/DifferentiationInterfaceZygoteExt.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceZygoteExt/DifferentiationInterfaceZygoteExt.jl @@ -171,6 +171,29 @@ function DI.hvp!( return DI.hvp!(f, tg, prep, SecondOrder(AutoForwardDiff(), backend), x, tx, contexts...) end +function DI.gradient_and_hvp( + f, prep::HVPPrep, backend::AutoZygote, x, tx::NTuple, contexts::Vararg{Constant,C} +) where {C} + return DI.gradient_and_hvp( + f, prep, SecondOrder(AutoForwardDiff(), backend), x, tx, contexts... + ) +end + +function DI.gradient_and_hvp!( + f, + grad, + tg::NTuple, + prep::HVPPrep, + backend::AutoZygote, + x, + tx::NTuple, + contexts::Vararg{Constant,C}, +) where {C} + return DI.gradient_and_hvp!( + f, grad, tg, prep, SecondOrder(AutoForwardDiff(), backend), x, tx, contexts... + ) +end + ## Hessian function DI.prepare_hessian(f, ::AutoZygote, x, contexts::Vararg{Constant,C}) where {C} diff --git a/DifferentiationInterface/src/DifferentiationInterface.jl b/DifferentiationInterface/src/DifferentiationInterface.jl index 9831aba8..a059428b 100644 --- a/DifferentiationInterface/src/DifferentiationInterface.jl +++ b/DifferentiationInterface/src/DifferentiationInterface.jl @@ -84,7 +84,7 @@ export jacobian!, jacobian export second_derivative!, second_derivative export value_derivative_and_second_derivative, value_derivative_and_second_derivative! -export hvp!, hvp +export hvp!, hvp, gradient_and_hvp, gradient_and_hvp! export hessian!, hessian export value_gradient_and_hessian, value_gradient_and_hessian! diff --git a/DifferentiationInterface/src/fallbacks/no_prep.jl b/DifferentiationInterface/src/fallbacks/no_prep.jl index 509a51da..0f2ecd4e 100644 --- a/DifferentiationInterface/src/fallbacks/no_prep.jl +++ b/DifferentiationInterface/src/fallbacks/no_prep.jl @@ -14,7 +14,7 @@ for op in [ elseif op == :hessian :value_gradient_and_hessian elseif op == :hvp - nothing + :gradient_and_hvp else Symbol("value_and_", op) end @@ -138,26 +138,44 @@ for op in [ prep = $prep_op(f, backend, x, seed, contexts...) return $op!(f, result, prep, backend, x, seed, contexts...) end - - op == :hvp && continue - @eval function $val_and_op( f::F, backend::AbstractADType, x, seed::NTuple, contexts::Vararg{Context,C} ) where {F,C} prep = $prep_op(f, backend, x, seed, contexts...) return $val_and_op(f, prep, backend, x, seed, contexts...) end - @eval function $val_and_op!( - f::F, - result::NTuple, - backend::AbstractADType, - x, - seed::NTuple, - contexts::Vararg{Context,C}, - ) where {F,C} - prep = $prep_op(f, backend, x, seed, contexts...) - return $val_and_op!(f, result, prep, backend, x, seed, contexts...) + + if op in (:pushforward, :pullback) + @eval function $val_and_op!( + f::F, + result::NTuple, + backend::AbstractADType, + x, + seed::NTuple, + contexts::Vararg{Context,C}, + ) where {F,C} + prep = $prep_op(f, backend, x, seed, contexts...) + return $val_and_op!(f, result, prep, backend, x, seed, contexts...) + end + elseif op == :hvp + @eval function $val_and_op!( + f::F, + result1, + result2::NTuple, + backend::AbstractADType, + x, + seed::NTuple, + contexts::Vararg{Context,C}, + ) where {F,C} + prep = $prep_op(f, backend, x, seed, contexts...) + return $val_and_op!( + f, result1, result2, prep, backend, x, seed, contexts... + ) + end end + + op == :hvp && continue + @eval function $op( f!::F, y, backend::AbstractADType, x, seed::NTuple, contexts::Vararg{Context,C} ) where {F,C} diff --git a/DifferentiationInterface/src/second_order/hvp.jl b/DifferentiationInterface/src/second_order/hvp.jl index 67ee0003..4f53d81d 100644 --- a/DifferentiationInterface/src/second_order/hvp.jl +++ b/DifferentiationInterface/src/second_order/hvp.jl @@ -50,27 +50,23 @@ $(document_preparation("hvp"; same_point=true)) """ function hvp! end -## Preparation +""" + gradient_and_hvp(f, [prep,] backend, x, tx, [contexts...]) -> (grad, tg) -struct ForwardOverForwardHVPPrep{E<:PushforwardPrep} <: HVPPrep - # pushforward of many pushforwards in theory, but pushforward of gradient in practice - outer_pushforward_prep::E -end +Compute the gradient and the Hessian-vector product of `f` at point `x` with a tuple of tangents `tx`. -struct ForwardOverReverseHVPPrep{E<:PushforwardPrep} <: HVPPrep - # pushforward of gradient - outer_pushforward_prep::E -end +$(document_preparation("hvp"; same_point=true)) +""" +function gradient_and_hvp end -struct ReverseOverForwardHVPPrep{E<:GradientPrep} <: HVPPrep - # gradient of pushforward - outer_gradient_prep::E -end +""" + gradient_and_hvp!(f, grad, tg, [prep,] backend, x, tx, [contexts...]) -> (grad, tg) -struct ReverseOverReverseHVPPrep{E<:PullbackPrep} <: HVPPrep - # pullback of gradient - outer_pullback_prep::E -end +Compute the gradient and the Hessian-vector product of `f` at point `x` with a tuple of tangents `tx`, overwriting `grad` and `tg`. + +$(document_preparation("hvp"; same_point=true)) +""" +function gradient_and_hvp! end function prepare_hvp( f::F, backend::AbstractADType, x, tx::NTuple, contexts::Vararg{Context,C} @@ -78,6 +74,13 @@ function prepare_hvp( return _prepare_hvp_aux(hvp_mode(backend), f, backend, x, tx, contexts...) end +## Forward over forward + +struct ForwardOverForwardHVPPrep{E2<:PushforwardPrep} <: HVPPrep + # pushforward of many pushforwards in theory, but pushforward of gradient in practice + outer_pushforward_prep::E2 +end + function _prepare_hvp_aux( ::ForwardOverForward, f::F, @@ -94,63 +97,46 @@ function _prepare_hvp_aux( return ForwardOverForwardHVPPrep(outer_pushforward_prep) end -function _prepare_hvp_aux( - ::ForwardOverReverse, +function hvp( f::F, + prep::ForwardOverForwardHVPPrep, backend::AbstractADType, x, tx::NTuple, contexts::Vararg{Context,C}, ) where {F,C} + (; outer_pushforward_prep) = prep rewrap = Rewrap(contexts...) new_contexts = (Constant(f), Constant(inner(backend)), Constant(rewrap), contexts...) - outer_pushforward_prep = prepare_pushforward( - shuffled_gradient, outer(backend), x, tx, new_contexts... - ) - return ForwardOverReverseHVPPrep(outer_pushforward_prep) -end - -function _prepare_hvp_aux( - ::ReverseOverForward, - f::F, - backend::AbstractADType, - x, - tx::NTuple, - contexts::Vararg{Context,C}, -) where {F,C} - rewrap = Rewrap(contexts...) - new_contexts = ( - Constant(f), - Constant(inner(backend)), - Constant(first(tx)), - Constant(rewrap), - contexts..., - ) - outer_gradient_prep = prepare_gradient( - shuffled_single_pushforward, outer(backend), x, new_contexts... + return pushforward( + shuffled_gradient, outer_pushforward_prep, outer(backend), x, tx, new_contexts... ) - return ReverseOverForwardHVPPrep(outer_gradient_prep) end -function _prepare_hvp_aux( - ::ReverseOverReverse, +function hvp!( f::F, + tg::NTuple, + prep::ForwardOverForwardHVPPrep, backend::AbstractADType, x, tx::NTuple, contexts::Vararg{Context,C}, ) where {F,C} + (; outer_pushforward_prep) = prep rewrap = Rewrap(contexts...) new_contexts = (Constant(f), Constant(inner(backend)), Constant(rewrap), contexts...) - outer_pullback_prep = prepare_pullback( - shuffled_gradient, outer(backend), x, tx, new_contexts... + return pushforward!( + shuffled_gradient, + tg, + outer_pushforward_prep, + outer(backend), + x, + tx, + new_contexts..., ) - return ReverseOverReverseHVPPrep(outer_pullback_prep) end -## One argument - -function hvp( +function gradient_and_hvp( f::F, prep::ForwardOverForwardHVPPrep, backend::AbstractADType, @@ -161,14 +147,16 @@ function hvp( (; outer_pushforward_prep) = prep rewrap = Rewrap(contexts...) new_contexts = (Constant(f), Constant(inner(backend)), Constant(rewrap), contexts...) - return pushforward( + return value_and_pushforward( shuffled_gradient, outer_pushforward_prep, outer(backend), x, tx, new_contexts... ) end -function hvp( +function gradient_and_hvp!( f::F, - prep::ForwardOverReverseHVPPrep, + grad, + tg::NTuple, + prep::ForwardOverForwardHVPPrep, backend::AbstractADType, x, tx::NTuple, @@ -177,57 +165,61 @@ function hvp( (; outer_pushforward_prep) = prep rewrap = Rewrap(contexts...) new_contexts = (Constant(f), Constant(inner(backend)), Constant(rewrap), contexts...) - return pushforward( - shuffled_gradient, outer_pushforward_prep, outer(backend), x, tx, new_contexts... + new_grad, _ = value_and_pushforward!( + shuffled_gradient, + tg, + outer_pushforward_prep, + outer(backend), + x, + tx, + new_contexts..., ) + return copyto!(grad, new_grad), tg end -function hvp( +## Forward over reverse + +struct ForwardOverReverseHVPPrep{E2<:PushforwardPrep} <: HVPPrep + # pushforward of gradient + outer_pushforward_prep::E2 +end + +function _prepare_hvp_aux( + ::ForwardOverReverse, f::F, - prep::ReverseOverForwardHVPPrep, backend::AbstractADType, x, tx::NTuple, contexts::Vararg{Context,C}, ) where {F,C} - (; outer_gradient_prep) = prep rewrap = Rewrap(contexts...) - tg = map(tx) do dx - gradient( - shuffled_single_pushforward, - outer_gradient_prep, - outer(backend), - x, - Constant(f), - Constant(inner(backend)), - Constant(dx), - Constant(rewrap), - contexts..., - ) - end - return tg + new_contexts = (Constant(f), Constant(inner(backend)), Constant(rewrap), contexts...) + outer_pushforward_prep = prepare_pushforward( + shuffled_gradient, outer(backend), x, tx, new_contexts... + ) + return ForwardOverReverseHVPPrep(outer_pushforward_prep) end function hvp( f::F, - prep::ReverseOverReverseHVPPrep, + prep::ForwardOverReverseHVPPrep, backend::AbstractADType, x, tx::NTuple, contexts::Vararg{Context,C}, ) where {F,C} - (; outer_pullback_prep) = prep + (; outer_pushforward_prep) = prep rewrap = Rewrap(contexts...) new_contexts = (Constant(f), Constant(inner(backend)), Constant(rewrap), contexts...) - return pullback( - shuffled_gradient, outer_pullback_prep, outer(backend), x, tx, new_contexts... + return pushforward( + shuffled_gradient, outer_pushforward_prep, outer(backend), x, tx, new_contexts... ) end function hvp!( f::F, tg::NTuple, - prep::ForwardOverForwardHVPPrep, + prep::ForwardOverReverseHVPPrep, backend::AbstractADType, x, tx::NTuple, @@ -247,8 +239,25 @@ function hvp!( ) end -function hvp!( +function gradient_and_hvp( + f::F, + prep::ForwardOverReverseHVPPrep, + backend::AbstractADType, + x, + tx::NTuple, + contexts::Vararg{Context,C}, +) where {F,C} + (; outer_pushforward_prep) = prep + rewrap = Rewrap(contexts...) + new_contexts = (Constant(f), Constant(inner(backend)), Constant(rewrap), contexts...) + return value_and_pushforward( + shuffled_gradient, outer_pushforward_prep, outer(backend), x, tx, new_contexts... + ) +end + +function gradient_and_hvp!( f::F, + grad, tg::NTuple, prep::ForwardOverReverseHVPPrep, backend::AbstractADType, @@ -259,7 +268,7 @@ function hvp!( (; outer_pushforward_prep) = prep rewrap = Rewrap(contexts...) new_contexts = (Constant(f), Constant(inner(backend)), Constant(rewrap), contexts...) - return pushforward!( + new_grad, _ = value_and_pushforward!( shuffled_gradient, tg, outer_pushforward_prep, @@ -268,6 +277,64 @@ function hvp!( tx, new_contexts..., ) + return copyto!(grad, new_grad), tg +end + +## Reverse over forward + +struct ReverseOverForwardHVPPrep{E2<:GradientPrep,E1<:GradientPrep} <: HVPPrep + # gradient of pushforward + outer_gradient_prep::E2 + gradient_prep::E1 +end + +function _prepare_hvp_aux( + ::ReverseOverForward, + f::F, + backend::AbstractADType, + x, + tx::NTuple, + contexts::Vararg{Context,C}, +) where {F,C} + rewrap = Rewrap(contexts...) + new_contexts = ( + Constant(f), + Constant(inner(backend)), + Constant(first(tx)), + Constant(rewrap), + contexts..., + ) + outer_gradient_prep = prepare_gradient( + shuffled_single_pushforward, outer(backend), x, new_contexts... + ) + gradient_prep = prepare_gradient(f, inner(backend), x, contexts...) + return ReverseOverForwardHVPPrep(outer_gradient_prep, gradient_prep) +end + +function hvp( + f::F, + prep::ReverseOverForwardHVPPrep, + backend::AbstractADType, + x, + tx::NTuple, + contexts::Vararg{Context,C}, +) where {F,C} + (; outer_gradient_prep) = prep + rewrap = Rewrap(contexts...) + tg = map(tx) do dx + gradient( + shuffled_single_pushforward, + outer_gradient_prep, + outer(backend), + x, + Constant(f), + Constant(inner(backend)), + Constant(dx), + Constant(rewrap), + contexts..., + ) + end + return tg end function hvp!( @@ -298,6 +365,73 @@ function hvp!( return tg end +function gradient_and_hvp( + f::F, + prep::ReverseOverForwardHVPPrep, + backend::AbstractADType, + x, + tx::NTuple, + contexts::Vararg{Context,C}, +) where {F,C} + tg = hvp(f, prep, backend, x, tx, contexts...) + grad = gradient(f, prep.gradient_prep, inner(backend), x, contexts...) + return grad, tg +end + +function gradient_and_hvp!( + f::F, + grad, + tg::NTuple, + prep::ReverseOverForwardHVPPrep, + backend::AbstractADType, + x, + tx::NTuple, + contexts::Vararg{Context,C}, +) where {F,C} + hvp!(f, tg, prep, backend, x, tx, contexts...) + gradient!(f, grad, prep.gradient_prep, inner(backend), x, contexts...) + return grad, tg +end + +## Reverse over reverse + +struct ReverseOverReverseHVPPrep{E2<:PullbackPrep} <: HVPPrep + # pullback of gradient + outer_pullback_prep::E2 +end + +function _prepare_hvp_aux( + ::ReverseOverReverse, + f::F, + backend::AbstractADType, + x, + tx::NTuple, + contexts::Vararg{Context,C}, +) where {F,C} + rewrap = Rewrap(contexts...) + new_contexts = (Constant(f), Constant(inner(backend)), Constant(rewrap), contexts...) + outer_pullback_prep = prepare_pullback( + shuffled_gradient, outer(backend), x, tx, new_contexts... + ) + return ReverseOverReverseHVPPrep(outer_pullback_prep) +end + +function hvp( + f::F, + prep::ReverseOverReverseHVPPrep, + backend::AbstractADType, + x, + tx::NTuple, + contexts::Vararg{Context,C}, +) where {F,C} + (; outer_pullback_prep) = prep + rewrap = Rewrap(contexts...) + new_contexts = (Constant(f), Constant(inner(backend)), Constant(rewrap), contexts...) + return pullback( + shuffled_gradient, outer_pullback_prep, outer(backend), x, tx, new_contexts... + ) +end + function hvp!( f::F, tg::NTuple, @@ -314,3 +448,38 @@ function hvp!( shuffled_gradient, tg, outer_pullback_prep, outer(backend), x, tx, new_contexts... ) end + +function gradient_and_hvp( + f::F, + prep::ReverseOverReverseHVPPrep, + backend::AbstractADType, + x, + tx::NTuple, + contexts::Vararg{Context,C}, +) where {F,C} + (; outer_pullback_prep) = prep + rewrap = Rewrap(contexts...) + new_contexts = (Constant(f), Constant(inner(backend)), Constant(rewrap), contexts...) + return value_and_pullback( + shuffled_gradient, outer_pullback_prep, outer(backend), x, tx, new_contexts... + ) +end + +function gradient_and_hvp!( + f::F, + grad, + tg::NTuple, + prep::ReverseOverReverseHVPPrep, + backend::AbstractADType, + x, + tx::NTuple, + contexts::Vararg{Context,C}, +) where {F,C} + (; outer_pullback_prep) = prep + rewrap = Rewrap(contexts...) + new_contexts = (Constant(f), Constant(inner(backend)), Constant(rewrap), contexts...) + new_grad, _ = value_and_pullback!( + shuffled_gradient, tg, outer_pullback_prep, outer(backend), x, tx, new_contexts... + ) + return copyto!(grad, new_grad), tg +end diff --git a/DifferentiationInterfaceTest/Project.toml b/DifferentiationInterfaceTest/Project.toml index 24dbf51b..d5d6953e 100644 --- a/DifferentiationInterfaceTest/Project.toml +++ b/DifferentiationInterfaceTest/Project.toml @@ -51,6 +51,7 @@ ForwardDiff = "0.10.36" Functors = "0.4" JET = "0.4 - 0.8, 0.9" JLArrays = "0.1" +JuliaFormatter = "1" LinearAlgebra = "<0.0.1,1" Lux = "1.1.0" LuxTestUtils = "1.3.1" diff --git a/DifferentiationInterfaceTest/src/tests/allocs_eval.jl b/DifferentiationInterfaceTest/src/tests/allocs_eval.jl index 5128827f..94b5df67 100644 --- a/DifferentiationInterfaceTest/src/tests/allocs_eval.jl +++ b/DifferentiationInterfaceTest/src/tests/allocs_eval.jl @@ -7,8 +7,10 @@ for op in ALL_OPS op! = Symbol(op, "!") val_prefix = if op == :second_derivative "value_derivative_and_" - elseif op in [:hessian, :hvp] + elseif op == :hessian "value_gradient_and_" + elseif op == :hvp + "gradient_and_" else "value_and_" end @@ -200,7 +202,11 @@ for op in ALL_OPS (; f, x, tang, contexts) = deepcopy(scen) prep = $prep_op(f, ba, x, tang, contexts...) (subset == :full) && test_noallocs(skip, $prep_op, f, ba, x, tang, contexts...) + (subset == :full) && + test_noallocs(skip, $val_and_op, f, ba, x, tang, contexts...) (subset == :full) && test_noallocs(skip, $op, f, ba, x, tang, contexts...) + (subset != :none) && + test_noallocs(skip, $val_and_op, f, prep, ba, x, tang, contexts...) (subset != :none) && test_noallocs(skip, $op, f, prep, ba, x, tang, contexts...) return nothing end @@ -214,8 +220,23 @@ for op in ALL_OPS (subset == :full) && test_noallocs(skip, $prep_op, f, ba, x, tang, contexts...) (subset == :full) && test_noallocs(skip, $op!, f, res2_sim, ba, x, tang, contexts...) + (subset == :full) && test_noallocs( + skip, $val_and_op!, f, res1_sim, res2_sim, ba, x, tang, contexts... + ) (subset != :none) && test_noallocs(skip, $op!, f, res2_sim, prep, ba, x, tang, contexts...) + (subset != :none) && test_noallocs( + skip, + $val_and_op!, + f, + res1_sim, + res2_sim, + prep, + ba, + x, + tang, + contexts..., + ) return nothing end end diff --git a/DifferentiationInterfaceTest/src/tests/benchmark_eval.jl b/DifferentiationInterfaceTest/src/tests/benchmark_eval.jl index 2f5d6733..586fd70e 100644 --- a/DifferentiationInterfaceTest/src/tests/benchmark_eval.jl +++ b/DifferentiationInterfaceTest/src/tests/benchmark_eval.jl @@ -18,8 +18,10 @@ for op in ALL_OPS op! = Symbol(op, "!") val_prefix = if op == :second_derivative "value_derivative_and_" - elseif op in [:hessian, :hvp] + elseif op == :hessian "value_gradient_and_" + elseif op == :hvp + "gradient_and_" else "value_and_" end @@ -561,11 +563,11 @@ for op in ALL_OPS @eval function benchmark_aux(ba::AbstractADType, scen::$S1out; subset::Symbol) (; f, x, tang, contexts) = deepcopy(scen) prep = $prep_op(f, ba, x, tang, contexts...) - prepared_valop = @be +(1, 1) # TODO: fix + prepared_valop = @be prep $val_and_op(f, _, ba, x, tang, contexts...) prepared_op = @be prep $op(f, _, ba, x, tang, contexts...) if subset == :full preparation = @be $prep_op(f, ba, x, tang, contexts...) - unprepared_valop = @be +(1, 1) # TODO: fix + unprepared_valop = @be $val_and_op(f, ba, x, tang, contexts...) unprepared_op = @be $op(f, ba, x, tang, contexts...) return BenchmarkResult(; prepared_valop, @@ -584,10 +586,12 @@ for op in ALL_OPS cc = CallCounter(f) prep = $prep_op(cc, ba, x, tang, contexts...) preparation = reset_count!(cc) - prepared_valop = -1 # TODO: fix + $val_and_op(cc, prep, ba, x, tang, contexts...) + prepared_valop = reset_count!(cc) $op(cc, prep, ba, x, tang, contexts...) prepared_op = reset_count!(cc) - unprepared_valop = -1 # TODO: fix + $val_and_op(cc, ba, x, tang, contexts...) + unprepared_valop = reset_count!(cc) $op(cc, ba, x, tang, contexts...) unprepared_op = reset_count!(cc) return CallsResult(; @@ -596,15 +600,19 @@ for op in ALL_OPS end @eval function benchmark_aux(ba::AbstractADType, scen::$S1in; subset::Symbol) - (; f, x, tang, res2, contexts) = deepcopy(scen) + (; f, x, tang, res1, res2, contexts) = deepcopy(scen) prep = $prep_op(f, ba, x, tang, contexts...) - prepared_valop = @be +(1, 1) # TODO: fix + prepared_valop = @be (mysimilar(res1), mysimilar(res2), prep) $val_and_op!( + f, _[1], _[2], _[3], ba, x, tang, contexts... + ) prepared_op = @be (mysimilar(res2), prep) $op!( f, _[1], _[2], ba, x, tang, contexts... ) if subset == :full preparation = @be $prep_op(f, ba, x, tang, contexts...) - unprepared_valop = @be +(1, 1) # TODO: fix + unprepared_valop = @be (mysimilar(res1), mysimilar(res2)) $val_and_op!( + f, _[1], _[2], ba, x, tang, contexts... + ) unprepared_op = @be mysimilar(res2) $op!(f, _, ba, x, tang, contexts...) return BenchmarkResult(; prepared_valop, @@ -619,14 +627,18 @@ for op in ALL_OPS end @eval function calls_aux(ba::AbstractADType, scen::$S1in; subset::Symbol) - (; f, x, tang, res2, contexts) = deepcopy(scen) + (; f, x, tang, res1, res2, contexts) = deepcopy(scen) cc = CallCounter(f) prep = $prep_op(cc, ba, x, tang, contexts...) preparation = reset_count!(cc) - prepared_valop = -1 # TODO: fix + $val_and_op!( + cc, mysimilar(res1), mysimilar(res2), prep, ba, x, tang, contexts... + ) + prepared_valop = reset_count!(cc) $op!(cc, mysimilar(res2), prep, ba, x, tang, contexts...) prepared_op = reset_count!(cc) - unprepared_valop = -1 # TODO: fix + $val_and_op!(cc, mysimilar(res1), mysimilar(res2), ba, x, tang, contexts...) + unprepared_valop = reset_count!(cc) $op!(cc, mysimilar(res2), ba, x, tang, contexts...) unprepared_op = reset_count!(cc) return CallsResult(; diff --git a/DifferentiationInterfaceTest/src/tests/correctness_eval.jl b/DifferentiationInterfaceTest/src/tests/correctness_eval.jl index 7d990fe9..da91c0e6 100644 --- a/DifferentiationInterfaceTest/src/tests/correctness_eval.jl +++ b/DifferentiationInterfaceTest/src/tests/correctness_eval.jl @@ -2,8 +2,10 @@ for op in ALL_OPS op! = Symbol(op, "!") val_prefix = if op == :second_derivative "value_derivative_and_" - elseif op in [:hessian, :hvp] + elseif op == :hessian "value_gradient_and_" + elseif op == :hvp + "gradient_and_" else "value_and_" end @@ -669,28 +671,40 @@ for op in ALL_OPS scenario_intact::Bool, sparsity::Bool, ) - (; f, x, y, tang, res2, contexts) = new_scen = deepcopy(scen) + (; f, x, y, tang, res1, res2, contexts) = new_scen = deepcopy(scen) xrand, tangrand = myrandom(x), myrandom(tang) rewrap = Rewrap(contexts...) contextsrand = rewrap(map(myrandom ∘ unwrap, contexts)...) - prep = $prep_op(f, ba, xrand, tangrand, contextsrand...) - prepprep = $prep_op!( - f, - $prep_op(f, ba, xrand, tangrand, contextsrand...), - ba, - xrand, - tangrand, - contextsrand..., - ) - prep_same = $prep_op_same(f, ba, x, tangrand, contexts...) - preptup_cands_noval = [(), (prep,), (prepprep,), (prep_same,)] - for preptup_noval in preptup_cands_noval + preptup_cands_val, preptup_cands_noval = map(1:2) do _ + prep = $prep_op(f, ba, xrand, tangrand, contextsrand...) + prepprep = $prep_op!( + f, + $prep_op(f, ba, xrand, tangrand, contextsrand...), + ba, + xrand, + tangrand, + contextsrand..., + ) + prep_same = $prep_op_same(f, ba, x, tangrand, contexts...) + [(), (prep,), (prepprep,), (prep_same,)] + end + for (preptup_val, preptup_noval) in zip(preptup_cands_val, preptup_cands_noval) res2_out1_noval = $op(f, preptup_noval..., ba, x, tang, contexts...) res2_out2_noval = $op(f, preptup_noval..., ba, x, tang, contexts...) + res1_out1_val, res2_out1_val = $val_and_op( + f, preptup_noval..., ba, x, tang, contexts... + ) + res1_out2_val, res2_out2_val = $val_and_op( + f, preptup_noval..., ba, x, tang, contexts... + ) let (≈)(x, y) = isapprox(x, y; atol, rtol) @test isempty(preptup_noval) || only(preptup_noval) isa $P @test all(res2_out1_noval .≈ scen.res2) @test all(res2_out2_noval .≈ scen.res2) + @test res1_out1_val ≈ scen.res1 + @test res1_out2_val ≈ scen.res1 + @test all(res2_out1_val .≈ scen.res2) + @test all(res2_out2_val .≈ scen.res2) end end scenario_intact && @test new_scen == scen @@ -706,36 +720,68 @@ for op in ALL_OPS scenario_intact::Bool, sparsity::Bool, ) - (; f, x, y, tang, res2, contexts) = new_scen = deepcopy(scen) + (; f, x, y, tang, res1, res2, contexts) = new_scen = deepcopy(scen) xrand, tangrand = myrandom(x), myrandom(tang) rewrap = Rewrap(contexts...) contextsrand = rewrap(map(myrandom ∘ unwrap, contexts)...) - prep = $prep_op(f, ba, xrand, tangrand, contextsrand...) - prepprep = $prep_op!( - f, - $prep_op(f, ba, xrand, tangrand, contextsrand...), - ba, - xrand, - tangrand, - contextsrand..., - ) - prep_same = $prep_op_same(f, ba, x, tangrand, contexts...) - preptup_cands_noval = [(), (prep,), (prepprep,), (prep_same,)] - for preptup_noval in preptup_cands_noval + preptup_cands_val, preptup_cands_noval = map(1:2) do _ + prep = $prep_op(f, ba, xrand, tangrand, contextsrand...) + prepprep = $prep_op!( + f, + $prep_op(f, ba, xrand, tangrand, contextsrand...), + ba, + xrand, + tangrand, + contextsrand..., + ) + prep_same = $prep_op_same(f, ba, x, tangrand, contexts...) + [(), (prep,), (prepprep,), (prep_same,)] + end + for (preptup_val, preptup_noval) in zip(preptup_cands_val, preptup_cands_noval) res2_in1_noval = mysimilar(res2) res2_in2_noval = mysimilar(res2) + res1_in1_val, res2_in1_val = mysimilar(res1), mysimilar(res2) + res1_in2_val, res2_in2_val = mysimilar(res1), mysimilar(res2) res2_out1_noval = $op!( f, res2_in1_noval, preptup_noval..., ba, x, tang, contexts... ) res2_out2_noval = $op!( f, res2_in2_noval, preptup_noval..., ba, x, tang, contexts... ) + res1_out1_val, res2_out1_val = $val_and_op!( + f, + res1_in1_val, + res2_in1_val, + preptup_noval..., + ba, + x, + tang, + contexts..., + ) + res1_out2_val, res2_out2_val = $val_and_op!( + f, + res1_in2_val, + res2_in2_val, + preptup_noval..., + ba, + x, + tang, + contexts..., + ) let (≈)(x, y) = isapprox(x, y; atol, rtol) @test isempty(preptup_noval) || only(preptup_noval) isa $P @test all(res2_in1_noval .≈ scen.res2) @test all(res2_in2_noval .≈ scen.res2) @test all(res2_out1_noval .≈ scen.res2) @test all(res2_out2_noval .≈ scen.res2) + @test res1_in1_val ≈ scen.res1 + @test res1_in2_val ≈ scen.res1 + @test res1_out1_val ≈ scen.res1 + @test res1_out2_val ≈ scen.res1 + @test all(res2_in1_val .≈ scen.res2) + @test all(res2_in2_val .≈ scen.res2) + @test all(res2_out1_val .≈ scen.res2) + @test all(res2_out2_val .≈ scen.res2) end end scenario_intact && @test new_scen == scen diff --git a/DifferentiationInterfaceTest/src/tests/type_stability_eval.jl b/DifferentiationInterfaceTest/src/tests/type_stability_eval.jl index d66971df..bf02fb56 100644 --- a/DifferentiationInterfaceTest/src/tests/type_stability_eval.jl +++ b/DifferentiationInterfaceTest/src/tests/type_stability_eval.jl @@ -2,8 +2,10 @@ for op in ALL_OPS op! = Symbol(op, "!") val_prefix = if op == :second_derivative "value_derivative_and_" - elseif op in [:hessian, :hvp] + elseif op == :hessian "value_gradient_and_" + elseif op == :hvp + "gradient_and_" else "value_and_" end @@ -325,9 +327,15 @@ for op in ALL_OPS (subset == :full) && @test_opt ignored_modules = ignored_modules function_filter = function_filter $op(f, ba, x, tang, contexts...) + (subset == :full) && + @test_opt ignored_modules = ignored_modules function_filter = + function_filter $val_and_op(f, ba, x, tang, contexts...) (subset in (:prepared, :full)) && @test_opt ignored_modules = ignored_modules function_filter = function_filter $op(f, prep, ba, x, tang, contexts...) + (subset in (:prepared, :full)) && + @test_opt ignored_modules = ignored_modules function_filter = + function_filter $val_and_op(f, prep, ba, x, tang, contexts...) return nothing end @@ -346,9 +354,19 @@ for op in ALL_OPS (subset == :full) && @test_opt ignored_modules = ignored_modules function_filter = function_filter $op!(f, mysimilar(res2), ba, x, tang, contexts...) + (subset == :full) && + @test_opt ignored_modules = ignored_modules function_filter = + function_filter $val_and_op!( + f, mysimilar(res1), mysimilar(res2), ba, x, tang, contexts... + ) (subset in (:prepared, :full)) && @test_opt ignored_modules = ignored_modules function_filter = function_filter $op!(f, mysimilar(res2), prep, ba, x, tang, contexts...) + (subset in (:prepared, :full)) && + @test_opt ignored_modules = ignored_modules function_filter = + function_filter $val_and_op!( + f, mysimilar(res1), mysimilar(res2), prep, ba, x, tang, contexts... + ) return nothing end end