diff --git a/src/sophia.jl b/src/sophia.jl index 88b0812c3..b63f0c099 100644 --- a/src/sophia.jl +++ b/src/sophia.jl @@ -75,7 +75,7 @@ function SciMLBase.__solve(cache::OptimizationCache{ gₜ = zero(θ) mₜ = zero(θ) hₜ = zero(θ) - for _ in 1:maxiters + for epoch in 1:maxiters for (i, d) in enumerate(data) if cache.f.fg !== nothing && dataiterate x = cache.f.fg(gₜ, θ, d) @@ -88,7 +88,7 @@ function SciMLBase.__solve(cache::OptimizationCache{ cache.f.grad(gₜ, θ) x = cache.f(θ) end - opt_state = Optimization.OptimizationState(; iter = i, + opt_state = Optimization.OptimizationState(; iter = i + (epoch - 1) * length(data), u = θ, objective = first(x), grad = gₜ,