Skip to content

Commit

Permalink
make storage in gradient inplace zero
Browse files Browse the repository at this point in the history
  • Loading branch information
Vaibhavdixit02 committed Jul 13, 2023
1 parent 51b5f68 commit 162075c
Showing 1 changed file with 4 additions and 6 deletions.
10 changes: 4 additions & 6 deletions ext/OptimizationEnzymeExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,9 @@ function Optimization.instantiate_function(f::OptimizationFunction{true}, x,

if f.grad === nothing
function grad(res, θ, args...)
= zero(res)
Enzyme.autodiff(Enzyme.Reverse, _f, Enzyme.Duplicated(θ, ),
res .= zero(eltype(res))
Enzyme.autodiff(Enzyme.Reverse, _f, Enzyme.Duplicated(θ, res),

Check warning on line 17 in ext/OptimizationEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/OptimizationEnzymeExt.jl#L16-L17

Added lines #L16 - L17 were not covered by tests
Enzyme.DuplicatedNoNeed(zeros(1), ones(1)), args...)
res .=
end
else
grad = (G, θ, args...) -> f.grad(G, θ, p, args...)
Expand Down Expand Up @@ -124,10 +123,9 @@ function Optimization.instantiate_function(f::OptimizationFunction{true},

if f.grad === nothing
function grad(res, θ, args...)
= zero(res)
Enzyme.autodiff(Enzyme.Reverse, _f, Enzyme.Duplicated(θ, ),
res .= zero(eltype(res))
Enzyme.autodiff(Enzyme.Reverse, _f, Enzyme.Duplicated(θ, res),

Check warning on line 127 in ext/OptimizationEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/OptimizationEnzymeExt.jl#L126-L127

Added lines #L126 - L127 were not covered by tests
Enzyme.DuplicatedNoNeed(zeros(1), ones(1)), args...)
res .=
end
else
grad = (G, θ, args...) -> f.grad(G, θ, p, args...)

Check warning on line 131 in ext/OptimizationEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/OptimizationEnzymeExt.jl#L131

Added line #L131 was not covered by tests
Expand Down

0 comments on commit 162075c

Please sign in to comment.