From 162075c292251aa8df043ade2cc891e5dd6d76d1 Mon Sep 17 00:00:00 2001 From: Vaibhav Dixit Date: Thu, 13 Jul 2023 10:16:22 +0530 Subject: [PATCH] make storage in gradient inplace zero --- ext/OptimizationEnzymeExt.jl | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/ext/OptimizationEnzymeExt.jl b/ext/OptimizationEnzymeExt.jl index 1f58e09ed..2c758e5bc 100644 --- a/ext/OptimizationEnzymeExt.jl +++ b/ext/OptimizationEnzymeExt.jl @@ -13,10 +13,9 @@ function Optimization.instantiate_function(f::OptimizationFunction{true}, x, if f.grad === nothing function grad(res, θ, args...) - dθ = zero(res) - Enzyme.autodiff(Enzyme.Reverse, _f, Enzyme.Duplicated(θ, dθ), + res .= zero(eltype(res)) + Enzyme.autodiff(Enzyme.Reverse, _f, Enzyme.Duplicated(θ, res), Enzyme.DuplicatedNoNeed(zeros(1), ones(1)), args...) - res .= dθ end else grad = (G, θ, args...) -> f.grad(G, θ, p, args...) @@ -124,10 +123,9 @@ function Optimization.instantiate_function(f::OptimizationFunction{true}, if f.grad === nothing function grad(res, θ, args...) - dθ = zero(res) - Enzyme.autodiff(Enzyme.Reverse, _f, Enzyme.Duplicated(θ, dθ), + res .= zero(eltype(res)) + Enzyme.autodiff(Enzyme.Reverse, _f, Enzyme.Duplicated(θ, res), Enzyme.DuplicatedNoNeed(zeros(1), ones(1)), args...) - res .= dθ end else grad = (G, θ, args...) -> f.grad(G, θ, p, args...)