Bayesian Inference of ODE
For this tutorial, we will show how to do Bayesian inference to infer the parameters of the Lotka-Volterra equations using each of the three backends:
- Turing.jl
- Stan.jl
- DynamicHMC.jl
Setup
First, let's set up our ODE and the data. For the data, we will simply solve the ODE and take that solution at some known parameters as the dataset. This looks like the following:
using DiffEqBayes, ParameterizedFunctions, OrdinaryDiffEq, RecursiveArrayTools,
+ Distributions
+f1 = @ode_def LotkaVolterra begin
+ dx = a * x - x * y
+ dy = -3 * y + x * y
+end a
+
+p = [1.5]
+u0 = [1.0, 1.0]
+tspan = (0.0, 10.0)
+prob1 = ODEProblem(f1, u0, tspan, p)
+
+σ = 0.01 # noise, fixed for now
+t = collect(1.0:10.0) # observation times
+sol = solve(prob1, Tsit5())
+priors = [Normal(1.5, 1)]
+randomized = VectorOfArray([(sol(t[i]) + σ * randn(2)) for i in 1:length(t)])
+data = convert(Array, randomized)
2×10 Matrix{Float64}:
+ 2.76967 6.77947 0.973077 1.89297 … 4.35303 3.25201 1.02374
+ 0.24743 2.01714 1.90963 0.328245 0.330333 4.54354 0.914196
Inference Methods
Stan
using CmdStan #required for using the Stan backend
+bayesian_result_stan = stan_inference(prob1, t, data, priors)
Chains MCMC chain (1000×3×1 Array{Float64, 3}):
+
+Iterations = 1:1:1000
+Number of chains = 1
+Samples per chain = 1000
+parameters = sigma1.1, sigma1.2, theta_1
+internals =
+
+Summary Statistics
+ parameters mean std mcse ess_bulk ess_tail rhat e ⋯
+ Symbol Float64 Float64 Float64 Float64 Float64 Float64 ⋯
+
+ sigma1.1 0.2721 0.0849 0.0035 638.8385 563.5389 1.0008 ⋯
+ sigma1.2 0.2584 0.0819 0.0034 808.0022 409.9780 0.9997 ⋯
+ theta_1 1.5011 0.0061 0.0002 794.1227 538.3494 0.9995 ⋯
+ 1 column omitted
+
+Quantiles
+ parameters 2.5% 25.0% 50.0% 75.0% 97.5%
+ Symbol Float64 Float64 Float64 Float64 Float64
+
+ sigma1.1 0.1545 0.2101 0.2567 0.3159 0.4728
+ sigma1.2 0.1428 0.2040 0.2405 0.2948 0.4660
+ theta_1 1.4887 1.4973 1.5008 1.5046 1.5122
+
Turing
bayesian_result_turing = turing_inference(prob1, Tsit5(), t, data, priors)
Chains MCMC chain (1000×14×1 Array{Float64, 3}):
+
+Iterations = 501:1:1500
+Number of chains = 1
+Samples per chain = 1000
+Wall duration = 24.0 seconds
+Compute duration = 24.0 seconds
+parameters = theta[1], σ[1]
+internals = lp, n_steps, is_accept, acceptance_rate, log_density, hamiltonian_energy, hamiltonian_energy_error, max_hamiltonian_energy_error, tree_depth, numerical_error, step_size, nom_step_size
+
+Summary Statistics
+ parameters mean std mcse ess_bulk ess_tail rhat e ⋯
+ Symbol Float64 Float64 Float64 Float64 Float64 Float64 ⋯
+
+ theta[1] 1.5002 0.0033 0.0001 841.4280 653.7758 1.0081 ⋯
+ σ[1] 0.1491 0.0337 0.0015 454.0794 407.4851 1.0007 ⋯
+ 1 column omitted
+
+Quantiles
+ parameters 2.5% 25.0% 50.0% 75.0% 97.5%
+ Symbol Float64 Float64 Float64 Float64 Float64
+
+ theta[1] 1.4941 1.4981 1.5003 1.5023 1.5069
+ σ[1] 0.0969 0.1246 0.1451 0.1685 0.2285
+
DynamicHMC
We can use DynamicHMC.jl as the backend for sampling with the dynamic_inference
function. It is similarly used as follows:
bayesian_result_hmc = dynamichmc_inference(prob1, Tsit5(), t, data, priors)
(posterior = NamedTuple{(:parameters, :σ), Tuple{Vector{Float64}, Vector{Float64}}}[(parameters = [1.5003136888662878], σ = [0.0038704619355187982, 0.00796539836360419]), (parameters = [1.5001639743139807], σ = [0.0038384769565244156, 0.007831833029925286]), (parameters = [1.5004604776030652], σ = [0.005710731758673134, 0.007941792825807141]), (parameters = [1.5001818235936895], σ = [0.005550850066733869, 0.007209878919234763]), (parameters = [1.5003568086954406], σ = [0.005499247850074649, 0.0073808319454661585]), (parameters = [1.5004556541605785], σ = [0.005705435048697979, 0.013691492185691916]), (parameters = [1.5001728151319507], σ = [0.007340491083391076, 0.006481661832441887]), (parameters = [1.500093697419652], σ = [0.0073811000048381314, 0.009738088681308984]), (parameters = [1.5001474470880567], σ = [0.004559576972520138, 0.008422608520333717]), (parameters = [1.5000709739687534], σ = [0.004513932751365236, 0.008862986949708654]) … (parameters = [1.5005555406074598], σ = [0.004501674689405071, 0.01367774969031757]), (parameters = [1.5001136617861988], σ = [0.004715993631704208, 0.010847018292803188]), (parameters = [1.5000747949115345], σ = [0.005904325779407816, 0.00780141041243877]), (parameters = [1.5002305959326152], σ = [0.005903084753781323, 0.007801190339375383]), (parameters = [1.5003451195077495], σ = [0.0055873380119094785, 0.007980661413593463]), (parameters = [1.5004152921523313], σ = [0.0050666870741009565, 0.009947401693528203]), (parameters = [1.5001206019252376], σ = [0.005778551011559513, 0.010642801231135517]), (parameters = [1.500227236311067], σ = [0.007405663125060141, 0.016154808615718046]), (parameters = [1.5002421707897142], σ = [0.007060915665011238, 0.010493588640767729]), (parameters = [1.5003452259354877], σ = [0.005823640990164413, 0.012466554255325864])], posterior_matrix = [0.4056742121552477 0.4055744183429035 … 0.40562654227011163 0.405695232251231; -5.5543814158805125 -5.56267961696446 … -4.953180538293332 -5.14582961316009; -4.832648322607873 -4.849558692938626 … -4.556990813955373 -4.384705880222884], tree_statistics = DynamicHMC.TreeStatisticsNUTS[DynamicHMC.TreeStatisticsNUTS(56.09972449227871, 10, turning at positions -448:575, 0.9768083033619173, 1023, DynamicHMC.Directions(0x1e03123f)), DynamicHMC.TreeStatisticsNUTS(55.99880450344021, 6, turning at positions -80:-83, 0.8537851024545738, 103, DynamicHMC.Directions(0xebbbcc94)), DynamicHMC.TreeStatisticsNUTS(54.072804499858165, 10, turning at positions -355:668, 0.9381893645576496, 1023, DynamicHMC.Directions(0x1337fa9c)), DynamicHMC.TreeStatisticsNUTS(53.87222433180597, 8, turning at positions -186:-189, 0.9787327708494868, 271, DynamicHMC.Directions(0x7f789252)), DynamicHMC.TreeStatisticsNUTS(51.523036490487954, 6, turning at positions 26:29, 0.7456545764186002, 79, DynamicHMC.Directions(0x912413cd)), DynamicHMC.TreeStatisticsNUTS(53.55941083649569, 10, turning at positions -723:300, 0.8761578413274247, 1023, DynamicHMC.Directions(0xac1c492c)), DynamicHMC.TreeStatisticsNUTS(51.72994949917636, 10, reached maximum depth without divergence or turning, 0.9880073946239357, 1023, DynamicHMC.Directions(0x4e5e1a25)), DynamicHMC.TreeStatisticsNUTS(52.03789048013739, 10, turning at positions -846:177, 0.9683416283366267, 1023, DynamicHMC.Directions(0x0e35ccb1)), DynamicHMC.TreeStatisticsNUTS(54.54194896222248, 10, turning at positions -299:724, 0.9935304401238156, 1023, DynamicHMC.Directions(0x08c436d4)), DynamicHMC.TreeStatisticsNUTS(54.40971507648113, 8, turning at positions 148:151, 0.8986644808502857, 307, DynamicHMC.Directions(0xc174e363)) … DynamicHMC.TreeStatisticsNUTS(51.35828589947691, 9, turning at positions 600:603, 0.9168687131294502, 727, DynamicHMC.Directions(0xf8685383)), DynamicHMC.TreeStatisticsNUTS(53.363804773110715, 8, turning at positions -171:-174, 0.9954852052405282, 399, DynamicHMC.Directions(0xbde4e8e1)), DynamicHMC.TreeStatisticsNUTS(53.05530094277249, 9, turning at positions -459:-466, 0.9931981649442928, 703, DynamicHMC.Directions(0x721c5ced)), DynamicHMC.TreeStatisticsNUTS(54.744301334441154, 2, turning at positions -2:1, 0.9817076554729621, 3, DynamicHMC.Directions(0xf8367f65)), DynamicHMC.TreeStatisticsNUTS(55.86751318826507, 8, turning at positions -142:113, 0.9881576918935685, 255, DynamicHMC.Directions(0x1d4a2a71)), DynamicHMC.TreeStatisticsNUTS(52.9994392007261, 9, turning at positions -727:-730, 0.8449343586262494, 891, DynamicHMC.Directions(0x3cf6a4a1)), DynamicHMC.TreeStatisticsNUTS(55.7214772489555, 9, turning at positions 573:576, 0.95480860747906, 859, DynamicHMC.Directions(0x91a07ee4)), DynamicHMC.TreeStatisticsNUTS(51.85185448003083, 9, turning at positions 508:515, 0.9133301927905927, 615, DynamicHMC.Directions(0x25004b9b)), DynamicHMC.TreeStatisticsNUTS(53.38824008089312, 10, turning at positions -765:258, 0.9990487640420308, 1023, DynamicHMC.Directions(0xcde4e102)), DynamicHMC.TreeStatisticsNUTS(54.828172198141125, 10, turning at positions -774:249, 0.9963595816891021, 1023, DynamicHMC.Directions(0x9c3e3cf9))], κ = Gaussian kinetic energy (Diagonal), √diag(M⁻¹): [0.023859290281126454, 0.23043070693521112, 0.2134030394617685], ϵ = 0.0029131839917604947)
More Information
For a better idea of the summary statistics and plotting, you can take a look at the benchmarks.