diff --git a/ext/OptimizationEnzymeExt.jl b/ext/OptimizationEnzymeExt.jl index 2c758e5bc..b419b0d8d 100644 --- a/ext/OptimizationEnzymeExt.jl +++ b/ext/OptimizationEnzymeExt.jl @@ -9,38 +9,34 @@ isdefined(Base, :get_extension) ? (using Enzyme) : (using ..Enzyme) function Optimization.instantiate_function(f::OptimizationFunction{true}, x, adtype::AutoEnzyme, p, num_cons = 0) - _f = (θ, y, args...) -> (y .= first(f.f(θ, p, args...)); return nothing) + _f = (f, θ, args...) -> first(f(θ, p, args...)) if f.grad === nothing function grad(res, θ, args...) res .= zero(eltype(res)) - Enzyme.autodiff(Enzyme.Reverse, _f, Enzyme.Duplicated(θ, res), - Enzyme.DuplicatedNoNeed(zeros(1), ones(1)), args...) + Enzyme.autodiff(Enzyme.Reverse, Const(_f), Const(f.f), Enzyme.Duplicated(θ, res), args...) end else grad = (G, θ, args...) -> f.grad(G, θ, p, args...) end if f.hess === nothing - function g(θ, bθ, y, by, args...) - Enzyme.autodiff_deferred(Enzyme.Reverse, _f, Enzyme.Duplicated(θ, bθ), Enzyme.DuplicatedNoNeed(y, by), args...) + function g(θ, bθ, _f, f, args...) + Enzyme.autodiff_deferred(Enzyme.Reverse, Const(_f), Const(f), Enzyme.Duplicated(θ, bθ), args...) return nothing end function hess(res, θ, args...) - y = Vector{Float64}(undef, 1) - vdθ = Tuple((Array(r) for r in eachrow(I(length(θ)) * 1.0))) bθ = zeros(length(θ)) - by = ones(1) vdbθ = Tuple(zeros(length(θ)) for i in eachindex(θ)) Enzyme.autodiff(Enzyme.Forward, g, Enzyme.BatchDuplicated(θ, vdθ), Enzyme.BatchDuplicated(bθ, vdbθ), - Const(y), - Const(by), + Const(_f), + Const(f.f), args...) for i in eachindex(θ) @@ -52,16 +48,17 @@ function Optimization.instantiate_function(f::OptimizationFunction{true}, x, end if f.hv === nothing + function f2(x, v, _f, f, args...)::Float64 + dx = zeros(length(x)) + Enzyme.autodiff_deferred(Enzyme.Reverse, Const(_f), + Const(f), + Enzyme.Duplicated(x, dx), + args...) + Float64(dot(dx, v)) + end hv = function (H, θ, v, args...) - function f2(x, v, args...)::Float64 - dx = zeros(length(x)) - Enzyme.autodiff_deferred(Enzyme.Reverse, _f, - Enzyme.Duplicated(x, dx), - Enzyme.DuplicatedNoNeed(zeros(1), ones(1)), - args...) - Float64(dot(dx, v)) - end - H .= Enzyme.gradient(Enzyme.Forward, x -> f2(x, v, args...), θ) + H .= zero(eltype(H)) + Enzyme.autodiff(Enzyme.Forward, f2, Duplicated(θ, H), v, Const(_f), Const(f.f), args...) end else hv = f.hv @@ -119,39 +116,35 @@ function Optimization.instantiate_function(f::OptimizationFunction{true}, cache::Optimization.ReInitCache, adtype::AutoEnzyme, num_cons = 0) - _f = (θ, y, args...) -> (y .= first(f.f(θ, cache.p, args...)); return nothing) + _f = (f, θ, args...) -> first(f(θ, cache.p, args...)) if f.grad === nothing function grad(res, θ, args...) res .= zero(eltype(res)) - Enzyme.autodiff(Enzyme.Reverse, _f, Enzyme.Duplicated(θ, res), - Enzyme.DuplicatedNoNeed(zeros(1), ones(1)), args...) + Enzyme.autodiff(Enzyme.Reverse, Const(_f), Const(f.f), Enzyme.Duplicated(θ, res), args...) end else grad = (G, θ, args...) -> f.grad(G, θ, p, args...) end if f.hess === nothing - function g(θ, bθ, y, by, args...) - Enzyme.autodiff_deferred(Enzyme.Reverse, _f, Enzyme.Duplicated(θ, bθ), - Enzyme.DuplicatedNoNeed(y, by), args...) + function g(θ, bθ, _f, f, args...) + Enzyme.autodiff_deferred(Enzyme.Reverse, Const(_f), Const(f), Enzyme.Duplicated(θ, bθ), + args...) return nothing end function hess(res, θ, args...) - y = Vector{Float64}(undef, 1) - vdθ = Tuple((Array(r) for r in eachrow(I(length(θ)) * 1.0))) bθ = zeros(length(θ)) - by = ones(1) vdbθ = Tuple(zeros(length(θ)) for i in eachindex(θ)) Enzyme.autodiff(Enzyme.Forward, g, Enzyme.BatchDuplicated(θ, vdθ), Enzyme.BatchDuplicated(bθ, vdbθ), - Const(y), - Const(by), + Const(_f), + Const(f.f), args...) for i in eachindex(θ) @@ -163,16 +156,17 @@ function Optimization.instantiate_function(f::OptimizationFunction{true}, end if f.hv === nothing + function f2(x, v, _f, f, args...)::Float64 + dx = zeros(length(x)) + Enzyme.autodiff_deferred(Enzyme.Reverse, Const(_f), + Const(f), + Enzyme.Duplicated(x, dx), + args...) + Float64(dot(dx, v)) + end hv = function (H, θ, v, args...) - function f2(x, v, args...)::Float64 - dx = zeros(length(x)) - Enzyme.autodiff_deferred(Enzyme.Reverse, _f, - Enzyme.Duplicated(x, dx), - Enzyme.DuplicatedNoNeed(zeros(1), ones(1)), - args...) - Float64(dot(dx, v)) - end - H .= Enzyme.gradient(Enzyme.Forward, x -> f2(x, v, args...), θ) + H .= zero(eltype(H)) + Enzyme.autodiff(Enzyme.Forward, f2, Duplicated(θ, H), v, Const(_f), Const(f.f), args...) end else hv = f.hv