Skip to content

Commit

Permalink
Merge pull request #8 from mrc-ide/mrc-5336
Browse files Browse the repository at this point in the history
Deterministic filter-like device
  • Loading branch information
richfitz authored May 20, 2024
2 parents b844a45 + ec3791c commit 762d9a0
Show file tree
Hide file tree
Showing 8 changed files with 393 additions and 0 deletions.
8 changes: 8 additions & 0 deletions R/cpp11.R
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,14 @@ dust2_cpu_sir_compare_data <- function(ptr, r_data, grouped) {
.Call(`_dust2_dust2_cpu_sir_compare_data`, ptr, r_data, grouped)
}

dust2_cpu_sir_unfilter_alloc <- function(r_pars, r_time_start, r_time, r_dt, r_data, r_n_particles, r_n_groups) {
.Call(`_dust2_dust2_cpu_sir_unfilter_alloc`, r_pars, r_time_start, r_time, r_dt, r_data, r_n_particles, r_n_groups)
}

dust2_cpu_sir_unfilter_run <- function(ptr, r_pars, grouped) {
.Call(`_dust2_dust2_cpu_sir_unfilter_run`, ptr, r_pars, grouped)
}

dust2_cpu_walk_alloc <- function(r_pars, r_time, r_dt, r_n_particles, r_n_groups, r_seed, r_deterministic) {
.Call(`_dust2_dust2_cpu_walk_alloc`, r_pars, r_time, r_dt, r_n_particles, r_n_groups, r_seed, r_deterministic)
}
Expand Down
7 changes: 7 additions & 0 deletions inst/include/dust2/cpu.hpp
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
#pragma once

#include <vector>

namespace dust2 {

template <typename T>
class dust_cpu {
public:
using model_type = T;
using real_type = typename T::real_type;
using rng_state_type = typename T::rng_state_type;
using shared_state = typename T::shared_state;
Expand Down Expand Up @@ -129,6 +132,10 @@ class dust_cpu {
return time_;
}

auto dt() const {
return dt_;
}

auto n_state() const {
return n_state_;
}
Expand Down
70 changes: 70 additions & 0 deletions inst/include/dust2/filter.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
#pragma once

#include <dust2/cpu.hpp>

namespace dust2 {

template <typename T>
class unfilter {
public:
using real_type = typename T::real_type;
using data_type = typename T::data_type;

// We need to provide direct access to the model, because the user
// will want to set parameters in, and pull out state, etc.
dust_cpu<T> model;

unfilter(dust_cpu<T> model_,
real_type time_start,
std::vector<real_type> time,
std::vector<data_type> data) :
model(model_),
time_start_(time_start),
time_(time),
data_(data),
n_particles_(model.n_particles()),
n_groups_(model.n_groups()),
ll_(n_particles_ * n_groups_, 0),
ll_step_(n_particles_ * n_groups_, 0) {
const auto dt = model_.dt();
for (size_t i = 0; i < time_.size(); i++) {
const auto t0 = i == 0 ? time_start_ : time_[i - 1];
const auto t1 = time_[i];
step_.push_back(static_cast<size_t>(std::round((t1 - t0) / dt)));
}
}

void run() {
const auto n_times = step_.size();

model.set_time(time_start_);
model.set_state_initial();
std::fill(ll_.begin(), ll_.end(), 0);

auto it_data = data_.begin();
for (size_t i = 0; i < n_times; ++i, it_data += n_groups_) {
model.run_steps(step_[i]); // just compute this at point of use?
model.compare_data(it_data, ll_step_.begin());
for (size_t j = 0; j < ll_.size(); ++j) {
ll_[j] += ll_step_[j];
}
}
}

template <typename Iter>
void last_log_likelihood(Iter iter) {
std::copy(ll_.begin(), ll_.end(), iter);
}

private:
real_type time_start_;
std::vector<real_type> time_;
std::vector<size_t> step_;
std::vector<data_type> data_;
size_t n_particles_;
size_t n_groups_;
std::vector<real_type> ll_;
std::vector<real_type> ll_step_;
};

}
79 changes: 79 additions & 0 deletions inst/include/dust2/r/filter.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
#pragma once

#include <dust2/r/helpers.hpp>
#include <dust2/filter.hpp>

namespace dust2 {
namespace r {

// TODO: this name must be changed!
template <typename T>
cpp11::sexp dust2_cpu_unfilter_alloc(cpp11::list r_pars,
cpp11::sexp r_time_start,
cpp11::sexp r_time,
cpp11::sexp r_dt,
cpp11::list r_data,
cpp11::sexp r_n_particles,
cpp11::sexp r_n_groups) {
using real_type = typename T::real_type;
using rng_state_type = typename T::rng_state_type;

auto n_particles = to_size(r_n_particles, "n_particles");
auto n_groups = to_size(r_n_groups, "n_groups");
const bool grouped = n_groups > 0;
const auto time_start = check_time(r_time_start, "time_start");
const auto time = check_time_sequence<real_type>(time_start, r_time, "time");
const auto dt = check_dt(r_dt);
const auto shared = build_shared<T>(r_pars, n_groups);
const auto internal = build_internal<T>(shared);
const auto data = check_data<T>(r_data, time.size(), n_groups, "data");

// It's possible that we don't want to always really be
// deterministic here? Though nooone can think of a case where
// that's actually the behaviour wanted. For now let's go fully
// deterministic.
auto seed = mcstate::random::r::as_rng_seed<rng_state_type>(R_NilValue);
const auto deterministic = true;

// Then allocate the model; this pulls together almost all the data
// we need. At this point we could have constructed the model out
// of one that exists already on the R side, but I think that's
// going to feel weirder overall.
const auto model = dust2::dust_cpu<T>(shared, internal, time_start, dt, n_particles,
seed, deterministic);

auto obj = new unfilter<T>(model, time_start, time, data);
cpp11::external_pointer<unfilter<T>> ptr(obj, true, false);

cpp11::sexp r_n_state = cpp11::as_sexp(obj->model.n_state());
cpp11::sexp r_group_names = R_NilValue;
if (grouped) {
r_group_names = r_pars.attr("names");
}
cpp11::sexp r_grouped = cpp11::as_sexp(grouped);

return cpp11::writable::list{ptr, r_n_state, r_grouped, r_group_names};
}

template <typename T>
cpp11::sexp dust2_cpu_unfilter_run(cpp11::sexp ptr, cpp11::sexp r_pars,
bool grouped) {
auto *obj =
cpp11::as_cpp<cpp11::external_pointer<unfilter<T>>>(ptr).get();
if (r_pars != R_NilValue) {
update_pars(obj->model, cpp11::as_cpp<cpp11::list>(r_pars), grouped);
}
obj->run();

const auto n_groups = obj->model.n_groups();
const auto n_particles = obj->model.n_particles();
cpp11::writable::doubles ret(n_groups * n_particles);
obj->last_log_likelihood(REAL(ret));
if (grouped && n_particles > 1) {
set_array_dims(ret, {n_particles, n_groups});
}
return ret;
}

}
}
70 changes: 70 additions & 0 deletions inst/include/dust2/r/helpers.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,23 @@ inline double to_double(cpp11::sexp x, const char * name) {
cpp11::stop("'%s' must be scalar numeric", name);
}

template <typename real_type>
inline std::vector<real_type> to_vector_real(cpp11::sexp x, const char * name) {
if (TYPEOF(x) == REALSXP) {
auto x_dbl = cpp11::as_cpp<cpp11::doubles>(x);
std::vector<real_type> ret(x_dbl.size());
std::copy(x_dbl.begin(), x_dbl.end(), ret.begin());
return ret;
}
if (TYPEOF(x) == INTSXP) {
auto x_int = cpp11::as_cpp<cpp11::integers>(x);
std::vector<real_type> ret(x_int.size());
std::copy(x_int.begin(), x_int.end(), ret.begin());
return ret;
}
cpp11::stop("'%s' must be a numeric vector", name);
}

inline int to_int(cpp11::sexp x, const char * name) {
check_scalar(x, name);
if (TYPEOF(x) == INTSXP) {
Expand Down Expand Up @@ -94,6 +111,27 @@ inline double check_dt(cpp11::sexp r_dt) {
return dt;
}

template <typename real_type>
std::vector<real_type> check_time_sequence(real_type time_start,
cpp11::sexp r_time,
const char * name) {
auto time = to_vector_real<real_type>(r_time, name);
auto prev = time_start;
const auto eps = 1e-8;
for (size_t i = 0; i < time.size(); ++i) {
const auto t = time[i];
if (!is_integer_like(t, eps)) {
cpp11::stop("Expected 'time[%d]' to be integer-like", i + 1);
}
if (t <= prev) {
cpp11::stop("Expected 'time[%d]' (%d) to be larger than the previous value (%d)",
i + 1, static_cast<int>(prev), static_cast<int>(t));
}
prev = t;
}
return time;
}

// The initializer_list is a type-safe variadic-like approach.
inline void set_array_dims(cpp11::sexp data,
std::initializer_list<size_t> dims) {
Expand Down Expand Up @@ -174,5 +212,37 @@ void update_pars(dust_cpu<T>& obj, cpp11::list r_pars, bool grouped) {
}
}

template <typename T>
std::vector<typename T::data_type> check_data(cpp11::list r_data,
size_t n_time,
size_t n_groups,
const char * name) {
const bool grouped = n_groups > 0;
std::vector<typename T::data_type> data;

check_length(r_data, n_time, name);

if (grouped) {
// There are two ways we might recieve things; as a list-of-lists
// or as a list matrix. We might also want to cope with a
// data.frame but we can probably do that on the R side, and might
// provide helpers there that throw much nicer errors than we can
// throw here, really.
for (size_t i = 0; i < n_time; ++i) {
auto r_data_i = cpp11::as_cpp<cpp11::list>(r_data[i]);
check_length(r_data_i, n_groups, "data[i]"); // can do better with sstream
for (size_t j = 0; j < n_groups; ++j) {
data.push_back(T::build_data(r_data_i[j]));
}
}
} else {
for (size_t i = 0; i < n_time; ++i) {
data.push_back(T::build_data(r_data[i]));
}
}

return data;
}

}
}
16 changes: 16 additions & 0 deletions src/cpp11.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,20 @@ extern "C" SEXP _dust2_dust2_cpu_sir_compare_data(SEXP ptr, SEXP r_data, SEXP gr
return cpp11::as_sexp(dust2_cpu_sir_compare_data(cpp11::as_cpp<cpp11::decay_t<cpp11::sexp>>(ptr), cpp11::as_cpp<cpp11::decay_t<cpp11::sexp>>(r_data), cpp11::as_cpp<cpp11::decay_t<bool>>(grouped)));
END_CPP11
}
// sir.cpp
SEXP dust2_cpu_sir_unfilter_alloc(cpp11::list r_pars, cpp11::sexp r_time_start, cpp11::sexp r_time, cpp11::sexp r_dt, cpp11::list r_data, cpp11::sexp r_n_particles, cpp11::sexp r_n_groups);
extern "C" SEXP _dust2_dust2_cpu_sir_unfilter_alloc(SEXP r_pars, SEXP r_time_start, SEXP r_time, SEXP r_dt, SEXP r_data, SEXP r_n_particles, SEXP r_n_groups) {
BEGIN_CPP11
return cpp11::as_sexp(dust2_cpu_sir_unfilter_alloc(cpp11::as_cpp<cpp11::decay_t<cpp11::list>>(r_pars), cpp11::as_cpp<cpp11::decay_t<cpp11::sexp>>(r_time_start), cpp11::as_cpp<cpp11::decay_t<cpp11::sexp>>(r_time), cpp11::as_cpp<cpp11::decay_t<cpp11::sexp>>(r_dt), cpp11::as_cpp<cpp11::decay_t<cpp11::list>>(r_data), cpp11::as_cpp<cpp11::decay_t<cpp11::sexp>>(r_n_particles), cpp11::as_cpp<cpp11::decay_t<cpp11::sexp>>(r_n_groups)));
END_CPP11
}
// sir.cpp
SEXP dust2_cpu_sir_unfilter_run(cpp11::sexp ptr, cpp11::sexp r_pars, bool grouped);
extern "C" SEXP _dust2_dust2_cpu_sir_unfilter_run(SEXP ptr, SEXP r_pars, SEXP grouped) {
BEGIN_CPP11
return cpp11::as_sexp(dust2_cpu_sir_unfilter_run(cpp11::as_cpp<cpp11::decay_t<cpp11::sexp>>(ptr), cpp11::as_cpp<cpp11::decay_t<cpp11::sexp>>(r_pars), cpp11::as_cpp<cpp11::decay_t<bool>>(grouped)));
END_CPP11
}
// walk.cpp
SEXP dust2_cpu_walk_alloc(cpp11::list r_pars, cpp11::sexp r_time, cpp11::sexp r_dt, cpp11::sexp r_n_particles, cpp11::sexp r_n_groups, cpp11::sexp r_seed, cpp11::sexp r_deterministic);
extern "C" SEXP _dust2_dust2_cpu_walk_alloc(SEXP r_pars, SEXP r_time, SEXP r_dt, SEXP r_n_particles, SEXP r_n_groups, SEXP r_seed, SEXP r_deterministic) {
Expand Down Expand Up @@ -149,6 +163,8 @@ static const R_CallMethodDef CallEntries[] = {
{"_dust2_dust2_cpu_sir_set_state_initial", (DL_FUNC) &_dust2_dust2_cpu_sir_set_state_initial, 1},
{"_dust2_dust2_cpu_sir_state", (DL_FUNC) &_dust2_dust2_cpu_sir_state, 2},
{"_dust2_dust2_cpu_sir_time", (DL_FUNC) &_dust2_dust2_cpu_sir_time, 1},
{"_dust2_dust2_cpu_sir_unfilter_alloc", (DL_FUNC) &_dust2_dust2_cpu_sir_unfilter_alloc, 7},
{"_dust2_dust2_cpu_sir_unfilter_run", (DL_FUNC) &_dust2_dust2_cpu_sir_unfilter_run, 3},
{"_dust2_dust2_cpu_sir_update_pars", (DL_FUNC) &_dust2_dust2_cpu_sir_update_pars, 3},
{"_dust2_dust2_cpu_walk_alloc", (DL_FUNC) &_dust2_dust2_cpu_walk_alloc, 7},
{"_dust2_dust2_cpu_walk_reorder", (DL_FUNC) &_dust2_dust2_cpu_walk_reorder, 2},
Expand Down
20 changes: 20 additions & 0 deletions src/sir.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
// Generated by dust2 (version 0.1.0) - do not edit
#include <cpp11.hpp>
#include <dust2/r/cpu.hpp>
#include <dust2/r/filter.hpp>

// first declarations all go here, with their decorators, once we get
// this bit sorted.
Expand Down Expand Up @@ -62,3 +63,22 @@ SEXP dust2_cpu_sir_compare_data(cpp11::sexp ptr,
bool grouped) {
return dust2::r::dust2_cpu_compare_data<sir>(ptr, r_data, grouped);
}

[[cpp11::register]]
SEXP dust2_cpu_sir_unfilter_alloc(cpp11::list r_pars,
cpp11::sexp r_time_start,
cpp11::sexp r_time,
cpp11::sexp r_dt,
cpp11::list r_data,
cpp11::sexp r_n_particles,
cpp11::sexp r_n_groups) {
return dust2::r::dust2_cpu_unfilter_alloc<sir>(r_pars, r_time_start, r_time,
r_dt, r_data, r_n_particles,
r_n_groups);
}

[[cpp11::register]]
SEXP dust2_cpu_sir_unfilter_run(cpp11::sexp ptr, cpp11::sexp r_pars,
bool grouped) {
return dust2::r::dust2_cpu_unfilter_run<sir>(ptr, r_pars, grouped);
}
Loading

0 comments on commit 762d9a0

Please sign in to comment.