From 9f19048d7e6f5930c3a4f862ea5862447d60b07c Mon Sep 17 00:00:00 2001 From: Rich FitzJohn Date: Wed, 8 May 2024 16:13:39 +0100 Subject: [PATCH 1/2] Allow updating parameters --- R/cpp11.R | 4 +++ inst/include/dust2/cpu.hpp | 11 +++++++ inst/include/dust2/r/cpu.hpp | 35 +++++++++++++++++++--- src/cpp11.cpp | 8 +++++ src/walk.cpp | 6 ++++ tests/testthat/test-walk.R | 57 ++++++++++++++++++++++++++++++++++-- 6 files changed, 114 insertions(+), 7 deletions(-) diff --git a/R/cpp11.R b/R/cpp11.R index a1fd1c4d..12244c3e 100644 --- a/R/cpp11.R +++ b/R/cpp11.R @@ -31,3 +31,7 @@ dust2_cpu_walk_rng_state <- function(ptr) { dust2_cpu_walk_set_time <- function(ptr, r_time) { .Call(`_dust2_dust2_cpu_walk_set_time`, ptr, r_time) } + +dust2_cpu_walk_update_pars <- function(ptr, pars, grouped) { + .Call(`_dust2_dust2_cpu_walk_update_pars`, ptr, pars, grouped) +} diff --git a/inst/include/dust2/cpu.hpp b/inst/include/dust2/cpu.hpp index e8f2ec3f..fde127ed 100644 --- a/inst/include/dust2/cpu.hpp +++ b/inst/include/dust2/cpu.hpp @@ -113,6 +113,17 @@ class dust_cpu { rng_.import_state(); } + template + void update_shared(size_t i, Fn fn) { + // TODO: check that size was not modified, error if so (quite a + // bit later). + fn(shared_[i]); + } + + auto n_groups() const { + return n_groups_; + } + private: size_t n_state_; size_t n_particles_; diff --git a/inst/include/dust2/r/cpu.hpp b/inst/include/dust2/r/cpu.hpp index a468d2fb..249acde5 100644 --- a/inst/include/dust2/r/cpu.hpp +++ b/inst/include/dust2/r/cpu.hpp @@ -31,7 +31,8 @@ SEXP dust2_cpu_alloc(cpp11::list r_pars, size_t size = 0; cpp11::sexp group_names = R_NilValue; - if (n_groups == 0) { + const auto grouped = n_groups > 0; + if (!grouped) { shared.push_back(T::build_shared(r_pars)); internal.push_back(T::build_internal(r_pars)); size = T::size(shared[0]); @@ -62,10 +63,13 @@ SEXP dust2_cpu_alloc(cpp11::list r_pars, cpp11::external_pointer> ptr(obj, true, false); // Later, we'll export a bit more back from the model (in particular - // models need to provide information about how they organise - // variables, ode models report computed control, etc. + // models need to provide information about ~how they organise + // variables~, ode models report computed control, etc. - return cpp11::writable::list{ptr, cpp11::as_sexp(size), group_names}; + cpp11::sexp r_size = cpp11::as_sexp(size); + cpp11::sexp r_grouped = cpp11::as_sexp(grouped); + + return cpp11::writable::list{ptr, r_size, r_grouped, group_names}; } template @@ -123,5 +127,28 @@ SEXP dust2_cpu_set_time(cpp11::sexp ptr, cpp11::sexp r_time) { return R_NilValue; } +template +SEXP dust2_cpu_update_pars(cpp11::sexp ptr, cpp11::list r_pars, + bool grouped) { + auto *obj = cpp11::as_cpp>>(ptr).get(); + if (grouped) { + const auto n_groups = obj->n_groups(); + if (r_pars.size() != static_cast(n_groups)) { + cpp11::stop("Expected 'pars' to have length %d to match 'n_groups'", + static_cast(n_groups)); + } + for (size_t i = 0; i < n_groups; ++i) { + obj->update_shared(i, [&] (auto& shared) { + T::update_shared(r_pars[i], shared); + }); + } + } else { + obj->update_shared(0, [&] (auto& shared) { + T::update_shared(r_pars, shared); + }); + } + return R_NilValue; +} + } } diff --git a/src/cpp11.cpp b/src/cpp11.cpp index 5948e5e7..38ad5623 100644 --- a/src/cpp11.cpp +++ b/src/cpp11.cpp @@ -61,6 +61,13 @@ extern "C" SEXP _dust2_dust2_cpu_walk_set_time(SEXP ptr, SEXP r_time) { return cpp11::as_sexp(dust2_cpu_walk_set_time(cpp11::as_cpp>(ptr), cpp11::as_cpp>(r_time))); END_CPP11 } +// walk.cpp +SEXP dust2_cpu_walk_update_pars(cpp11::sexp ptr, cpp11::list pars, bool grouped); +extern "C" SEXP _dust2_dust2_cpu_walk_update_pars(SEXP ptr, SEXP pars, SEXP grouped) { + BEGIN_CPP11 + return cpp11::as_sexp(dust2_cpu_walk_update_pars(cpp11::as_cpp>(ptr), cpp11::as_cpp>(pars), cpp11::as_cpp>(grouped))); + END_CPP11 +} extern "C" { static const R_CallMethodDef CallEntries[] = { @@ -72,6 +79,7 @@ static const R_CallMethodDef CallEntries[] = { {"_dust2_dust2_cpu_walk_set_time", (DL_FUNC) &_dust2_dust2_cpu_walk_set_time, 2}, {"_dust2_dust2_cpu_walk_state", (DL_FUNC) &_dust2_dust2_cpu_walk_state, 1}, {"_dust2_dust2_cpu_walk_time", (DL_FUNC) &_dust2_dust2_cpu_walk_time, 1}, + {"_dust2_dust2_cpu_walk_update_pars", (DL_FUNC) &_dust2_dust2_cpu_walk_update_pars, 3}, {NULL, NULL, 0} }; } diff --git a/src/walk.cpp b/src/walk.cpp index 214c22d9..33621f0e 100644 --- a/src/walk.cpp +++ b/src/walk.cpp @@ -54,3 +54,9 @@ SEXP dust2_cpu_walk_rng_state(cpp11::sexp ptr) { SEXP dust2_cpu_walk_set_time(cpp11::sexp ptr, cpp11::sexp r_time) { return dust2::r::dust2_cpu_set_time(ptr, r_time); } + +[[cpp11::register]] +SEXP dust2_cpu_walk_update_pars(cpp11::sexp ptr, cpp11::list pars, + bool grouped) { + return dust2::r::dust2_cpu_update_pars(ptr, pars, grouped); +} diff --git a/tests/testthat/test-walk.R b/tests/testthat/test-walk.R index bf4735d4..e93b1af6 100644 --- a/tests/testthat/test-walk.R +++ b/tests/testthat/test-walk.R @@ -1,10 +1,11 @@ test_that("can run simple walk model", { pars <- list(sd = 1, random_initial = TRUE) obj <- dust2_cpu_walk_alloc(pars, 0, 1, 10, 0, 42, FALSE) - expect_length(obj, 3) + expect_length(obj, 4) expect_type(obj[[1]], "externalptr") expect_equal(obj[[2]], 1) - expect_null(obj[[3]]) + expect_false(obj[[3]]) + expect_null(obj[[4]]) ptr <- obj[[1]] expect_type(dust2_cpu_walk_rng_state(ptr), "raw") @@ -171,7 +172,8 @@ test_that("return names passed in with groups", { pars <- lapply(1:4, function(sd) list(sd = sd, random_initial = TRUE)) names(pars) <- letters[1:4] obj <- dust2_cpu_walk_alloc(pars, 0, 1, 10, 4, 42, FALSE) - expect_equal(obj[[3]], letters[1:4]) + expect_true(obj[[3]]) + expect_equal(obj[[4]], letters[1:4]) }) @@ -225,3 +227,52 @@ test_that("can set time", { expect_error(dust2_cpu_walk_set_time(ptr, 0.5), "Expected 'time' to be integer-like") }) + + +test_that("can update parameters", { + pars1 <- list(sd = 1, random_initial = TRUE) + pars2 <- list(sd = 10) + obj <- dust2_cpu_walk_alloc(pars1, 0, 1, 10, 0, 42, FALSE) + ptr <- obj[[1]] + + expect_null(dust2_cpu_walk_run_steps(ptr, 1)) + s1 <- dust2_cpu_walk_state(ptr) + + expect_null(dust2_cpu_walk_update_pars(ptr, pars2, FALSE)) + expect_null(dust2_cpu_walk_run_steps(ptr, 1)) + s2 <- dust2_cpu_walk_state(ptr) + + r <- mcstate2::mcstate_rng$new(seed = 42, n_streams = 10) + expect_equal(s1, drop(r$normal(1, 0, 1))) + expect_equal(s2, s1 + drop(r$normal(1, 0, 10))) +}) + + +test_that("can update parameters for grouped models", { + pars1 <- lapply(1:4, function(sd) list(sd = sd, random_initial = TRUE)) + pars2 <- lapply(1:4, function(sd) list(sd = 10 * sd)) + + obj <- dust2_cpu_walk_alloc(pars1, 0, 1, 10, 4, 42, FALSE) + ptr <- obj[[1]] + + expect_null(dust2_cpu_walk_run_steps(ptr, 1)) + s1 <- dust2_cpu_walk_state(ptr) + + expect_null(dust2_cpu_walk_update_pars(ptr, pars2, TRUE)) + expect_null(dust2_cpu_walk_run_steps(ptr, 1)) + s2 <- dust2_cpu_walk_state(ptr) + + r <- mcstate2::mcstate_rng$new(seed = 42, n_streams = 40) + expect_equal(s1, drop(r$normal(1, 0, 1)) * rep(1:4, each = 10)) + expect_equal(s2, s1 + drop(r$normal(1, 0, 10)) * rep(1:4, each = 10)) +}) + + +test_that("can update parameters for grouped models", { + pars1 <- lapply(1:4, function(sd) list(sd = sd, random_initial = TRUE)) + pars2 <- lapply(1:5, function(sd) list(sd = 10 * sd)) + obj <- dust2_cpu_walk_alloc(pars1, 0, 1, 10, 4, 42, FALSE) + ptr <- obj[[1]] + expect_error(dust2_cpu_walk_update_pars(ptr, pars2, TRUE), + "Expected 'pars' to have length 4 to match 'n_groups'"); +}) From 776c640cc646a0c7be75fb6bfd213505de222d24 Mon Sep 17 00:00:00 2001 From: Rich FitzJohn Date: Fri, 10 May 2024 16:36:34 +0100 Subject: [PATCH 2/2] Update tests/testthat/test-walk.R Co-authored-by: Wes Hinsley --- tests/testthat/test-walk.R | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/testthat/test-walk.R b/tests/testthat/test-walk.R index e93b1af6..ff3a3456 100644 --- a/tests/testthat/test-walk.R +++ b/tests/testthat/test-walk.R @@ -268,7 +268,7 @@ test_that("can update parameters for grouped models", { }) -test_that("can update parameters for grouped models", { +test_that("params must be same length to update", { pars1 <- lapply(1:4, function(sd) list(sd = sd, random_initial = TRUE)) pars2 <- lapply(1:5, function(sd) list(sd = 10 * sd)) obj <- dust2_cpu_walk_alloc(pars1, 0, 1, 10, 4, 42, FALSE)