diff --git a/DESCRIPTION b/DESCRIPTION index 2e1613b7..68bb1bfc 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -17,6 +17,8 @@ Language: en-GB Config/testthat/edition: 3 URL: https://github.com/mrc-ide/dust2 BugReports: https://github.com/mrc-ide/dust2/issues +Imports: + mcstate2 LinkingTo: cpp11, mcstate2 diff --git a/inst/include/dust2/cpu.hpp b/inst/include/dust2/cpu.hpp index 81f5a399..1f50d3f2 100644 --- a/inst/include/dust2/cpu.hpp +++ b/inst/include/dust2/cpu.hpp @@ -39,6 +39,7 @@ class dust_cpu { // Ignore errors for now. real_type * state_data = state_.data(); real_type * state_next_data = state_next_.data(); + // Later we parallelise this and track errors carefully. for (size_t i = 0; i < n_particles_; ++i) { const auto offset = i * n_state_; run_particle(time_, dt_, n_steps, @@ -50,6 +51,12 @@ class dust_cpu { if (n_steps % 2 == 1) { std::swap(state_, state_next_); } + // Time management here is going to require some effort once we + // support interesting dt so that we always land on times with no + // non-integer bits, but for now we require that dt is 1 so this + // is easy. We need this to hold within run_particle too, so it's + // possible that's where the calculation here will be done. + time_ = time_ + n_steps * dt_; } void set_state_initial() { diff --git a/inst/include/dust2/r/cpu.hpp b/inst/include/dust2/r/cpu.hpp index 857ff8c7..d5ba7a69 100644 --- a/inst/include/dust2/r/cpu.hpp +++ b/inst/include/dust2/r/cpu.hpp @@ -33,7 +33,12 @@ SEXP dust2_cpu_alloc(cpp11::list r_pars, seed, deterministic); cpp11::external_pointer> ptr(obj, true, false); - return cpp11::writable::list({ptr}); + // Later, we'll export a bit more back from the model (in particular + // models need to provide information about how they organise + // variables, ode models report computed control, etc. + auto size = T::size(shared); + + return cpp11::writable::list{ptr, cpp11::as_sexp(size)}; } template diff --git a/tests/testthat/test-walk.R b/tests/testthat/test-walk.R index edb5a3c0..51042fc8 100644 --- a/tests/testthat/test-walk.R +++ b/tests/testthat/test-walk.R @@ -1,9 +1,20 @@ -test_that("...", { +test_that("can run simple walk model", { pars <- list(sd = 1, random_initial = TRUE) obj <- dust2_cpu_walk_alloc(pars, 0, 1, 10, 42, FALSE) + expect_type(obj[[1]], "externalptr") + expect_equal(obj[[2]], 1) + ptr <- obj[[1]] - dust2_cpu_walk_rng_state(ptr) - dust2_cpu_walk_state(ptr) - dust2_cpu_walk_run_steps(ptr, 3) - dust2_cpu_walk_state(ptr) + expect_type(dust2_cpu_walk_rng_state(ptr), "raw") + expect_length(dust2_cpu_walk_rng_state(ptr), 32 * 10) + + expect_equal(dust2_cpu_walk_state(ptr), rep(0, 10)) + expect_equal(dust2_cpu_walk_time(ptr), 0) + + expect_null(dust2_cpu_walk_run_steps(ptr, 3)) + s <- dust2_cpu_walk_state(ptr) + + r <- mcstate2::mcstate_rng$new(seed = 42, n_streams = 10) + expect_equal(s, colSums(r$normal(3, 0, 1))) + expect_equal(dust2_cpu_walk_time(ptr), 3) })