Skip to content

Commit

Permalink
tests for server routes, data for ensemble calib
Browse files Browse the repository at this point in the history
  • Loading branch information
jClugstor committed Nov 8, 2023
1 parent 30c3b04 commit 27f0b5e
Show file tree
Hide file tree
Showing 3 changed files with 100 additions and 14 deletions.
42 changes: 42 additions & 0 deletions examples/sir_calibrate/sir_ensemble_data.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
timestamp,S,I,R
0.0,1000.0,1.4,0.0
1.0,999.9996370253091,1.1815832550914591,0.21877971959939915
2.0,999.9993325898772,0.9988906061858863,0.4017768039369397
3.0,999.9990768992743,0.8458676890608052,0.5550554116649357
4.0,999.9988618442471,0.7175128097759396,0.6836253459769424
5.0,999.9986807045169,0.6096908187041074,0.7916284767790487
6.0,999.9985279068266,0.5189800597904012,0.8824920333830553
7.0,999.9983988212975,0.4425459227099201,0.9590552559925583
8.0,999.998289603317,0.3780393589291605,1.0236710377539167
9.0,999.9981970528227,0.3235115325278425,1.0782914146495763
10.0,999.998118499697,0.2773400191469993,1.1245414811560759
11.0,999.9980517266109,0.2381811116325973,1.1637671617564602
12.0,999.9979948771722,0.20491347994957596,1.1970916428783052
13.0,999.9979463950228,0.1765993101571986,1.2254542948197864
14.0,999.9979049843444,0.15246035849486322,1.2496346571608217
15.0,999.9978695617087,0.1318474515483701,1.2702829867428493
16.0,999.997839210662,0.11421306508022987,1.2879477242577608
17.0,999.9978131612163,0.09909920625730877,1.3030876325263074
18.0,999.9977907706666,0.08612419687258105,1.3160850324608189
19.0,999.9977714973852,0.07496680142437737,1.3272617011903274
20.0,999.9977548807467,0.06535528051153047,1.336889838741927
21.0,999.9977405318734,0.05706123449889946,1.3451982336276076
22.0,999.9977281229965,0.049892080848540556,1.3523797961549961
23.0,999.997717376844,0.04368496649518508,1.3585976566608506
24.0,999.9977080576258,0.03830233234906549,1.3639896100250464
25.0,999.9976999645526,0.03362749093344375,1.3686725445139125
26.0,999.9976929253762,0.029560393117835475,1.3727466815059388
27.0,999.9976867949074,0.02601664172236043,1.376296563370182
28.0,999.9976814503476,0.022925127215691454,1.379393422436717
29.0,999.997676785132,0.020224311706003355,1.3820989031621629
30.0,999.9976727064331,0.017860692701611357,1.3844666008652917
31.0,999.9976691354393,0.015789019054240936,1.3865418455063205
32.0,999.9976660066118,0.013971608370947141,1.388362385017237
33.0,999.997663263172,0.012375787007496572,1.3899609498204604
34.0,999.9976608546723,0.010972576831300908,1.3913665684963843
35.0,999.9976587371646,0.009736829124011859,1.3926044337114891
36.0,999.9976568730758,0.008647095651106598,1.39369603127313
37.0,999.9976552311186,0.007685446436379884,1.3946593224450141
38.0,999.9976537837061,0.006836052747901465,1.3955101635459473
39.0,999.9976525065451,0.006085021126375148,1.39626247232854
40.0,999.9976513783824,0.005420209214685627,1.396928412402837
16 changes: 8 additions & 8 deletions src/operations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -173,14 +173,6 @@ function get_callback(o::OperationRequest, ::Type{Simulate})
DiscreteCallback((args...) -> true, IntermediateResults(o.id,every = Dates.Second(0)))
end

function get_callback(o::OperationRequest, ::Type{Ensemble{Simulate}})
nothing
end

function get_callback(o::OperationRequest, ::Type{Ensemble{Calibrate}})
nothing
end

# callback for Simulate requests
function solve(op::Simulate; callback)
prob = ODEProblem(op.sys, [], op.timespan)
Expand Down Expand Up @@ -305,6 +297,14 @@ function Ensemble{T}(o::OperationRequest) where {T}
Ensemble{T}(model_ids, operations, weights, sol_mappings, df)
end

function get_callback(o::OperationRequest, ::Type{Ensemble{Simulate}})
nothing
end

function get_callback(o::OperationRequest, ::Type{Ensemble{Calibrate}})
nothing
end

# Solves multiple ODEs, performs a weighted sum
# of the solutions.
function solve(o::Ensemble{Simulate}; callback)
Expand Down
56 changes: 50 additions & 6 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,33 @@ calibrate_payloads = JSON3.write.([
end
])

ensemble_simulate_payloads = JSON3.write.([])

simulate_ensemble_payloads = JSON3.write.([(
model_configs = map(1:4) do i
(id="model_config_id_$i", weight = i / sum(1:4), solution_mappings = (any_generic = "I", name = "R", s = "S"))
end,
local_model_files = JSON3.read.(read.([here("examples", "sir_calibrate", "sir1.json"),
here("examples", "sir_calibrate", "sir2.json"),
here("examples", "sir_calibrate", "sir3.json"),
here("examples", "sir_calibrate", "sir4.json")])),
timespan = (start = 0, var"end" = 40),
engine = "sciml",
extra = (; num_samples = 40))]
)

calibrate_ensemble_payloads = JSON3.write.([(
model_configs = map(1:4) do i
(id="model_config_id_$i", weight = i / sum(1:4), solution_mappings = (any_generic = "I", name = "R", s = "S"))
end,
local_model_files = JSON3.read.(read.([here("examples", "sir_calibrate", "sir1.json"),
here("examples", "sir_calibrate", "sir2.json"),
here("examples", "sir_calibrate", "sir3.json"),
here("examples", "sir_calibrate", "sir4.json")])),
timespan = (start = 0, var"end" = 40),
engine = "sciml",
local_csv_file = here("examples", "sir_calibrate", "sir_ensemble_data.csv"),
extra = (; num_samples = 40))]
)

#-----------------------------------------------------------------------------# utils
@testset "utils" begin
Expand Down Expand Up @@ -160,7 +185,8 @@ end
@test names(dfsim) == vcat("timestamp",string.(statenames))
@test names(dfparam) == string.(parameters(sys))
end
@testset "Ensembles" begin

@testset "ensemble-simulate" begin
amrfiles = [here("examples", "sir_calibrate", "sir1.json"),
here("examples", "sir_calibrate", "sir2.json"),
here("examples", "sir_calibrate", "sir3.json"),
Expand All @@ -173,7 +199,7 @@ end
(id="model_config_id_$i", weight = i / sum(1:4), solution_mappings = (any_generic = "I", name = "R", s = "S"))
end,
models = amrs,
ztimespan = (start = 0, var"end" = 40),
timespan = (start = 0, var"end" = 40),
engine = "sciml",
extra = (; num_samples = 40)
)
Expand All @@ -193,16 +219,21 @@ end

# bad test, need something better
@test names(sim_en_sol) == ["timestamp","S","I","R"]

end

@testset "ensemble-calibrate" begin
# create ensemble-calibrate
o = OperationRequest()
o.route = "ensemble-calibrate"
o.obj = JSON3.read(JSON3.write(obj))
o.models = amrs
o.timespan = (0,40)
o.df = sim_en_sol
o.df = df = CSV.read(here("examples", "sir_calibrate", "sir_ensemble_data.csv"), DataFrame)
en_cal = Ensemble{Calibrate}(o)
cal_sol = SimulationService.solve(en_cal,callback = nothing)
@test cal_sol[!,:Weights] [0.1, 0.2, 0.3, 0.4]

end

@testset "Real Calibrate Payload" begin
Expand Down Expand Up @@ -274,7 +305,20 @@ end
end

@testset "/ensemble-simulate" begin
@test true # TODO
for body in simulate_ensemble_payloads
res = HTTP.post("$url/ensemble-simulate", ["Content-Type" => "application/json"]; body)
@test res.status == 201
id = JSON3.read(res.body).simulation_id
test_until_done(id)
end
end
@testset "/ensemble-calibrate" begin
for body in calibrate_ensemble_payloads
res = HTTP.post("$url/ensemble-calibrate", ["Content-Type" => "application/json"]; body)
@test res.status == 201
id = JSON3.read(res.body).simulation_id
test_until_done(id)
end
end
end
end
end

0 comments on commit 27f0b5e

Please sign in to comment.