Skip to content

Commit

Permalink
Merge pull request #2 from mrc-ide/mrc-5326
Browse files Browse the repository at this point in the history
Comparison with data
  • Loading branch information
weshinsley authored May 15, 2024
2 parents 002a8d5 + 18e3fcc commit 8aa6c38
Show file tree
Hide file tree
Showing 9 changed files with 424 additions and 5 deletions.
32 changes: 32 additions & 0 deletions R/cpp11.R
Original file line number Diff line number Diff line change
@@ -1,5 +1,37 @@
# Generated by cpp11: do not edit by hand

dust2_cpu_sir_alloc <- function(r_pars, r_time, r_dt, r_n_particles, r_n_groups, r_seed, r_deterministic) {
.Call(`_dust2_dust2_cpu_sir_alloc`, r_pars, r_time, r_dt, r_n_particles, r_n_groups, r_seed, r_deterministic)
}

dust2_cpu_sir_run_steps <- function(ptr, r_n_steps) {
.Call(`_dust2_dust2_cpu_sir_run_steps`, ptr, r_n_steps)
}

dust2_cpu_sir_state <- function(ptr, grouped) {
.Call(`_dust2_dust2_cpu_sir_state`, ptr, grouped)
}

dust2_cpu_sir_time <- function(ptr) {
.Call(`_dust2_dust2_cpu_sir_time`, ptr)
}

dust2_cpu_sir_set_state_initial <- function(ptr) {
.Call(`_dust2_dust2_cpu_sir_set_state_initial`, ptr)
}

dust2_cpu_sir_set_state <- function(ptr, r_state) {
.Call(`_dust2_dust2_cpu_sir_set_state`, ptr, r_state)
}

dust2_cpu_sir_rng_state <- function(ptr) {
.Call(`_dust2_dust2_cpu_sir_rng_state`, ptr)
}

dust2_cpu_sir_compare_data <- function(ptr, r_data, grouped) {
.Call(`_dust2_dust2_cpu_sir_compare_data`, ptr, r_data, grouped)
}

dust2_cpu_walk_alloc <- function(r_pars, r_time, r_dt, r_n_particles, r_n_groups, r_seed, r_deterministic) {
.Call(`_dust2_dust2_cpu_walk_alloc`, r_pars, r_time, r_dt, r_n_particles, r_n_groups, r_seed, r_deterministic)
}
Expand Down
111 changes: 111 additions & 0 deletions inst/examples/sir.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
#include <dust2/common.hpp>

namespace {
inline double with_default(double default_value, cpp11::sexp value) {
return value == R_NilValue ? default_value : cpp11::as_cpp<double>(value);
}
}

class sir {
public:
sir() = delete;

using real_type = double;

struct shared_state {
real_type N;
real_type I0;
real_type beta;
real_type gamma;
real_type exp_noise;
};

using internal_state = dust2::no_internal_state;

struct data_type {
real_type incidence;
};

using rng_state_type = mcstate::random::generator<real_type>;

static size_t size(const shared_state& shared) {
return 5;
}

static void initial(real_type time,
real_type dt,
const shared_state& shared,
internal_state& internal,
rng_state_type& rng_state,
real_type * state_next) {
state_next[0] = shared.N - shared.I0;
state_next[1] = shared.I0;
state_next[2] = 0;
state_next[3] = 0;
state_next[4] = 0;
}

// The main update function, converting state to state_next
static void update(real_type time,
real_type dt,
const real_type * state,
const shared_state& shared,
internal_state& internal,
rng_state_type& rng_state,
real_type * state_next) {
const auto S = state[0];
const auto I = state[1];
const auto R = state[2];
const auto cases_cumul = state[3];
// const auto cases_inc = state[4];
const auto p_SI = 1 - mcstate::math::exp(-shared.beta * I / shared.N * dt);
const auto p_IR = 1 - mcstate::math::exp(-shared.gamma * dt);
const auto n_SI = mcstate::random::binomial<real_type>(rng_state, S, p_SI);
const auto n_IR = mcstate::random::binomial<real_type>(rng_state, I, p_IR);
state_next[0] = S - n_SI;
state_next[1] = I + n_SI - n_IR;
state_next[2] = R + n_IR;
state_next[3] = cases_cumul + n_SI;
// state_next[4] = (time % shared.freq == 0) ? n_SI : (cases_inc + n_SI);
state_next[4] = n_SI;
}

static shared_state build_shared(cpp11::list pars) {
const real_type I0 = with_default(10, pars["I0"]);
const real_type N = with_default(1000, pars["N"]);

const real_type beta = with_default(0.2, pars["beta"]);
const real_type gamma = with_default(0.1, pars["gamma"]);

const real_type exp_noise = with_default(1e6, pars["exp_noise"]);

return shared_state{N, I0, beta, gamma, exp_noise};
}

static internal_state build_internal(cpp11::list pars) {
return sir::internal_state{};
}

static data_type build_data(cpp11::sexp r_data) {
auto data = static_cast<cpp11::list>(r_data);
return data_type{cpp11::as_cpp<real_type>(data["incidence"])};
}

static real_type compare_data(const real_type time,
const real_type dt,
const real_type * state,
const data_type& data,
const shared_state& shared,
internal_state& internal,
rng_state_type& rng_state) {
const auto incidence_observed = data.incidence;
if (std::isnan(data.incidence)) {
return 0;
}
const auto noise =
mcstate::random::exponential(rng_state, shared.exp_noise);
const auto incidence_modelled = state[4];
const auto lambda = incidence_modelled + noise;
return mcstate::density::poisson(incidence_observed, lambda, true);
}
};
1 change: 1 addition & 0 deletions inst/include/dust2/common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
// This will likely move around a bit. Also note that we have an odd
// choice of directory here.
#include <mcstate/random/random.hpp>
#include <mcstate/random/density.hpp>

namespace dust2 {

Expand Down
13 changes: 13 additions & 0 deletions inst/include/dust2/cpu.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,19 @@ class dust_cpu {
fn(shared_[i]);
}

template <typename It>
void compare_data(const std::vector<data_type>& data, It it) {
const real_type * state_data = state_.data();
for (size_t i = 0; i < n_groups_; ++i) {
for (size_t j = 0; j < n_particles_; ++j, ++it) {
const auto k = n_particles_ * i + j;
const auto offset = k * n_state_;
*it = T::compare_data(time_, dt_, state_data + offset, data[i],
shared_[i], internal_[i], rng_.state(k));
}
}
}

private:
size_t n_state_;
size_t n_particles_;
Expand Down
31 changes: 31 additions & 0 deletions inst/include/dust2/r/cpu.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -197,5 +197,36 @@ SEXP dust2_cpu_update_pars(cpp11::sexp ptr, cpp11::list r_pars,
return R_NilValue;
}

// This one exists to help push around the comparison part of things;
// it's not expected to be called often by users.
template <typename T>
SEXP dust2_cpu_compare_data(cpp11::sexp ptr,
cpp11::sexp r_data,
bool grouped) {
using data_type = typename T::data_type;
auto *obj = cpp11::as_cpp<cpp11::external_pointer<dust_cpu<T>>>(ptr).get();
const auto n_groups = obj->n_groups();
std::vector<data_type> data;
if (grouped) {
auto r_data_list = cpp11::as_cpp<cpp11::list>(r_data);
check_length(r_data_list, n_groups, "data");
for (size_t i = 0; i < n_groups; ++i) {
data.push_back(T::build_data(r_data_list[i]));
}
} else {
if (n_groups > 1) {
cpp11::stop("Can't compare with grouped = FALSE with more than one group");
}
data.push_back(T::build_data(r_data));
}

cpp11::writable::doubles ret(obj->n_particles() * obj->n_groups());
obj->compare_data(data, REAL(ret));
if (grouped) {
set_array_dims(ret, {obj->n_particles(), obj->n_groups()});
}
return ret;
}

}
}
21 changes: 16 additions & 5 deletions inst/include/dust2/r/helpers.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,12 @@ inline void check_scalar(cpp11::sexp x, const char * name) {
}
}

inline void check_length(cpp11::sexp x, int len, const char * name) {
if (LENGTH(x) != len) {
cpp11::stop("'%s' must have length %d", name, len);
}
}

inline double to_double(cpp11::sexp x, const char * name) {
check_scalar(x, name);
if (TYPEOF(x) == REALSXP) {
Expand Down Expand Up @@ -86,18 +92,23 @@ inline double check_dt(cpp11::sexp r_dt) {
}

// The initializer_list is a type-safe variadic-like approach.
template <typename It>
cpp11::sexp export_array_n(It it, std::initializer_list<size_t> dims) {
const auto len =
std::accumulate(dims.begin(), dims.end(), 1, std::multiplies<>{});
inline void set_array_dims(cpp11::sexp data,
std::initializer_list<size_t> dims) {
cpp11::writable::integers r_dim(dims.size());
auto dim_i = dims.begin();
for (size_t i = 0; i < dims.size(); ++i, ++dim_i) {
r_dim[i] = *dim_i;
}
data.attr("dim") = r_dim;
}

template <typename It>
cpp11::sexp export_array_n(It it, std::initializer_list<size_t> dims) {
const auto len =
std::accumulate(dims.begin(), dims.end(), 1, std::multiplies<>{});
cpp11::writable::doubles ret(len);
std::copy_n(it, len, ret.begin());
ret.attr("dim") = r_dim;
set_array_dims(ret, dims);
return ret;
}

Expand Down
64 changes: 64 additions & 0 deletions src/cpp11.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,62 @@
#include "cpp11/declarations.hpp"
#include <R_ext/Visibility.h>

// sir.cpp
SEXP dust2_cpu_sir_alloc(cpp11::list r_pars, cpp11::sexp r_time, cpp11::sexp r_dt, cpp11::sexp r_n_particles, cpp11::sexp r_n_groups, cpp11::sexp r_seed, cpp11::sexp r_deterministic);
extern "C" SEXP _dust2_dust2_cpu_sir_alloc(SEXP r_pars, SEXP r_time, SEXP r_dt, SEXP r_n_particles, SEXP r_n_groups, SEXP r_seed, SEXP r_deterministic) {
BEGIN_CPP11
return cpp11::as_sexp(dust2_cpu_sir_alloc(cpp11::as_cpp<cpp11::decay_t<cpp11::list>>(r_pars), cpp11::as_cpp<cpp11::decay_t<cpp11::sexp>>(r_time), cpp11::as_cpp<cpp11::decay_t<cpp11::sexp>>(r_dt), cpp11::as_cpp<cpp11::decay_t<cpp11::sexp>>(r_n_particles), cpp11::as_cpp<cpp11::decay_t<cpp11::sexp>>(r_n_groups), cpp11::as_cpp<cpp11::decay_t<cpp11::sexp>>(r_seed), cpp11::as_cpp<cpp11::decay_t<cpp11::sexp>>(r_deterministic)));
END_CPP11
}
// sir.cpp
SEXP dust2_cpu_sir_run_steps(cpp11::sexp ptr, cpp11::sexp r_n_steps);
extern "C" SEXP _dust2_dust2_cpu_sir_run_steps(SEXP ptr, SEXP r_n_steps) {
BEGIN_CPP11
return cpp11::as_sexp(dust2_cpu_sir_run_steps(cpp11::as_cpp<cpp11::decay_t<cpp11::sexp>>(ptr), cpp11::as_cpp<cpp11::decay_t<cpp11::sexp>>(r_n_steps)));
END_CPP11
}
// sir.cpp
SEXP dust2_cpu_sir_state(cpp11::sexp ptr, bool grouped);
extern "C" SEXP _dust2_dust2_cpu_sir_state(SEXP ptr, SEXP grouped) {
BEGIN_CPP11
return cpp11::as_sexp(dust2_cpu_sir_state(cpp11::as_cpp<cpp11::decay_t<cpp11::sexp>>(ptr), cpp11::as_cpp<cpp11::decay_t<bool>>(grouped)));
END_CPP11
}
// sir.cpp
SEXP dust2_cpu_sir_time(cpp11::sexp ptr);
extern "C" SEXP _dust2_dust2_cpu_sir_time(SEXP ptr) {
BEGIN_CPP11
return cpp11::as_sexp(dust2_cpu_sir_time(cpp11::as_cpp<cpp11::decay_t<cpp11::sexp>>(ptr)));
END_CPP11
}
// sir.cpp
SEXP dust2_cpu_sir_set_state_initial(cpp11::sexp ptr);
extern "C" SEXP _dust2_dust2_cpu_sir_set_state_initial(SEXP ptr) {
BEGIN_CPP11
return cpp11::as_sexp(dust2_cpu_sir_set_state_initial(cpp11::as_cpp<cpp11::decay_t<cpp11::sexp>>(ptr)));
END_CPP11
}
// sir.cpp
SEXP dust2_cpu_sir_set_state(cpp11::sexp ptr, cpp11::sexp r_state);
extern "C" SEXP _dust2_dust2_cpu_sir_set_state(SEXP ptr, SEXP r_state) {
BEGIN_CPP11
return cpp11::as_sexp(dust2_cpu_sir_set_state(cpp11::as_cpp<cpp11::decay_t<cpp11::sexp>>(ptr), cpp11::as_cpp<cpp11::decay_t<cpp11::sexp>>(r_state)));
END_CPP11
}
// sir.cpp
SEXP dust2_cpu_sir_rng_state(cpp11::sexp ptr);
extern "C" SEXP _dust2_dust2_cpu_sir_rng_state(SEXP ptr) {
BEGIN_CPP11
return cpp11::as_sexp(dust2_cpu_sir_rng_state(cpp11::as_cpp<cpp11::decay_t<cpp11::sexp>>(ptr)));
END_CPP11
}
// sir.cpp
SEXP dust2_cpu_sir_compare_data(cpp11::sexp ptr, cpp11::sexp r_data, bool grouped);
extern "C" SEXP _dust2_dust2_cpu_sir_compare_data(SEXP ptr, SEXP r_data, SEXP grouped) {
BEGIN_CPP11
return cpp11::as_sexp(dust2_cpu_sir_compare_data(cpp11::as_cpp<cpp11::decay_t<cpp11::sexp>>(ptr), cpp11::as_cpp<cpp11::decay_t<cpp11::sexp>>(r_data), cpp11::as_cpp<cpp11::decay_t<bool>>(grouped)));
END_CPP11
}
// walk.cpp
SEXP dust2_cpu_walk_alloc(cpp11::list r_pars, cpp11::sexp r_time, cpp11::sexp r_dt, cpp11::sexp r_n_particles, cpp11::sexp r_n_groups, cpp11::sexp r_seed, cpp11::sexp r_deterministic);
extern "C" SEXP _dust2_dust2_cpu_walk_alloc(SEXP r_pars, SEXP r_time, SEXP r_dt, SEXP r_n_particles, SEXP r_n_groups, SEXP r_seed, SEXP r_deterministic) {
Expand Down Expand Up @@ -71,6 +127,14 @@ extern "C" SEXP _dust2_dust2_cpu_walk_update_pars(SEXP ptr, SEXP pars, SEXP grou

extern "C" {
static const R_CallMethodDef CallEntries[] = {
{"_dust2_dust2_cpu_sir_alloc", (DL_FUNC) &_dust2_dust2_cpu_sir_alloc, 7},
{"_dust2_dust2_cpu_sir_compare_data", (DL_FUNC) &_dust2_dust2_cpu_sir_compare_data, 3},
{"_dust2_dust2_cpu_sir_rng_state", (DL_FUNC) &_dust2_dust2_cpu_sir_rng_state, 1},
{"_dust2_dust2_cpu_sir_run_steps", (DL_FUNC) &_dust2_dust2_cpu_sir_run_steps, 2},
{"_dust2_dust2_cpu_sir_set_state", (DL_FUNC) &_dust2_dust2_cpu_sir_set_state, 2},
{"_dust2_dust2_cpu_sir_set_state_initial", (DL_FUNC) &_dust2_dust2_cpu_sir_set_state_initial, 1},
{"_dust2_dust2_cpu_sir_state", (DL_FUNC) &_dust2_dust2_cpu_sir_state, 2},
{"_dust2_dust2_cpu_sir_time", (DL_FUNC) &_dust2_dust2_cpu_sir_time, 1},
{"_dust2_dust2_cpu_walk_alloc", (DL_FUNC) &_dust2_dust2_cpu_walk_alloc, 7},
{"_dust2_dust2_cpu_walk_rng_state", (DL_FUNC) &_dust2_dust2_cpu_walk_rng_state, 1},
{"_dust2_dust2_cpu_walk_run_steps", (DL_FUNC) &_dust2_dust2_cpu_walk_run_steps, 2},
Expand Down
58 changes: 58 additions & 0 deletions src/sir.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
// Generated by dust2 (version 0.1.0) - do not edit
#include <cpp11.hpp>
#include <dust2/r/cpu.hpp>

// first declarations all go here, with their decorators, once we get
// this bit sorted.

#include "../inst/examples/sir.cpp"

[[cpp11::register]]
SEXP dust2_cpu_sir_alloc(cpp11::list r_pars,
cpp11::sexp r_time,
cpp11::sexp r_dt,
cpp11::sexp r_n_particles,
cpp11::sexp r_n_groups,
cpp11::sexp r_seed,
cpp11::sexp r_deterministic) {
return dust2::r::dust2_cpu_alloc<sir>(r_pars, r_time, r_dt,
r_n_particles, r_n_groups,
r_seed, r_deterministic);
}

[[cpp11::register]]
SEXP dust2_cpu_sir_run_steps(cpp11::sexp ptr, cpp11::sexp r_n_steps) {
return dust2::r::dust2_cpu_run_steps<sir>(ptr, r_n_steps);
}

[[cpp11::register]]
SEXP dust2_cpu_sir_state(cpp11::sexp ptr, bool grouped) {
return dust2::r::dust2_cpu_state<sir>(ptr, grouped);
}

[[cpp11::register]]
SEXP dust2_cpu_sir_time(cpp11::sexp ptr) {
return dust2::r::dust2_cpu_time<sir>(ptr);
}

[[cpp11::register]]
SEXP dust2_cpu_sir_set_state_initial(cpp11::sexp ptr) {
return dust2::r::dust2_cpu_set_state_initial<sir>(ptr);
}

[[cpp11::register]]
SEXP dust2_cpu_sir_set_state(cpp11::sexp ptr, cpp11::sexp r_state) {
return dust2::r::dust2_cpu_set_state<sir>(ptr, r_state);
}

[[cpp11::register]]
SEXP dust2_cpu_sir_rng_state(cpp11::sexp ptr) {
return dust2::r::dust2_cpu_rng_state<sir>(ptr);
}

[[cpp11::register]]
SEXP dust2_cpu_sir_compare_data(cpp11::sexp ptr,
cpp11::sexp r_data,
bool grouped) {
return dust2::r::dust2_cpu_compare_data<sir>(ptr, r_data, grouped);
}
Loading

0 comments on commit 8aa6c38

Please sign in to comment.