From 27f0b5e6dc5ec43f935e4cac544c89fd90c74844 Mon Sep 17 00:00:00 2001 From: jClugstor Date: Wed, 8 Nov 2023 12:11:30 -0500 Subject: [PATCH] tests for server routes, data for ensemble calib --- examples/sir_calibrate/sir_ensemble_data.csv | 42 +++++++++++++++ src/operations.jl | 16 +++--- test/runtests.jl | 56 +++++++++++++++++--- 3 files changed, 100 insertions(+), 14 deletions(-) create mode 100644 examples/sir_calibrate/sir_ensemble_data.csv diff --git a/examples/sir_calibrate/sir_ensemble_data.csv b/examples/sir_calibrate/sir_ensemble_data.csv new file mode 100644 index 0000000..68cbe6a --- /dev/null +++ b/examples/sir_calibrate/sir_ensemble_data.csv @@ -0,0 +1,42 @@ +timestamp,S,I,R +0.0,1000.0,1.4,0.0 +1.0,999.9996370253091,1.1815832550914591,0.21877971959939915 +2.0,999.9993325898772,0.9988906061858863,0.4017768039369397 +3.0,999.9990768992743,0.8458676890608052,0.5550554116649357 +4.0,999.9988618442471,0.7175128097759396,0.6836253459769424 +5.0,999.9986807045169,0.6096908187041074,0.7916284767790487 +6.0,999.9985279068266,0.5189800597904012,0.8824920333830553 +7.0,999.9983988212975,0.4425459227099201,0.9590552559925583 +8.0,999.998289603317,0.3780393589291605,1.0236710377539167 +9.0,999.9981970528227,0.3235115325278425,1.0782914146495763 +10.0,999.998118499697,0.2773400191469993,1.1245414811560759 +11.0,999.9980517266109,0.2381811116325973,1.1637671617564602 +12.0,999.9979948771722,0.20491347994957596,1.1970916428783052 +13.0,999.9979463950228,0.1765993101571986,1.2254542948197864 +14.0,999.9979049843444,0.15246035849486322,1.2496346571608217 +15.0,999.9978695617087,0.1318474515483701,1.2702829867428493 +16.0,999.997839210662,0.11421306508022987,1.2879477242577608 +17.0,999.9978131612163,0.09909920625730877,1.3030876325263074 +18.0,999.9977907706666,0.08612419687258105,1.3160850324608189 +19.0,999.9977714973852,0.07496680142437737,1.3272617011903274 +20.0,999.9977548807467,0.06535528051153047,1.336889838741927 +21.0,999.9977405318734,0.05706123449889946,1.3451982336276076 +22.0,999.9977281229965,0.049892080848540556,1.3523797961549961 +23.0,999.997717376844,0.04368496649518508,1.3585976566608506 +24.0,999.9977080576258,0.03830233234906549,1.3639896100250464 +25.0,999.9976999645526,0.03362749093344375,1.3686725445139125 +26.0,999.9976929253762,0.029560393117835475,1.3727466815059388 +27.0,999.9976867949074,0.02601664172236043,1.376296563370182 +28.0,999.9976814503476,0.022925127215691454,1.379393422436717 +29.0,999.997676785132,0.020224311706003355,1.3820989031621629 +30.0,999.9976727064331,0.017860692701611357,1.3844666008652917 +31.0,999.9976691354393,0.015789019054240936,1.3865418455063205 +32.0,999.9976660066118,0.013971608370947141,1.388362385017237 +33.0,999.997663263172,0.012375787007496572,1.3899609498204604 +34.0,999.9976608546723,0.010972576831300908,1.3913665684963843 +35.0,999.9976587371646,0.009736829124011859,1.3926044337114891 +36.0,999.9976568730758,0.008647095651106598,1.39369603127313 +37.0,999.9976552311186,0.007685446436379884,1.3946593224450141 +38.0,999.9976537837061,0.006836052747901465,1.3955101635459473 +39.0,999.9976525065451,0.006085021126375148,1.39626247232854 +40.0,999.9976513783824,0.005420209214685627,1.396928412402837 diff --git a/src/operations.jl b/src/operations.jl index 0cf648d..9b0750c 100644 --- a/src/operations.jl +++ b/src/operations.jl @@ -173,14 +173,6 @@ function get_callback(o::OperationRequest, ::Type{Simulate}) DiscreteCallback((args...) -> true, IntermediateResults(o.id,every = Dates.Second(0))) end -function get_callback(o::OperationRequest, ::Type{Ensemble{Simulate}}) - nothing -end - -function get_callback(o::OperationRequest, ::Type{Ensemble{Calibrate}}) - nothing -end - # callback for Simulate requests function solve(op::Simulate; callback) prob = ODEProblem(op.sys, [], op.timespan) @@ -305,6 +297,14 @@ function Ensemble{T}(o::OperationRequest) where {T} Ensemble{T}(model_ids, operations, weights, sol_mappings, df) end +function get_callback(o::OperationRequest, ::Type{Ensemble{Simulate}}) + nothing +end + +function get_callback(o::OperationRequest, ::Type{Ensemble{Calibrate}}) + nothing +end + # Solves multiple ODEs, performs a weighted sum # of the solutions. function solve(o::Ensemble{Simulate}; callback) diff --git a/test/runtests.jl b/test/runtests.jl index 64c5290..bdd0eb0 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -36,8 +36,33 @@ calibrate_payloads = JSON3.write.([ end ]) -ensemble_simulate_payloads = JSON3.write.([]) + +simulate_ensemble_payloads = JSON3.write.([( + model_configs = map(1:4) do i + (id="model_config_id_$i", weight = i / sum(1:4), solution_mappings = (any_generic = "I", name = "R", s = "S")) + end, + local_model_files = JSON3.read.(read.([here("examples", "sir_calibrate", "sir1.json"), + here("examples", "sir_calibrate", "sir2.json"), + here("examples", "sir_calibrate", "sir3.json"), + here("examples", "sir_calibrate", "sir4.json")])), + timespan = (start = 0, var"end" = 40), + engine = "sciml", + extra = (; num_samples = 40))] + ) +calibrate_ensemble_payloads = JSON3.write.([( + model_configs = map(1:4) do i + (id="model_config_id_$i", weight = i / sum(1:4), solution_mappings = (any_generic = "I", name = "R", s = "S")) + end, + local_model_files = JSON3.read.(read.([here("examples", "sir_calibrate", "sir1.json"), + here("examples", "sir_calibrate", "sir2.json"), + here("examples", "sir_calibrate", "sir3.json"), + here("examples", "sir_calibrate", "sir4.json")])), + timespan = (start = 0, var"end" = 40), + engine = "sciml", + local_csv_file = here("examples", "sir_calibrate", "sir_ensemble_data.csv"), + extra = (; num_samples = 40))] + ) #-----------------------------------------------------------------------------# utils @testset "utils" begin @@ -160,7 +185,8 @@ end @test names(dfsim) == vcat("timestamp",string.(statenames)) @test names(dfparam) == string.(parameters(sys)) end - @testset "Ensembles" begin + + @testset "ensemble-simulate" begin amrfiles = [here("examples", "sir_calibrate", "sir1.json"), here("examples", "sir_calibrate", "sir2.json"), here("examples", "sir_calibrate", "sir3.json"), @@ -173,7 +199,7 @@ end (id="model_config_id_$i", weight = i / sum(1:4), solution_mappings = (any_generic = "I", name = "R", s = "S")) end, models = amrs, - ztimespan = (start = 0, var"end" = 40), + timespan = (start = 0, var"end" = 40), engine = "sciml", extra = (; num_samples = 40) ) @@ -193,16 +219,21 @@ end # bad test, need something better @test names(sim_en_sol) == ["timestamp","S","I","R"] + + end + + @testset "ensemble-calibrate" begin # create ensemble-calibrate o = OperationRequest() o.route = "ensemble-calibrate" o.obj = JSON3.read(JSON3.write(obj)) o.models = amrs o.timespan = (0,40) - o.df = sim_en_sol + o.df = df = CSV.read(here("examples", "sir_calibrate", "sir_ensemble_data.csv"), DataFrame) en_cal = Ensemble{Calibrate}(o) cal_sol = SimulationService.solve(en_cal,callback = nothing) @test cal_sol[!,:Weights] ≈ [0.1, 0.2, 0.3, 0.4] + end @testset "Real Calibrate Payload" begin @@ -274,7 +305,20 @@ end end @testset "/ensemble-simulate" begin - @test true # TODO + for body in simulate_ensemble_payloads + res = HTTP.post("$url/ensemble-simulate", ["Content-Type" => "application/json"]; body) + @test res.status == 201 + id = JSON3.read(res.body).simulation_id + test_until_done(id) + end + end + @testset "/ensemble-calibrate" begin + for body in calibrate_ensemble_payloads + res = HTTP.post("$url/ensemble-calibrate", ["Content-Type" => "application/json"]; body) + @test res.status == 201 + id = JSON3.read(res.body).simulation_id + test_until_done(id) + end end end -end +end \ No newline at end of file