Skip to content

Commit

Permalink
Intermediate results report timesteps (#138)
Browse files Browse the repository at this point in the history
  • Loading branch information
jClugstor authored Nov 10, 2023
1 parent a03db16 commit 8f263a0
Showing 1 changed file with 22 additions and 6 deletions.
28 changes: 22 additions & 6 deletions src/operations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -128,19 +128,19 @@ function (o::IntermediateResults)(integrator)
state_dict = Dict(states(f.sys) .=> u)
param_dict = Dict(parameters(f.sys) .=> p)

publish_to_rabbitmq(; iter=iter, time=t, state=state_dict, params = param_dict, id=o.id,
publish_to_rabbitmq(; iter=iter, state=state_dict, params = param_dict, id=o.id,
retcode=SciMLBase.check_error(integrator))
end
EasyModelAnalysis.DifferentialEquations.u_modified!(integrator, false)
end

# Intermediate results functor for calibrate
function (o::IntermediateResults)(p,lossval,ode_sol)
function (o::IntermediateResults)(p,lossval, ode_sol, ts)
if o.last_callback + o.every Dates.now()
param_dict = Dict(parameters(ode_sol.prob.f.sys) .=> ode_sol.prob.p)
state_dict = Dict([state => ode_sol[state] for state in states(ode_sol.prob.f.sys)])
state_dict = Dict([state => ode_sol(first(ts))[state] for state in states(ode_sol.prob.f.sys)])
o.iter = o.iter + 1
publish_to_rabbitmq(; iter = o.iter, loss = lossval, sol_data = state_dict, params = param_dict, id=o.id)
publish_to_rabbitmq(; iter = o.iter, loss = lossval, sol_data = state_dict, timesteps = first(ts), params = param_dict, id=o.id)
end

return false
Expand Down Expand Up @@ -249,10 +249,10 @@ function solve(o::Calibrate; callback)
elseif o.calibrate_method == "local" || o.calibrate_method == "global"
if o.calibrate_method == "local"
init_params = Pair.(EasyModelAnalysis.ModelingToolkit.Num.(first.(o.priors)), Statistics.mean.(last.(o.priors)))
fit = EasyModelAnalysis.datafit(prob, init_params, o.data, solve_kws = (callback = callback,))
fit = EasyModelAnalysis.datafit(prob, init_params, o.data, loss = sciml_service_l2loss, solve_kws = (callback = callback,))
else
init_params = Pair.(EasyModelAnalysis.ModelingToolkit.Num.(first.(o.priors)), tuple.(minimum.(last.(o.priors)), maximum.(last.(o.priors))))
fit = EasyModelAnalysis.global_datafit(prob, init_params, o.data, solve_kws = (callback = callback,))
fit = EasyModelAnalysis.global_datafit(prob, init_params, o.data, loss = sciml_service_l2loss, solve_kws = (callback = callback,))
end

newprob = EasyModelAnalysis.DifferentialEquations.remake(prob, p=fit)
Expand Down Expand Up @@ -383,3 +383,19 @@ const route2operation_type = Dict(
"ensemble-simulate" => Ensemble{Simulate},
"ensemble-calibrate" => Ensemble{Calibrate}
)

function sciml_service_l2loss(pvals, (prob, pkeys, data)::Tuple{Vararg{Any, 3}})
p = Pair.(pkeys, pvals)
ts = first.(last.(data))
lastt = maximum(last.(ts))
timeseries = last.(last.(data))
datakeys = first.(data)

prob = DifferentialEquations.remake(prob, tspan = (prob.tspan[1], lastt), p = p)
sol = solve(prob)
tot_loss = 0.0
for i in 1:length(ts)
tot_loss += sum((sol(ts[i]; idxs = datakeys[i]) .- timeseries[i]) .^ 2)
end
return tot_loss, sol, ts
end

0 comments on commit 8f263a0

Please sign in to comment.