Skip to content

Commit

Permalink
allow for local and global fit options
Browse files Browse the repository at this point in the history
  • Loading branch information
ChrisRackauckas committed Jul 14, 2023
1 parent b30f2bf commit 8584c07
Show file tree
Hide file tree
Showing 5 changed files with 103 additions and 31 deletions.
1 change: 1 addition & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ ModelingToolkit = "961ee093-0014-501f-94e3-6117800e7a78"
OpenAPI = "d5e62ea6-ddf3-4d43-8e4c-ad5e6c8bfd7d"
Oxygen = "df9a0d86-3283-4920-82dc-4555fc0d1d8b"
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
StructTypes = "856f2bd8-1eba-4b0a-8007-ebc267875bd4"
SwaggerMarkdown = "1b6eb727-ad4b-44eb-9669-b9596a6e760f"
SymbolicUtils = "d1185830-fcd6-423d-90d6-eec64667417b"
Expand Down
25 changes: 19 additions & 6 deletions examples/request-calibrate-no-integration.json
Original file line number Diff line number Diff line change
@@ -1,8 +1,21 @@
{
"model": "{\"name\": \"Giordano2020 - SIDARTHE model of COVID-19 spread in Italy\",\"schema\": \"https://raw.githubusercontent.com/DARPA-ASKEM/Model-Representations/petrinet_v0.5/petrinet/petrinet_schema.json\",\"schema_name\": \"petrinet\",\"description\": \"Giordano2020 - SIDARTHE model of COVID-19 spread in Italy\",\"model_version\": \"0.1\",\"properties\": {},\"model\": { \"states\": [ { \"id\": \"Susceptible\", \"name\": \"Susceptible\", \"grounding\": { \"identifiers\": { \"ido\": \"0000514\" }, \"modifiers\": {} }, \"units\": { \"expression\": \"1\", \"expression_mathml\": \"<cn>1</cn>\" } }, { \"id\": \"Diagnosed\", \"name\": \"Diagnosed\", \"grounding\": { \"identifiers\": { \"ido\": \"0000511\" }, \"modifiers\": { \"diagnosis\": \"ncit:C15220\" } }, \"units\": { \"expression\": \"1\", \"expression_mathml\": \"<cn>1</cn>\" } }, { \"id\": \"Infected\", \"name\": \"Infected\", \"grounding\": { \"identifiers\": { \"ido\": \"0000511\" }, \"modifiers\": {} }, \"units\": { \"expression\": \"1\", \"expression_mathml\": \"<cn>1</cn>\" } }, { \"id\": \"Ailing\", \"name\": \"Ailing\", \"grounding\": { \"identifiers\": { \"ido\": \"0000511\" }, \"modifiers\": { \"disease_severity\": \"ncit:C25269\", \"diagnosis\": \"ncit:C113725\" } }, \"units\": { \"expression\": \"1\", \"expression_mathml\": \"<cn>1</cn>\" } }, { \"id\": \"Recognized\", \"name\": \"Recognized\", \"grounding\": { \"identifiers\": { \"ido\": \"0000511\" }, \"modifiers\": { \"diagnosis\": \"ncit:C15220\" } }, \"units\": { \"expression\": \"1\", \"expression_mathml\": \"<cn>1</cn>\" } }, { \"id\": \"Healed\", \"name\": \"Healed\", \"grounding\": { \"identifiers\": { \"ido\": \"0000592\" }, \"modifiers\": {} }, \"units\": { \"expression\": \"1\", \"expression_mathml\": \"<cn>1</cn>\" } }, { \"id\": \"Threatened\", \"name\": \"Threatened\", \"grounding\": { \"identifiers\": { \"ido\": \"0000511\" }, \"modifiers\": { \"disease_severity\": \"ncit:C25467\" } }, \"units\": { \"expression\": \"1\", \"expression_mathml\": \"<cn>1</cn>\" } }, { \"id\": \"Extinct\", \"name\": \"Extinct\", \"grounding\": { \"identifiers\": { \"ncit\": \"C28554\" }, \"modifiers\": {} }, \"units\": { \"expression\": \"1\", \"expression_mathml\": \"<cn>1</cn>\" } } ], \"transitions\": [ { \"id\": \"t1\", \"input\": [ \"Diagnosed\", \"Susceptible\" ], \"output\": [ \"Diagnosed\", \"Infected\" ], \"properties\": { \"name\": \"t1\" } }, { \"id\": \"t2\", \"input\": [ \"Ailing\", \"Susceptible\" ], \"output\": [ \"Ailing\", \"Infected\" ], \"properties\": { \"name\": \"t2\" } }, { \"id\": \"t3\", \"input\": [ \"Recognized\", \"Susceptible\" ], \"output\": [ \"Recognized\", \"Infected\" ], \"properties\": { \"name\": \"t3\" } }, { \"id\": \"t4\", \"input\": [ \"Infected\", \"Susceptible\" ], \"output\": [ \"Infected\", \"Infected\" ], \"properties\": { \"name\": \"t4\" } }, { \"id\": \"t5\", \"input\": [ \"Infected\" ], \"output\": [ \"Diagnosed\" ], \"properties\": { \"name\": \"t5\" } }, { \"id\": \"t6\", \"input\": [ \"Infected\" ], \"output\": [ \"Ailing\" ], \"properties\": { \"name\": \"t6\" } }, { \"id\": \"t7\", \"input\": [ \"Infected\" ], \"output\": [ \"Healed\" ], \"properties\": { \"name\": \"t7\" } }, { \"id\": \"t8\", \"input\": [ \"Diagnosed\" ], \"output\": [ \"Recognized\" ], \"properties\": { \"name\": \"t8\" } }, { \"id\": \"t9\", \"input\": [ \"Diagnosed\" ], \"output\": [ \"Healed\" ], \"properties\": { \"name\": \"t9\" } }, { \"id\": \"t10\", \"input\": [ \"Ailing\" ], \"output\": [ \"Recognized\" ], \"properties\": { \"name\": \"t10\" } }, { \"id\": \"t11\", \"input\": [ \"Ailing\" ], \"output\": [ \"Healed\" ], \"properties\": { \"name\": \"t11\" } }, { \"id\": \"t12\", \"input\": [ \"Ailing\" ], \"output\": [ \"Threatened\" ], \"properties\": { \"name\": \"t12\" } }, { \"id\": \"t13\", \"input\": [ \"Recognized\" ], \"output\": [ \"Threatened\" ], \"properties\": { \"name\": \"t13\" } }, { \"id\": \"t14\", \"input\": [ \"Recognized\" ], \"output\": [ \"Healed\" ], \"properties\": { \"name\": \"t14\" } }, { \"id\": \"t15\", \"input\": [ \"Threatened\" ], \"output\": [ \"Extinct\" ], \"properties\": { \"name\": \"t15\" } }, { \"id\": \"t16\", \"input\": [ \"Threatened\" ], \"output\": [ \"Healed\" ], \"properties\": { \"name\": \"t16\" } } ] }, \"semantics\": { \"ode\": { \"rates\": [ { \"target\": \"t1\", \"expression\": \"Diagnosed*Susceptible*beta\", \"expression_mathml\": \"<apply><times/><ci>Diagnosed</ci><ci>Susceptible</ci><ci>beta</ci></apply>\" }, { \"target\": \"t2\", \"expression\": \"Ailing*Susceptible*gamma\", \"expression_mathml\": \"<apply><times/><ci>Ailing</ci><ci>Susceptible</ci><ci>gamma</ci></apply>\" }, { \"target\": \"t3\", \"expression\": \"Recognized*Susceptible*delta\", \"expression_mathml\": \"<apply><times/><ci>Recognized</ci><ci>Susceptible</ci><ci>delta</ci></apply>\" }, { \"target\": \"t4\", \"expression\": \"Infected*Susceptible*alpha\", \"expression_mathml\": \"<apply><times/><ci>Infected</ci><ci>Susceptible</ci><ci>alpha</ci></apply>\" }, { \"target\": \"t5\", \"expression\": \"Infected*epsilon\", \"expression_mathml\": \"<apply><times/><ci>Infected</ci><ci>epsilon</ci></apply>\" }, { \"target\": \"t6\", \"expression\": \"Infected*zeta\", \"expression_mathml\": \"<apply><times/><ci>Infected</ci><ci>zeta</ci></apply>\" }, { \"target\": \"t7\", \"expression\": \"Infected*lambda\", \"expression_mathml\": \"<apply><times/><ci>Infected</ci><ci>lambda</ci></apply>\" }, { \"target\": \"t8\", \"expression\": \"Diagnosed*eta\", \"expression_mathml\": \"<apply><times/><ci>Diagnosed</ci><ci>eta</ci></apply>\" }, { \"target\": \"t9\", \"expression\": \"Diagnosed*rho\", \"expression_mathml\": \"<apply><times/><ci>Diagnosed</ci><ci>rho</ci></apply>\" }, { \"target\": \"t10\", \"expression\": \"Ailing*theta\", \"expression_mathml\": \"<apply><times/><ci>Ailing</ci><ci>theta</ci></apply>\" }, { \"target\": \"t11\", \"expression\": \"Ailing*kappa\", \"expression_mathml\": \"<apply><times/><ci>Ailing</ci><ci>kappa</ci></apply>\" }, { \"target\": \"t12\", \"expression\": \"Ailing*mu\", \"expression_mathml\": \"<apply><times/><ci>Ailing</ci><ci>mu</ci></apply>\" }, { \"target\": \"t13\", \"expression\": \"Recognized*nu\", \"expression_mathml\": \"<apply><times/><ci>Recognized</ci><ci>nu</ci></apply>\" }, { \"target\": \"t14\", \"expression\": \"Recognized*xi\", \"expression_mathml\": \"<apply><times/><ci>Recognized</ci><ci>xi</ci></apply>\" }, { \"target\": \"t15\", \"expression\": \"Threatened*tau\", \"expression_mathml\": \"<apply><times/><ci>Threatened</ci><ci>tau</ci></apply>\" }, { \"target\": \"t16\", \"expression\": \"Threatened*sigma\", \"expression_mathml\": \"<apply><times/><ci>Threatened</ci><ci>sigma</ci></apply>\" } ], \"initials\": [ { \"target\": \"Susceptible\", \"expression\": \"0.999996300000000\", \"expression_mathml\": \"<cn>0.99999629999999995</cn>\" }, { \"target\": \"Diagnosed\", \"expression\": \"3.33333333000000e-7\", \"expression_mathml\": \"<cn>3.33333333e-7</cn>\" }, { \"target\": \"Infected\", \"expression\": \"3.33333333000000e-6\", \"expression_mathml\": \"<cn>3.3333333299999999e-6</cn>\" }, { \"target\": \"Ailing\", \"expression\": \"1.66666666000000e-8\", \"expression_mathml\": \"<cn>1.6666666599999999e-8</cn>\" }, { \"target\": \"Recognized\", \"expression\": \"3.33333333000000e-8\", \"expression_mathml\": \"<cn>3.33333333e-8</cn>\" }, { \"target\": \"Healed\", \"expression\": \"0.0\", \"expression_mathml\": \"<cn>0.0</cn>\" }, { \"target\": \"Threatened\", \"expression\": \"0.0\", \"expression_mathml\": \"<cn>0.0</cn>\" }, { \"target\": \"Extinct\", \"expression\": \"0.0\", \"expression_mathml\": \"<cn>0.0</cn>\" } ], \"parameters\": [ { \"id\": \"beta\", \"value\": 0.011, \"distribution\": { \"type\": \"StandardUniform1\", \"parameters\": { \"minimum\": 0.008799999999999999, \"maximum\": 0.0132 } } }, { \"id\": \"gamma\", \"value\": 0.456, \"distribution\": { \"type\": \"StandardUniform1\", \"parameters\": { \"minimum\": 0.3648, \"maximum\": 0.5472 } } }, { \"id\": \"delta\", \"value\": 0.011, \"distribution\": { \"type\": \"StandardUniform1\", \"parameters\": { \"minimum\": 0.008799999999999999, \"maximum\": 0.0132 } } }, { \"id\": \"alpha\", \"value\": 0.57, \"distribution\": { \"type\": \"StandardUniform1\", \"parameters\": { \"minimum\": 0.45599999999999996, \"maximum\": 0.6839999999999999 } } }, { \"id\": \"epsilon\", \"value\": 0.171, \"distribution\": { \"type\": \"StandardUniform1\", \"parameters\": { \"minimum\": 0.1368, \"maximum\": 0.20520000000000002 } } }, { \"id\": \"zeta\", \"value\": 0.125, \"distribution\": { \"type\": \"StandardUniform1\", \"parameters\": { \"minimum\": 0.1, \"maximum\": 0.15 } } }, { \"id\": \"lambda\", \"value\": 0.034, \"distribution\": { \"type\": \"StandardUniform1\", \"parameters\": { \"minimum\": 0.027200000000000002, \"maximum\": 0.0408 } } }, { \"id\": \"eta\", \"value\": 0.125, \"distribution\": { \"type\": \"StandardUniform1\", \"parameters\": { \"minimum\": 0.1, \"maximum\": 0.15 } } }, { \"id\": \"rho\", \"value\": 0.034, \"distribution\": { \"type\": \"StandardUniform1\", \"parameters\": { \"minimum\": 0.027200000000000002, \"maximum\": 0.0408 } } }, { \"id\": \"theta\", \"value\": 0.371, \"distribution\": { \"type\": \"StandardUniform1\", \"parameters\": { \"minimum\": 0.2968, \"maximum\": 0.4452 } } }, { \"id\": \"kappa\", \"value\": 0.017, \"distribution\": { \"type\": \"StandardUniform1\", \"parameters\": { \"minimum\": 0.013600000000000001, \"maximum\": 0.0204 } } }, { \"id\": \"mu\", \"value\": 0.017, \"distribution\": { \"type\": \"StandardUniform1\", \"parameters\": { \"minimum\": 0.013600000000000001, \"maximum\": 0.0204 } } }, { \"id\": \"nu\", \"value\": 0.027, \"distribution\": { \"type\": \"StandardUniform1\", \"parameters\": { \"minimum\": 0.0216, \"maximum\": 0.0324 } } }, { \"id\": \"xi\", \"value\": 0.017, \"distribution\": { \"type\": \"StandardUniform1\", \"parameters\": { \"minimum\": 0.013600000000000001, \"maximum\": 0.0204 } } }, { \"id\": \"tau\", \"value\": 0.01, \"distribution\": { \"type\": \"StandardUniform1\", \"parameters\": { \"minimum\": 0.008, \"maximum\": 0.012 } } }, { \"id\": \"sigma\", \"value\": 0.017, \"distribution\": { \"type\": \"StandardUniform1\", \"parameters\": { \"minimum\": 0.013600000000000001, \"maximum\": 0.0204 } } } ], \"observables\": [ { \"id\": \"Cases\", \"name\": \"Cases\", \"expression\": \"Diagnosed + Recognized + Threatened\", \"expression_mathml\": \"<apply><plus/><ci>Diagnosed</ci><ci>Recognized</ci><ci>Threatened</ci></apply>\" }, { \"id\": \"Hospitalizations\", \"name\": \"Hospitalizations\", \"expression\": \"Recognized + Threatened\", \"expression_mathml\": \"<apply><plus/><ci>Recognized</ci><ci>Threatened</ci></apply>\" }, { \"id\": \"Deaths\", \"name\": \"Deaths\", \"expression\": \"Extinct\", \"expression_mathml\": \"<ci>Extinct</ci>\" } ], \"time\": { \"id\": \"t\", \"units\": { \"expression\": \"day\", \"expression_mathml\": \"<ci>day</ci>\" } } } }, \"metadata\": { \"annotations\": { \"license\": \"CC0\", \"authors\": [], \"references\": [ \"pubmed:32322102\" ], \"time_scale\": null, \"time_start\": null, \"time_end\": null, \"locations\": [], \"pathogens\": [ \"ncbitaxon:2697049\" ], \"diseases\": [ \"doid:0080600\" ], \"hosts\": [ \"ncbitaxon:9606\" ], \"model_types\": [ \"mamo:0000028\" ] } }}",
"dataset": {
"id": "2ea2d39f-866f-46f6-beec-972ed2136ed5",
"filename": "dataset.csv"
},
"timespan": {"start": 101, "end": 190}
"engine": "sciml",
"model_config_id": "c1cd941a-047d-11ee-be56",
"dataset": {
"id": "cd339570-047d-11ee-be55",
"filename": "dataset.csv",
"mappings": {
"postive_tests": "infected"
}
},
"timespan": {
"start": 0,
"end": 90
},
"extra": {
"num_chains": 4,
"num_iterations": 100,
"ode_method": "default",
"calibrate_method": "bayesian"
}
}
1 change: 1 addition & 0 deletions src/SimulationService.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import SwaggerMarkdown
import SymbolicUtils
import UUIDs
import YAML
import Statistics

export start!, stop!

Expand Down
85 changes: 62 additions & 23 deletions src/operations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -58,41 +58,80 @@ struct Calibrate <: Operation
sys::ODESystem
timespan::Tuple{Float64, Float64}
priors::Vector{Pair{SymbolicUtils.BasicSymbolic{Real}, Uniform{Float64}}}
data::Any # ???
data::Any
num_chains::Int
num_iterations::Int
calibrate_method::String
ode_method::Any
end

function Calibrate(o::OperationRequest)
sys = amr_get(o.model, ODESystem)
priors = amr_get(o.model, sys, Val(:priors))
data = amr_get(o.df, sys, Val(:data))
Calibrate(sys, o.timespan, priors, data)

num_chains = 4
num_iterations = 100
calibrate_method = "bayesian"
ode_method = nothing

if :extra in keys(o)
extrakeys = keys(o.extra)
:num_chains in extrakeys && (num_chains = o.extra.num_chains)
:num_iterations in extrakeys && (num_iterations = o.extra.num_iterations)
:calibrate_method in extrakeys && (calibrate_method = o.extra.calibrate_method)
end
Calibrate(sys, o.timespan, priors, data, num_chains, num_iterations, calibrate_method, ode_method)
end

function solve(o::Calibrate; callback)
prob = ODEProblem(o.sys, [], o.timespan)
statenames = [states(o.sys);getproperty.(observed(o.sys), :lhs)]
p_posterior = EasyModelAnalysis.bayesian_datafit(prob, o.priors, o.data;
nchains = 2,
niter = 100,
mcmcensemble = SimulationService.EasyModelAnalysis.Turing.MCMCSerial())

pvalues = last.(p_posterior)

probs = [EasyModelAnalysis.remake(prob, p = Pair.(first.(p_posterior), getindex.(pvalues,i))) for i in 1:length(p_posterior[1][2])]
enprob = EasyModelAnalysis.EnsembleProblem(probs)
ensol = solve(enprob, saveat = 1)
outs = map(1:length(probs)) do i
mats = stack(ensol[i][statenames])'
headers = string.("ensemble",i,"_", statenames)
mats, headers
end
dfsim = DataFrame(hcat(ensol[1].t, reduce(hcat, first.(outs))), :auto)
rename!(dfsim, ["timestamp";reduce(vcat, last.(outs))])

dfparam = DataFrame(last.(p_posterior), :auto)
rename!(dfparam, Symbol.(first.(p_posterior)))

dfsim, dfparam
if o.calibrate_method == "bayesian"
p_posterior = EasyModelAnalysis.bayesian_datafit(prob, o.priors, o.data;
nchains = 2,
niter = 100,
mcmcensemble = SimulationService.EasyModelAnalysis.Turing.MCMCSerial())

pvalues = last.(p_posterior)

probs = [EasyModelAnalysis.remake(prob, p = Pair.(first.(p_posterior), getindex.(pvalues,i))) for i in 1:length(p_posterior[1][2])]
enprob = EasyModelAnalysis.EnsembleProblem(probs)
ensol = solve(enprob, saveat = 1)
outs = map(1:length(probs)) do i
mats = stack(ensol[i][statenames])'
headers = string.("ensemble",i,"_", statenames)
mats, headers
end
dfsim = DataFrame(hcat(ensol[1].t, reduce(hcat, first.(outs))), :auto)
rename!(dfsim, ["timestamp";reduce(vcat, last.(outs))])

dfparam = DataFrame(last.(p_posterior), :auto)
rename!(dfparam, Symbol.(first.(p_posterior)))

dfsim, dfparam
elseif o.calibrate_method == "local" || o.calibrate_method == "global"
if o.calibrate_method == "local"
init_params = Pair.(EasyModelAnalysis.ModelingToolkit.Num.(first.(o.priors)), Statistics.mean.(last.(o.priors)))
fit = EasyModelAnalysis.datafit(prob, init_params, o.data)
else
init_params = Pair.(EasyModelAnalysis.ModelingToolkit.Num.(first.(o.priors)), tuple.(minimum.(last.(o.priors)), maximum.(last.(o.priors))))
fit = global_datafit(prob, init_params, o.data)
end

newprob = remake(prob, p=fit)
sol = solve(newprob)
dfsim = DataFrame(hcat(sol.t,stack(sol[statenames])'), :auto)
rename!(dfsim, ["timestamp";string.(statenames)])

dfparam = DataFrame(Matrix(last.(fit)'), :auto)
rename!(dfparam, Symbol.(first.(fit)))

dfsim, dfparam
else
error("$(o.calibrate_method) is not a valid choice of calibration method")
end
end

#-----------------------------------------------------------------------------# Ensemble
Expand Down
Loading

0 comments on commit 8584c07

Please sign in to comment.