Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Inter results fixes #136

Merged
merged 3 commits into from
Nov 8, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 13 additions & 14 deletions src/operations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -110,22 +110,21 @@ function amr_get(df::DataFrame, sys::ODESystem, ::Val{:data})
end

#--------------------------------------------------------------------# IntermediateResults callback
# Publish intermediate results to RabbitMQ with at least `every` seconds in between callbacks
# Publish intermediate results to RabbitMQ with at least `every` iterations in between callbacks
mutable struct IntermediateResults
last_callback::Dates.DateTime # Track the last time the callback was called
every::Dates.TimePeriod # Callback frequency e.g. `Dates.Second(5)`
last_callback::Int # Track the last iteration the callback was called
every::Int # Callback frequency
id::String
iter::Int # Track how many iterations of the calibration have happened
function IntermediateResults(id::String; every=Dates.Second(0))
new(typemin(Dates.DateTime), every, id, 0)
function IntermediateResults(id::String; every = 10)
new(0, every, id, 0)
end
end

function (o::IntermediateResults)(integrator)
if o.last_callback + o.every ≤ Dates.now()
o.last_callback = Dates.now()
(; iter, f, t, u, p) = integrator

(; iter, f, t, u, p) = integrator
if o.last_callback + o.every == iter
o.last_callback = iter
state_dict = Dict(states(f.sys) .=> u)
param_dict = Dict(parameters(f.sys) .=> p)

Expand Down Expand Up @@ -170,7 +169,7 @@ function Simulate(o::OperationRequest)
end

function get_callback(o::OperationRequest, ::Type{Simulate})
DiscreteCallback((args...) -> true, IntermediateResults(o.id,every = Dates.Second(0)))
DiscreteCallback((args...) -> true, IntermediateResults(o.id,every = 10))
end

# callback for Simulate requests
Expand All @@ -195,7 +194,7 @@ end

# callback for Calibrate requests
function get_callback(o::OperationRequest, ::Type{Calibrate})
IntermediateResults(o.id,every = Dates.Second(0))
IntermediateResults(o.id,every = 10)
end

function Calibrate(o::OperationRequest)
Expand All @@ -211,7 +210,7 @@ function Calibrate(o::OperationRequest)
if :extra in keys(o.obj)
extrakeys = keys(o.obj.extra)
:num_chains in extrakeys && (num_chains = o.obj.extra.num_chains)
:num_iterations in extrakeys && (num_iterations = o.obj.extra.num_iterations)
:num_iterations in extrakeys && (num_iterations = o.obj.extra.num_iterations) # only for bayesian?
:calibrate_method in extrakeys && (calibrate_method = o.obj.extra.calibrate_method)
end
Calibrate(sys, o.timespan, priors, data, num_chains, num_iterations, calibrate_method, ode_method)
Expand All @@ -224,8 +223,8 @@ function solve(o::Calibrate; callback)
# bayesian datafit
if o.calibrate_method == "bayesian"
p_posterior = EasyModelAnalysis.bayesian_datafit(prob, o.priors, o.data;
nchains = 2,
niter = 100,
nchains = o.num_chains,
niter = o.num_iterations,
mcmcensemble = SimulationService.EasyModelAnalysis.Turing.MCMCSerial())

pvalues = last.(p_posterior)
Expand Down
Loading