Skip to content

Commit

Permalink
Merge pull request #6 from mrc-ide/mrc-5333
Browse files Browse the repository at this point in the history
Export state with dimension attribute
  • Loading branch information
weshinsley authored May 13, 2024
2 parents 3e5a109 + 2b8f32c commit abe450f
Show file tree
Hide file tree
Showing 8 changed files with 99 additions and 56 deletions.
2 changes: 2 additions & 0 deletions .Rbuildignore
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,5 @@
^\.covrignore$
^\.github$
\.*gcov$
^.*\.Rproj$
^\.Rproj\.user$
4 changes: 2 additions & 2 deletions R/cpp11.R
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@ dust2_cpu_walk_run_steps <- function(ptr, r_n_steps) {
.Call(`_dust2_dust2_cpu_walk_run_steps`, ptr, r_n_steps)
}

dust2_cpu_walk_state <- function(ptr) {
.Call(`_dust2_dust2_cpu_walk_state`, ptr)
dust2_cpu_walk_state <- function(ptr, grouped) {
.Call(`_dust2_dust2_cpu_walk_state`, ptr, grouped)
}

dust2_cpu_walk_time <- function(ptr) {
Expand Down
18 changes: 13 additions & 5 deletions inst/include/dust2/cpu.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ class dust_cpu {
std::copy_n(it, state_.size(), state_.begin());
}

auto state() const {
auto& state() const {
return state_;
}

Expand All @@ -101,6 +101,18 @@ class dust_cpu {
return time_;
}

auto n_state() const {
return n_state_;
}

auto n_particles() const {
return n_particles_;
}

auto n_groups() const {
return n_groups_;
}

void set_time(real_type time) {
time_ = time;
}
Expand All @@ -120,10 +132,6 @@ class dust_cpu {
fn(shared_[i]);
}

auto n_groups() const {
return n_groups_;
}

private:
size_t n_state_;
size_t n_particles_;
Expand Down
26 changes: 23 additions & 3 deletions inst/include/dust2/r/cpu.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -81,10 +81,18 @@ SEXP dust2_cpu_run_steps(cpp11::sexp ptr, cpp11::sexp r_n_steps) {
}

template <typename T>
SEXP dust2_cpu_state(cpp11::sexp ptr) {
SEXP dust2_cpu_state(cpp11::sexp ptr, bool grouped) {
auto *obj = cpp11::as_cpp<cpp11::external_pointer<dust_cpu<T>>>(ptr).get();
// return to_matrix(obj->state(), obj->n_state, obj->n_particles);
return cpp11::as_sexp(obj->state());
cpp11::sexp ret = R_NilValue;
const auto it = obj->state().begin();
if (grouped) {
ret = export_array_n(it,
{obj->n_state(), obj->n_particles(), obj->n_groups()});
} else {
ret = export_array_n(it,
{obj->n_state(), obj->n_particles() * obj->n_groups()});
}
return ret;
}

template <typename T>
Expand All @@ -93,6 +101,18 @@ SEXP dust2_cpu_time(cpp11::sexp ptr) {
return cpp11::as_sexp(obj->time());
}

// If this is a grouped model then we will return a matrix with
// dimensions )(state x particle x group) and if we are ungrouped
// (state x particle); the difference is really only apparent in the
// case where we have a single group to drop. In dust1 we had the
// option for (state x group) for simulations too.
//
// For now perhaps lets just ignore this detail and always return the
// 3d version as the code to do grouped/ungrouped switching is deep
// within one of the many in-progress branches and I can't find it
// yet.
//
// In the case where we have time, that goes in the last position
template <typename T>
SEXP dust2_cpu_set_state_initial(cpp11::sexp ptr) {
auto *obj = cpp11::as_cpp<cpp11::external_pointer<dust_cpu<T>>>(ptr).get();
Expand Down
25 changes: 17 additions & 8 deletions inst/include/dust2/r/helpers.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
#pragma once

#include <numeric>

namespace dust2 {
namespace r {

Expand Down Expand Up @@ -83,14 +85,21 @@ inline double check_dt(cpp11::sexp r_dt) {
return dt;
}

// template <typename real_type>
// inline cpp11::sexp to_matrix(std::vector<real_type> x, size_t nr, size_t nc) {
// cpp11::writable::integers dim{static_cast<int>(nr), static_cast<int>(nc)};
// cpp11::writable::doubles ret(x.size());
// std::copy(x.begin(), x.end(), REAL(ret));
// ret.attr("dim") = dim;
// return ret;
// }
// The initializer_list is a type-safe variadic-like approach.
template <typename It>
cpp11::sexp export_array_n(It it, std::initializer_list<size_t> dims) {
const auto len =
std::accumulate(dims.begin(), dims.end(), 1, std::multiplies<>{});
cpp11::writable::integers r_dim(dims.size());
auto dim_i = dims.begin();
for (size_t i = 0; i < dims.size(); ++i, ++dim_i) {
r_dim[i] = *dim_i;
}
cpp11::writable::doubles ret(len);
std::copy_n(it, len, ret.begin());
ret.attr("dim") = r_dim;
return ret;
}

}
}
8 changes: 4 additions & 4 deletions src/cpp11.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,10 @@ extern "C" SEXP _dust2_dust2_cpu_walk_run_steps(SEXP ptr, SEXP r_n_steps) {
END_CPP11
}
// walk.cpp
SEXP dust2_cpu_walk_state(cpp11::sexp ptr);
extern "C" SEXP _dust2_dust2_cpu_walk_state(SEXP ptr) {
SEXP dust2_cpu_walk_state(cpp11::sexp ptr, bool grouped);
extern "C" SEXP _dust2_dust2_cpu_walk_state(SEXP ptr, SEXP grouped) {
BEGIN_CPP11
return cpp11::as_sexp(dust2_cpu_walk_state(cpp11::as_cpp<cpp11::decay_t<cpp11::sexp>>(ptr)));
return cpp11::as_sexp(dust2_cpu_walk_state(cpp11::as_cpp<cpp11::decay_t<cpp11::sexp>>(ptr), cpp11::as_cpp<cpp11::decay_t<bool>>(grouped)));
END_CPP11
}
// walk.cpp
Expand Down Expand Up @@ -77,7 +77,7 @@ static const R_CallMethodDef CallEntries[] = {
{"_dust2_dust2_cpu_walk_set_state", (DL_FUNC) &_dust2_dust2_cpu_walk_set_state, 2},
{"_dust2_dust2_cpu_walk_set_state_initial", (DL_FUNC) &_dust2_dust2_cpu_walk_set_state_initial, 1},
{"_dust2_dust2_cpu_walk_set_time", (DL_FUNC) &_dust2_dust2_cpu_walk_set_time, 2},
{"_dust2_dust2_cpu_walk_state", (DL_FUNC) &_dust2_dust2_cpu_walk_state, 1},
{"_dust2_dust2_cpu_walk_state", (DL_FUNC) &_dust2_dust2_cpu_walk_state, 2},
{"_dust2_dust2_cpu_walk_time", (DL_FUNC) &_dust2_dust2_cpu_walk_time, 1},
{"_dust2_dust2_cpu_walk_update_pars", (DL_FUNC) &_dust2_dust2_cpu_walk_update_pars, 3},
{NULL, NULL, 0}
Expand Down
4 changes: 2 additions & 2 deletions src/walk.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@ SEXP dust2_cpu_walk_run_steps(cpp11::sexp ptr, cpp11::sexp r_n_steps) {
}

[[cpp11::register]]
SEXP dust2_cpu_walk_state(cpp11::sexp ptr) {
return dust2::r::dust2_cpu_state<walk>(ptr);
SEXP dust2_cpu_walk_state(cpp11::sexp ptr, bool grouped) {
return dust2::r::dust2_cpu_state<walk>(ptr, grouped);
}

[[cpp11::register]]
Expand Down
68 changes: 36 additions & 32 deletions tests/testthat/test-walk.R
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,14 @@ test_that("can run simple walk model", {
expect_type(dust2_cpu_walk_rng_state(ptr), "raw")
expect_length(dust2_cpu_walk_rng_state(ptr), 32 * 10)

expect_equal(dust2_cpu_walk_state(ptr), rep(0, 10))
expect_equal(dust2_cpu_walk_state(ptr, FALSE), matrix(0, 1, 10))
expect_equal(dust2_cpu_walk_time(ptr), 0)

expect_null(dust2_cpu_walk_run_steps(ptr, 3))
s <- dust2_cpu_walk_state(ptr)
s <- dust2_cpu_walk_state(ptr, FALSE)

r <- mcstate2::mcstate_rng$new(seed = 42, n_streams = 10)
expect_equal(s, colSums(r$normal(3, 0, 1)))
expect_equal(s, rbind(colSums(r$normal(3, 0, 1))))
expect_equal(dust2_cpu_walk_time(ptr), 3)
})

Expand All @@ -29,13 +29,14 @@ test_that("can set model state from a vector", {
ptr <- obj[[1]]
s <- runif(10)
expect_null(dust2_cpu_walk_set_state(ptr, s))
expect_equal(dust2_cpu_walk_state(ptr), s)
expect_equal(dust2_cpu_walk_state(ptr, FALSE),
rbind(s, deparse.level = 0))

expect_null(dust2_cpu_walk_run_steps(ptr, 3))

r <- mcstate2::mcstate_rng$new(seed = 42, n_streams = 10)
expect_equal(dust2_cpu_walk_state(ptr),
colSums(r$normal(3, 0, 1)) + s)
expect_equal(dust2_cpu_walk_state(ptr, FALSE),
rbind(colSums(r$normal(3, 0, 1)) + s))
})


Expand All @@ -45,8 +46,8 @@ test_that("can set model state from initial conditions", {
ptr <- obj[[1]]
expect_null(dust2_cpu_walk_set_state_initial(ptr))
r <- mcstate2::mcstate_rng$new(seed = 42, n_streams = 10)
expect_equal(dust2_cpu_walk_state(ptr),
drop(r$normal(1, 0, 1)))
expect_equal(dust2_cpu_walk_state(ptr, FALSE),
r$normal(1, 0, 1))
})


Expand All @@ -57,8 +58,8 @@ test_that("can set model state from initial conditions with empty version", {
obj <- dust2_cpu_walk_alloc(pars, 0, 1, 10, 0, 42, FALSE)
ptr <- obj[[1]]
expect_null(dust2_cpu_walk_set_state_initial(ptr))
expect_equal(dust2_cpu_walk_state(ptr),
rep(0, 10))
expect_equal(dust2_cpu_walk_state(ptr, FALSE),
matrix(0, 1, 10))
})


Expand All @@ -67,8 +68,8 @@ test_that("can run deterministically", {
obj <- dust2_cpu_walk_alloc(pars, 0, 1, 10, 0, 42, TRUE)
ptr <- obj[[1]]
expect_null(dust2_cpu_walk_run_steps(ptr, 3))
expect_equal(dust2_cpu_walk_state(ptr),
rep(0, 10))
expect_equal(dust2_cpu_walk_state(ptr, FALSE),
matrix(0, 1, 10))
})


Expand Down Expand Up @@ -131,12 +132,12 @@ test_that("validate inputs", {

expect_identical(
dust2_cpu_walk_state(
dust2_cpu_walk_alloc(pars, 5, 1, 10, 0, 42, FALSE)[[1]]),
rep(0, 10))
dust2_cpu_walk_alloc(pars, 5, 1, 10, 0, 42, FALSE)[[1]], FALSE),
matrix(0, 1, 10))
expect_identical(
dust2_cpu_walk_state(
dust2_cpu_walk_alloc(pars, 5, 1, 10L, 0, 42, FALSE)[[1]]),
rep(0, 10))
dust2_cpu_walk_alloc(pars, 5, 1, 10L, 0, 42, FALSE)[[1]], FALSE),
matrix(0, 1, 10))
expect_error(
dust2_cpu_walk_alloc(pars, 5, 1, 9.5, 0, 42, FALSE),
"'n_particles' must be integer-like")
Expand All @@ -157,13 +158,15 @@ test_that("can initialise multiple groups with different parameter sets", {
pars <- lapply(1:4, function(sd) list(sd = sd, random_initial = TRUE))
obj <- dust2_cpu_walk_alloc(pars, 0, 1, 10, 4, 42, FALSE)
ptr <- obj[[1]]
expect_equal(dust2_cpu_walk_state(ptr), rep(0, 40))
expect_equal(dust2_cpu_walk_state(ptr, TRUE), array(0, c(1, 10, 4)))

expect_null(dust2_cpu_walk_run_steps(ptr, 3))
s <- dust2_cpu_walk_state(ptr)
s <- dust2_cpu_walk_state(ptr, TRUE)

r <- mcstate2::mcstate_rng$new(seed = 42, n_streams = 40)
expect_equal(s, colSums(r$normal(3, 0, 1)) * rep(1:4, each = 10))
expect_equal(
s,
array(colSums(r$normal(3, 0, 1)) * rep(1:4, each = 10), c(1, 10, 4)))
expect_equal(dust2_cpu_walk_time(ptr), 3)
})

Expand All @@ -182,18 +185,19 @@ test_that("can create multi-state walk model", {
obj <- dust2_cpu_walk_alloc(pars, 0, 1, 10, 0, 42, FALSE)
expect_equal(obj[[2]], 3)
ptr <- obj[[1]]
expect_equal(dust2_cpu_walk_state(ptr), rep(0, 30))
expect_equal(dust2_cpu_walk_state(ptr, FALSE), matrix(0, 3, 10))
expect_equal(dust2_cpu_walk_state(ptr, TRUE), array(0, c(3, 10, 1)))
expect_null(dust2_cpu_walk_set_state_initial(ptr))

r <- mcstate2::mcstate_rng$new(seed = 42, n_streams = 10)
s0 <- dust2_cpu_walk_state(ptr)
expect_equal(s0, c(r$normal(3, 0, 1)))
s0 <- dust2_cpu_walk_state(ptr, FALSE)
expect_equal(s0, r$normal(3, 0, 1))

expect_null(dust2_cpu_walk_run_steps(ptr, 5))
s1 <- dust2_cpu_walk_state(ptr)
s1 <- dust2_cpu_walk_state(ptr, FALSE)

cmp <- r$normal(3 * 5, 0, 1)
expect_equal(s1, s0 + c(apply(array(cmp, c(3, 5, 10)), c(1, 3), sum)))
expect_equal(s1, s0 + apply(array(cmp, c(3, 5, 10)), c(1, 3), sum))
})


Expand Down Expand Up @@ -236,15 +240,15 @@ test_that("can update parameters", {
ptr <- obj[[1]]

expect_null(dust2_cpu_walk_run_steps(ptr, 1))
s1 <- dust2_cpu_walk_state(ptr)
s1 <- dust2_cpu_walk_state(ptr, FALSE)

expect_null(dust2_cpu_walk_update_pars(ptr, pars2, FALSE))
expect_null(dust2_cpu_walk_run_steps(ptr, 1))
s2 <- dust2_cpu_walk_state(ptr)
s2 <- dust2_cpu_walk_state(ptr, FALSE)

r <- mcstate2::mcstate_rng$new(seed = 42, n_streams = 10)
expect_equal(s1, drop(r$normal(1, 0, 1)))
expect_equal(s2, s1 + drop(r$normal(1, 0, 10)))
expect_equal(s1, r$normal(1, 0, 1))
expect_equal(s2, s1 + r$normal(1, 0, 10))
})


Expand All @@ -256,15 +260,15 @@ test_that("can update parameters for grouped models", {
ptr <- obj[[1]]

expect_null(dust2_cpu_walk_run_steps(ptr, 1))
s1 <- dust2_cpu_walk_state(ptr)
s1 <- dust2_cpu_walk_state(ptr, FALSE)

expect_null(dust2_cpu_walk_update_pars(ptr, pars2, TRUE))
expect_null(dust2_cpu_walk_run_steps(ptr, 1))
s2 <- dust2_cpu_walk_state(ptr)
s2 <- dust2_cpu_walk_state(ptr, FALSE)

r <- mcstate2::mcstate_rng$new(seed = 42, n_streams = 40)
expect_equal(s1, drop(r$normal(1, 0, 1)) * rep(1:4, each = 10))
expect_equal(s2, s1 + drop(r$normal(1, 0, 10)) * rep(1:4, each = 10))
expect_equal(s1, r$normal(1, 0, 1) * rep(1:4, each = 10))
expect_equal(s2, s1 + r$normal(1, 0, 10) * rep(1:4, each = 10))
})


Expand Down

0 comments on commit abe450f

Please sign in to comment.