Skip to content

Commit

Permalink
Merge pull request #5 from mrc-ide/mrc-5339
Browse files Browse the repository at this point in the history
Allow updating parameters
  • Loading branch information
richfitz authored May 10, 2024
2 parents 8db8a6a + 776c640 commit 3e5a109
Show file tree
Hide file tree
Showing 6 changed files with 114 additions and 7 deletions.
4 changes: 4 additions & 0 deletions R/cpp11.R
Original file line number Diff line number Diff line change
Expand Up @@ -31,3 +31,7 @@ dust2_cpu_walk_rng_state <- function(ptr) {
dust2_cpu_walk_set_time <- function(ptr, r_time) {
.Call(`_dust2_dust2_cpu_walk_set_time`, ptr, r_time)
}

dust2_cpu_walk_update_pars <- function(ptr, pars, grouped) {
.Call(`_dust2_dust2_cpu_walk_update_pars`, ptr, pars, grouped)
}
11 changes: 11 additions & 0 deletions inst/include/dust2/cpu.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,17 @@ class dust_cpu {
rng_.import_state();
}

template<typename Fn>
void update_shared(size_t i, Fn fn) {
// TODO: check that size was not modified, error if so (quite a
// bit later).
fn(shared_[i]);
}

auto n_groups() const {
return n_groups_;
}

private:
size_t n_state_;
size_t n_particles_;
Expand Down
35 changes: 31 additions & 4 deletions inst/include/dust2/r/cpu.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@ SEXP dust2_cpu_alloc(cpp11::list r_pars,

size_t size = 0;
cpp11::sexp group_names = R_NilValue;
if (n_groups == 0) {
const auto grouped = n_groups > 0;
if (!grouped) {
shared.push_back(T::build_shared(r_pars));
internal.push_back(T::build_internal(r_pars));
size = T::size(shared[0]);
Expand Down Expand Up @@ -62,10 +63,13 @@ SEXP dust2_cpu_alloc(cpp11::list r_pars,
cpp11::external_pointer<dust_cpu<T>> ptr(obj, true, false);

// Later, we'll export a bit more back from the model (in particular
// models need to provide information about how they organise
// variables, ode models report computed control, etc.
// models need to provide information about ~how they organise
// variables~, ode models report computed control, etc.

return cpp11::writable::list{ptr, cpp11::as_sexp(size), group_names};
cpp11::sexp r_size = cpp11::as_sexp(size);
cpp11::sexp r_grouped = cpp11::as_sexp(grouped);

return cpp11::writable::list{ptr, r_size, r_grouped, group_names};
}

template <typename T>
Expand Down Expand Up @@ -123,5 +127,28 @@ SEXP dust2_cpu_set_time(cpp11::sexp ptr, cpp11::sexp r_time) {
return R_NilValue;
}

template <typename T>
SEXP dust2_cpu_update_pars(cpp11::sexp ptr, cpp11::list r_pars,
bool grouped) {
auto *obj = cpp11::as_cpp<cpp11::external_pointer<dust_cpu<T>>>(ptr).get();
if (grouped) {
const auto n_groups = obj->n_groups();
if (r_pars.size() != static_cast<int>(n_groups)) {
cpp11::stop("Expected 'pars' to have length %d to match 'n_groups'",
static_cast<int>(n_groups));
}
for (size_t i = 0; i < n_groups; ++i) {
obj->update_shared(i, [&] (auto& shared) {
T::update_shared(r_pars[i], shared);
});
}
} else {
obj->update_shared(0, [&] (auto& shared) {
T::update_shared(r_pars, shared);
});
}
return R_NilValue;
}

}
}
8 changes: 8 additions & 0 deletions src/cpp11.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,13 @@ extern "C" SEXP _dust2_dust2_cpu_walk_set_time(SEXP ptr, SEXP r_time) {
return cpp11::as_sexp(dust2_cpu_walk_set_time(cpp11::as_cpp<cpp11::decay_t<cpp11::sexp>>(ptr), cpp11::as_cpp<cpp11::decay_t<cpp11::sexp>>(r_time)));
END_CPP11
}
// walk.cpp
SEXP dust2_cpu_walk_update_pars(cpp11::sexp ptr, cpp11::list pars, bool grouped);
extern "C" SEXP _dust2_dust2_cpu_walk_update_pars(SEXP ptr, SEXP pars, SEXP grouped) {
BEGIN_CPP11
return cpp11::as_sexp(dust2_cpu_walk_update_pars(cpp11::as_cpp<cpp11::decay_t<cpp11::sexp>>(ptr), cpp11::as_cpp<cpp11::decay_t<cpp11::list>>(pars), cpp11::as_cpp<cpp11::decay_t<bool>>(grouped)));
END_CPP11
}

extern "C" {
static const R_CallMethodDef CallEntries[] = {
Expand All @@ -72,6 +79,7 @@ static const R_CallMethodDef CallEntries[] = {
{"_dust2_dust2_cpu_walk_set_time", (DL_FUNC) &_dust2_dust2_cpu_walk_set_time, 2},
{"_dust2_dust2_cpu_walk_state", (DL_FUNC) &_dust2_dust2_cpu_walk_state, 1},
{"_dust2_dust2_cpu_walk_time", (DL_FUNC) &_dust2_dust2_cpu_walk_time, 1},
{"_dust2_dust2_cpu_walk_update_pars", (DL_FUNC) &_dust2_dust2_cpu_walk_update_pars, 3},
{NULL, NULL, 0}
};
}
Expand Down
6 changes: 6 additions & 0 deletions src/walk.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,3 +54,9 @@ SEXP dust2_cpu_walk_rng_state(cpp11::sexp ptr) {
SEXP dust2_cpu_walk_set_time(cpp11::sexp ptr, cpp11::sexp r_time) {
return dust2::r::dust2_cpu_set_time<walk>(ptr, r_time);
}

[[cpp11::register]]
SEXP dust2_cpu_walk_update_pars(cpp11::sexp ptr, cpp11::list pars,
bool grouped) {
return dust2::r::dust2_cpu_update_pars<walk>(ptr, pars, grouped);
}
57 changes: 54 additions & 3 deletions tests/testthat/test-walk.R
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
test_that("can run simple walk model", {
pars <- list(sd = 1, random_initial = TRUE)
obj <- dust2_cpu_walk_alloc(pars, 0, 1, 10, 0, 42, FALSE)
expect_length(obj, 3)
expect_length(obj, 4)
expect_type(obj[[1]], "externalptr")
expect_equal(obj[[2]], 1)
expect_null(obj[[3]])
expect_false(obj[[3]])
expect_null(obj[[4]])

ptr <- obj[[1]]
expect_type(dust2_cpu_walk_rng_state(ptr), "raw")
Expand Down Expand Up @@ -171,7 +172,8 @@ test_that("return names passed in with groups", {
pars <- lapply(1:4, function(sd) list(sd = sd, random_initial = TRUE))
names(pars) <- letters[1:4]
obj <- dust2_cpu_walk_alloc(pars, 0, 1, 10, 4, 42, FALSE)
expect_equal(obj[[3]], letters[1:4])
expect_true(obj[[3]])
expect_equal(obj[[4]], letters[1:4])
})


Expand Down Expand Up @@ -225,3 +227,52 @@ test_that("can set time", {
expect_error(dust2_cpu_walk_set_time(ptr, 0.5),
"Expected 'time' to be integer-like")
})


test_that("can update parameters", {
pars1 <- list(sd = 1, random_initial = TRUE)
pars2 <- list(sd = 10)
obj <- dust2_cpu_walk_alloc(pars1, 0, 1, 10, 0, 42, FALSE)
ptr <- obj[[1]]

expect_null(dust2_cpu_walk_run_steps(ptr, 1))
s1 <- dust2_cpu_walk_state(ptr)

expect_null(dust2_cpu_walk_update_pars(ptr, pars2, FALSE))
expect_null(dust2_cpu_walk_run_steps(ptr, 1))
s2 <- dust2_cpu_walk_state(ptr)

r <- mcstate2::mcstate_rng$new(seed = 42, n_streams = 10)
expect_equal(s1, drop(r$normal(1, 0, 1)))
expect_equal(s2, s1 + drop(r$normal(1, 0, 10)))
})


test_that("can update parameters for grouped models", {
pars1 <- lapply(1:4, function(sd) list(sd = sd, random_initial = TRUE))
pars2 <- lapply(1:4, function(sd) list(sd = 10 * sd))

obj <- dust2_cpu_walk_alloc(pars1, 0, 1, 10, 4, 42, FALSE)
ptr <- obj[[1]]

expect_null(dust2_cpu_walk_run_steps(ptr, 1))
s1 <- dust2_cpu_walk_state(ptr)

expect_null(dust2_cpu_walk_update_pars(ptr, pars2, TRUE))
expect_null(dust2_cpu_walk_run_steps(ptr, 1))
s2 <- dust2_cpu_walk_state(ptr)

r <- mcstate2::mcstate_rng$new(seed = 42, n_streams = 40)
expect_equal(s1, drop(r$normal(1, 0, 1)) * rep(1:4, each = 10))
expect_equal(s2, s1 + drop(r$normal(1, 0, 10)) * rep(1:4, each = 10))
})


test_that("params must be same length to update", {
pars1 <- lapply(1:4, function(sd) list(sd = sd, random_initial = TRUE))
pars2 <- lapply(1:5, function(sd) list(sd = 10 * sd))
obj <- dust2_cpu_walk_alloc(pars1, 0, 1, 10, 4, 42, FALSE)
ptr <- obj[[1]]
expect_error(dust2_cpu_walk_update_pars(ptr, pars2, TRUE),
"Expected 'pars' to have length 4 to match 'n_groups'");
})

0 comments on commit 3e5a109

Please sign in to comment.