Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix intermediate results (again), add test coverage for intermediate results. #139

Merged
merged 15 commits into from
Nov 11, 2023
Merged
2 changes: 1 addition & 1 deletion src/SimulationService.jl
Original file line number Diff line number Diff line change
Expand Up @@ -288,7 +288,7 @@ function solve(o::OperationRequest)
callback = get_callback(o)
T = route2operation_type[o.route]
op = T(o)
o.result = solve(op; callback)
o.result = solve(op, callback = callback)
end

#-----------------------------------------------------------------------------# DataServiceModel
Expand Down
7 changes: 3 additions & 4 deletions src/operations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,6 @@ function (o::IntermediateResults)(integrator)
o.last_callback = iter
state_dict = Dict(states(f.sys) .=> u)
param_dict = Dict(parameters(f.sys) .=> p)

publish_to_rabbitmq(; iter=iter, state=state_dict, params = param_dict, id=o.id,
retcode=SciMLBase.check_error(integrator))
end
Expand All @@ -136,13 +135,13 @@ end

# Intermediate results functor for calibrate
function (o::IntermediateResults)(p,lossval, ode_sol, ts)
if o.last_callback + o.every ≤ Dates.now()
if o.last_callback + o.every == o.iter
o.last_callback = o.iter
param_dict = Dict(parameters(ode_sol.prob.f.sys) .=> ode_sol.prob.p)
state_dict = Dict([state => ode_sol(first(ts))[state] for state in states(ode_sol.prob.f.sys)])
o.iter = o.iter + 1
publish_to_rabbitmq(; iter = o.iter, loss = lossval, sol_data = state_dict, timesteps = first(ts), params = param_dict, id=o.id)
end

o.iter = o.iter + 1
return false
end
#----------------------------------------------------------------------# dataframe_with_observables
Expand Down
11 changes: 8 additions & 3 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ using SimulationService
using SimulationService: DataServiceModel, OperationRequest, Simulate, Calibrate, Ensemble, get_json

SimulationService.ENABLE_TDS[] = false
SimulationService.RABBITMQ_ENABLED[] = false
SimulationService.PORT[] = 8080 # avoid 8000 in case another server is running

# joinpath(root_of_repo, args...)
Expand Down Expand Up @@ -139,9 +140,11 @@ end
num_iterations = 100
calibrate_method = "bayesian"
ode_method = nothing
op = OperationRequest() # to test callback
op.id = "1"
o = SimulationService.Calibrate(sys, (0.0, 89.0), priors, data, num_chains, num_iterations, calibrate_method, ode_method)

dfsim, dfparam = solve(o; callback = nothing)
dfsim, dfparam = solve(o, callback = SimulationService.get_callback(op,SimulationService.Calibrate))

statenames = [states(o.sys); getproperty.(observed(o.sys), :lhs)]
@test names(dfsim) == vcat("timestamp",reduce(vcat,[string.("ensemble",i,"_", statenames) for i in 1:size(dfsim,2)÷length(statenames)]))
Expand All @@ -153,7 +156,7 @@ end

calibrate_method = "global"
o = SimulationService.Calibrate(sys, (0.0, 89.0), priors, data, num_chains, num_iterations, calibrate_method, ode_method)
dfsim, dfparam = SimulationService.solve(o; callback = nothing)
dfsim, dfparam = SimulationService.solve(o; callback = SimulationService.get_callback(op,SimulationService.Calibrate))

statenames = [states(o.sys);getproperty.(observed(o.sys), :lhs)]
@test names(dfsim) == vcat("timestamp",string.(statenames))
Expand Down Expand Up @@ -208,7 +211,9 @@ end
ode_method = nothing

o = SimulationService.Calibrate(sys, (0.0, 89.0), priors, data, num_chains, num_iterations, calibrate_method, ode_method)
dfsim, dfparam = SimulationService.solve(o; callback = nothing)
op = OperationRequest()
op.id = "1"
dfsim, dfparam = SimulationService.solve(o, callback = SimulationService.get_callback(op,SimulationService.Calibrate))

statenames = [states(o.sys);getproperty.(observed(o.sys), :lhs)]
@test names(dfsim) == vcat("timestamp",string.(statenames))
Expand Down
Loading