Skip to content

Commit

Permalink
fixed callback to return false, added opt counter
Browse files Browse the repository at this point in the history
  • Loading branch information
jClugstor committed Oct 23, 2023
1 parent e801793 commit d420d6f
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 1 deletion.
3 changes: 3 additions & 0 deletions src/SimulationService.jl
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,9 @@ const RABBITMQ_ROUTE = Ref{String}()
const RABBITMQ_HOST = Ref{String}()
const RABBITMQ_PORT = Ref{Int}()

# I don't like this, but needed for now to count optimization iterations
opt_callback_counter = Ref{Int}()

function __init__()
if Threads.nthreads() == 1
@warn "SimulationService.jl expects `Threads.nthreads() > 1`. Use e.g. `julia --threads=auto`."
Expand Down
5 changes: 4 additions & 1 deletion src/operations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,9 @@ function get_callback(o::OperationRequest, ::Type{Calibrate})
function (p,lossval,ode_sol)
param_dict = Dict(parameters(ode_sol.prob.f.sys) .=> ode_sol.prob.p)
state_dict = Dict([state => ode_sol[state] for state in states(ode_sol.prob.f.sys)])
publish_to_rabbitmq(; loss = lossval, sol_data = state_dict, params = param_dict, id=o.id)
opt_callback_counter[] = opt_callback_counter[] + 1
publish_to_rabbitmq(; loss = lossval, sol_data = state_dict, params = param_dict, id=o.id, iteration = opt_callback_counter[])
return false
end
end

Expand All @@ -214,6 +216,7 @@ function solve(o::Calibrate; callback)
prob = ODEProblem(o.sys, [], o.timespan)
statenames = [states(o.sys);getproperty.(observed(o.sys), :lhs)]

opt_iter_counter[] = 0
# bayesian datafit
if o.calibrate_method == "bayesian"
p_posterior = EasyModelAnalysis.bayesian_datafit(prob, o.priors, o.data;
Expand Down

0 comments on commit d420d6f

Please sign in to comment.