diff --git a/R/cpp11.R b/R/cpp11.R index 2c9a5731..f082dc1c 100644 --- a/R/cpp11.R +++ b/R/cpp11.R @@ -36,6 +36,14 @@ dust2_cpu_sir_compare_data <- function(ptr, r_data, grouped) { .Call(`_dust2_dust2_cpu_sir_compare_data`, ptr, r_data, grouped) } +dust2_cpu_sir_unfilter_alloc <- function(r_pars, r_time_start, r_time, r_dt, r_data, r_n_particles, r_n_groups) { + .Call(`_dust2_dust2_cpu_sir_unfilter_alloc`, r_pars, r_time_start, r_time, r_dt, r_data, r_n_particles, r_n_groups) +} + +dust2_cpu_sir_unfilter_run <- function(ptr, r_pars, grouped) { + .Call(`_dust2_dust2_cpu_sir_unfilter_run`, ptr, r_pars, grouped) +} + dust2_cpu_walk_alloc <- function(r_pars, r_time, r_dt, r_n_particles, r_n_groups, r_seed, r_deterministic) { .Call(`_dust2_dust2_cpu_walk_alloc`, r_pars, r_time, r_dt, r_n_particles, r_n_groups, r_seed, r_deterministic) } diff --git a/inst/include/dust2/cpu.hpp b/inst/include/dust2/cpu.hpp index 7c81b1d7..3ec65e0e 100644 --- a/inst/include/dust2/cpu.hpp +++ b/inst/include/dust2/cpu.hpp @@ -1,10 +1,13 @@ #pragma once +#include + namespace dust2 { template class dust_cpu { public: + using model_type = T; using real_type = typename T::real_type; using rng_state_type = typename T::rng_state_type; using shared_state = typename T::shared_state; @@ -129,6 +132,10 @@ class dust_cpu { return time_; } + auto dt() const { + return dt_; + } + auto n_state() const { return n_state_; } diff --git a/inst/include/dust2/filter.hpp b/inst/include/dust2/filter.hpp new file mode 100644 index 00000000..2b6aea7d --- /dev/null +++ b/inst/include/dust2/filter.hpp @@ -0,0 +1,70 @@ +#pragma once + +#include + +namespace dust2 { + +template +class unfilter { +public: + using real_type = typename T::real_type; + using data_type = typename T::data_type; + + // We need to provide direct access to the model, because the user + // will want to set parameters in, and pull out state, etc. + dust_cpu model; + + unfilter(dust_cpu model_, + real_type time_start, + std::vector time, + std::vector data) : + model(model_), + time_start_(time_start), + time_(time), + data_(data), + n_particles_(model.n_particles()), + n_groups_(model.n_groups()), + ll_(n_particles_ * n_groups_, 0), + ll_step_(n_particles_ * n_groups_, 0) { + const auto dt = model_.dt(); + for (size_t i = 0; i < time_.size(); i++) { + const auto t0 = i == 0 ? time_start_ : time_[i - 1]; + const auto t1 = time_[i]; + step_.push_back(static_cast(std::round((t1 - t0) / dt))); + } + } + + void run() { + const auto n_times = step_.size(); + + model.set_time(time_start_); + model.set_state_initial(); + std::fill(ll_.begin(), ll_.end(), 0); + + auto it_data = data_.begin(); + for (size_t i = 0; i < n_times; ++i, it_data += n_groups_) { + model.run_steps(step_[i]); // just compute this at point of use? + model.compare_data(it_data, ll_step_.begin()); + for (size_t j = 0; j < ll_.size(); ++j) { + ll_[j] += ll_step_[j]; + } + } + } + + template + void last_log_likelihood(Iter iter) { + std::copy(ll_.begin(), ll_.end(), iter); + } + +private: + real_type time_start_; + std::vector time_; + std::vector step_; + std::vector data_; + size_t n_particles_; + size_t n_groups_; + std::vector ll_; + std::vector ll_step_; +}; + +} diff --git a/inst/include/dust2/r/filter.hpp b/inst/include/dust2/r/filter.hpp new file mode 100644 index 00000000..cd0139e2 --- /dev/null +++ b/inst/include/dust2/r/filter.hpp @@ -0,0 +1,79 @@ +#pragma once + +#include +#include + +namespace dust2 { +namespace r { + +// TODO: this name must be changed! +template +cpp11::sexp dust2_cpu_unfilter_alloc(cpp11::list r_pars, + cpp11::sexp r_time_start, + cpp11::sexp r_time, + cpp11::sexp r_dt, + cpp11::list r_data, + cpp11::sexp r_n_particles, + cpp11::sexp r_n_groups) { + using real_type = typename T::real_type; + using rng_state_type = typename T::rng_state_type; + + auto n_particles = to_size(r_n_particles, "n_particles"); + auto n_groups = to_size(r_n_groups, "n_groups"); + const bool grouped = n_groups > 0; + const auto time_start = check_time(r_time_start, "time_start"); + const auto time = check_time_sequence(time_start, r_time, "time"); + const auto dt = check_dt(r_dt); + const auto shared = build_shared(r_pars, n_groups); + const auto internal = build_internal(shared); + const auto data = check_data(r_data, time.size(), n_groups, "data"); + + // It's possible that we don't want to always really be + // deterministic here? Though nooone can think of a case where + // that's actually the behaviour wanted. For now let's go fully + // deterministic. + auto seed = mcstate::random::r::as_rng_seed(R_NilValue); + const auto deterministic = true; + + // Then allocate the model; this pulls together almost all the data + // we need. At this point we could have constructed the model out + // of one that exists already on the R side, but I think that's + // going to feel weirder overall. + const auto model = dust2::dust_cpu(shared, internal, time_start, dt, n_particles, + seed, deterministic); + + auto obj = new unfilter(model, time_start, time, data); + cpp11::external_pointer> ptr(obj, true, false); + + cpp11::sexp r_n_state = cpp11::as_sexp(obj->model.n_state()); + cpp11::sexp r_group_names = R_NilValue; + if (grouped) { + r_group_names = r_pars.attr("names"); + } + cpp11::sexp r_grouped = cpp11::as_sexp(grouped); + + return cpp11::writable::list{ptr, r_n_state, r_grouped, r_group_names}; +} + +template +cpp11::sexp dust2_cpu_unfilter_run(cpp11::sexp ptr, cpp11::sexp r_pars, + bool grouped) { + auto *obj = + cpp11::as_cpp>>(ptr).get(); + if (r_pars != R_NilValue) { + update_pars(obj->model, cpp11::as_cpp(r_pars), grouped); + } + obj->run(); + + const auto n_groups = obj->model.n_groups(); + const auto n_particles = obj->model.n_particles(); + cpp11::writable::doubles ret(n_groups * n_particles); + obj->last_log_likelihood(REAL(ret)); + if (grouped && n_particles > 1) { + set_array_dims(ret, {n_particles, n_groups}); + } + return ret; +} + +} +} diff --git a/inst/include/dust2/r/helpers.hpp b/inst/include/dust2/r/helpers.hpp index cd017635..d1868350 100644 --- a/inst/include/dust2/r/helpers.hpp +++ b/inst/include/dust2/r/helpers.hpp @@ -31,6 +31,23 @@ inline double to_double(cpp11::sexp x, const char * name) { cpp11::stop("'%s' must be scalar numeric", name); } +template +inline std::vector to_vector_real(cpp11::sexp x, const char * name) { + if (TYPEOF(x) == REALSXP) { + auto x_dbl = cpp11::as_cpp(x); + std::vector ret(x_dbl.size()); + std::copy(x_dbl.begin(), x_dbl.end(), ret.begin()); + return ret; + } + if (TYPEOF(x) == INTSXP) { + auto x_int = cpp11::as_cpp(x); + std::vector ret(x_int.size()); + std::copy(x_int.begin(), x_int.end(), ret.begin()); + return ret; + } + cpp11::stop("'%s' must be a numeric vector", name); +} + inline int to_int(cpp11::sexp x, const char * name) { check_scalar(x, name); if (TYPEOF(x) == INTSXP) { @@ -94,6 +111,27 @@ inline double check_dt(cpp11::sexp r_dt) { return dt; } +template +std::vector check_time_sequence(real_type time_start, + cpp11::sexp r_time, + const char * name) { + auto time = to_vector_real(r_time, name); + auto prev = time_start; + const auto eps = 1e-8; + for (size_t i = 0; i < time.size(); ++i) { + const auto t = time[i]; + if (!is_integer_like(t, eps)) { + cpp11::stop("Expected 'time[%d]' to be integer-like", i + 1); + } + if (t <= prev) { + cpp11::stop("Expected 'time[%d]' (%d) to be larger than the previous value (%d)", + i + 1, static_cast(prev), static_cast(t)); + } + prev = t; + } + return time; +} + // The initializer_list is a type-safe variadic-like approach. inline void set_array_dims(cpp11::sexp data, std::initializer_list dims) { @@ -174,5 +212,37 @@ void update_pars(dust_cpu& obj, cpp11::list r_pars, bool grouped) { } } +template +std::vector check_data(cpp11::list r_data, + size_t n_time, + size_t n_groups, + const char * name) { + const bool grouped = n_groups > 0; + std::vector data; + + check_length(r_data, n_time, name); + + if (grouped) { + // There are two ways we might recieve things; as a list-of-lists + // or as a list matrix. We might also want to cope with a + // data.frame but we can probably do that on the R side, and might + // provide helpers there that throw much nicer errors than we can + // throw here, really. + for (size_t i = 0; i < n_time; ++i) { + auto r_data_i = cpp11::as_cpp(r_data[i]); + check_length(r_data_i, n_groups, "data[i]"); // can do better with sstream + for (size_t j = 0; j < n_groups; ++j) { + data.push_back(T::build_data(r_data_i[j])); + } + } + } else { + for (size_t i = 0; i < n_time; ++i) { + data.push_back(T::build_data(r_data[i])); + } + } + + return data; +} + } } diff --git a/src/cpp11.cpp b/src/cpp11.cpp index 69ac0923..8461fd30 100644 --- a/src/cpp11.cpp +++ b/src/cpp11.cpp @@ -68,6 +68,20 @@ extern "C" SEXP _dust2_dust2_cpu_sir_compare_data(SEXP ptr, SEXP r_data, SEXP gr return cpp11::as_sexp(dust2_cpu_sir_compare_data(cpp11::as_cpp>(ptr), cpp11::as_cpp>(r_data), cpp11::as_cpp>(grouped))); END_CPP11 } +// sir.cpp +SEXP dust2_cpu_sir_unfilter_alloc(cpp11::list r_pars, cpp11::sexp r_time_start, cpp11::sexp r_time, cpp11::sexp r_dt, cpp11::list r_data, cpp11::sexp r_n_particles, cpp11::sexp r_n_groups); +extern "C" SEXP _dust2_dust2_cpu_sir_unfilter_alloc(SEXP r_pars, SEXP r_time_start, SEXP r_time, SEXP r_dt, SEXP r_data, SEXP r_n_particles, SEXP r_n_groups) { + BEGIN_CPP11 + return cpp11::as_sexp(dust2_cpu_sir_unfilter_alloc(cpp11::as_cpp>(r_pars), cpp11::as_cpp>(r_time_start), cpp11::as_cpp>(r_time), cpp11::as_cpp>(r_dt), cpp11::as_cpp>(r_data), cpp11::as_cpp>(r_n_particles), cpp11::as_cpp>(r_n_groups))); + END_CPP11 +} +// sir.cpp +SEXP dust2_cpu_sir_unfilter_run(cpp11::sexp ptr, cpp11::sexp r_pars, bool grouped); +extern "C" SEXP _dust2_dust2_cpu_sir_unfilter_run(SEXP ptr, SEXP r_pars, SEXP grouped) { + BEGIN_CPP11 + return cpp11::as_sexp(dust2_cpu_sir_unfilter_run(cpp11::as_cpp>(ptr), cpp11::as_cpp>(r_pars), cpp11::as_cpp>(grouped))); + END_CPP11 +} // walk.cpp SEXP dust2_cpu_walk_alloc(cpp11::list r_pars, cpp11::sexp r_time, cpp11::sexp r_dt, cpp11::sexp r_n_particles, cpp11::sexp r_n_groups, cpp11::sexp r_seed, cpp11::sexp r_deterministic); extern "C" SEXP _dust2_dust2_cpu_walk_alloc(SEXP r_pars, SEXP r_time, SEXP r_dt, SEXP r_n_particles, SEXP r_n_groups, SEXP r_seed, SEXP r_deterministic) { @@ -149,6 +163,8 @@ static const R_CallMethodDef CallEntries[] = { {"_dust2_dust2_cpu_sir_set_state_initial", (DL_FUNC) &_dust2_dust2_cpu_sir_set_state_initial, 1}, {"_dust2_dust2_cpu_sir_state", (DL_FUNC) &_dust2_dust2_cpu_sir_state, 2}, {"_dust2_dust2_cpu_sir_time", (DL_FUNC) &_dust2_dust2_cpu_sir_time, 1}, + {"_dust2_dust2_cpu_sir_unfilter_alloc", (DL_FUNC) &_dust2_dust2_cpu_sir_unfilter_alloc, 7}, + {"_dust2_dust2_cpu_sir_unfilter_run", (DL_FUNC) &_dust2_dust2_cpu_sir_unfilter_run, 3}, {"_dust2_dust2_cpu_sir_update_pars", (DL_FUNC) &_dust2_dust2_cpu_sir_update_pars, 3}, {"_dust2_dust2_cpu_walk_alloc", (DL_FUNC) &_dust2_dust2_cpu_walk_alloc, 7}, {"_dust2_dust2_cpu_walk_reorder", (DL_FUNC) &_dust2_dust2_cpu_walk_reorder, 2}, diff --git a/src/sir.cpp b/src/sir.cpp index 00ff518b..531975e6 100644 --- a/src/sir.cpp +++ b/src/sir.cpp @@ -1,6 +1,7 @@ // Generated by dust2 (version 0.1.0) - do not edit #include #include +#include // first declarations all go here, with their decorators, once we get // this bit sorted. @@ -62,3 +63,22 @@ SEXP dust2_cpu_sir_compare_data(cpp11::sexp ptr, bool grouped) { return dust2::r::dust2_cpu_compare_data(ptr, r_data, grouped); } + +[[cpp11::register]] +SEXP dust2_cpu_sir_unfilter_alloc(cpp11::list r_pars, + cpp11::sexp r_time_start, + cpp11::sexp r_time, + cpp11::sexp r_dt, + cpp11::list r_data, + cpp11::sexp r_n_particles, + cpp11::sexp r_n_groups) { + return dust2::r::dust2_cpu_unfilter_alloc(r_pars, r_time_start, r_time, + r_dt, r_data, r_n_particles, + r_n_groups); +} + +[[cpp11::register]] +SEXP dust2_cpu_sir_unfilter_run(cpp11::sexp ptr, cpp11::sexp r_pars, + bool grouped) { + return dust2::r::dust2_cpu_unfilter_run(ptr, r_pars, grouped); +} diff --git a/tests/testthat/test-filter.R b/tests/testthat/test-filter.R new file mode 100644 index 00000000..366d2d69 --- /dev/null +++ b/tests/testthat/test-filter.R @@ -0,0 +1,123 @@ +test_that("can run an unfilter", { + base <- list(beta = 0.1, gamma = 0.2, N = 1000, I0 = 10, exp_noise = 1e6) + pars1 <- list(beta = 0.1, gamma = 0.2, I0 = 10) + pars2 <- list(beta = 0.2, gamma = 0.2, I0 = 10) + + time_start <- 0 + time <- c(4, 8, 12, 16) + data <- lapply(1:4, function(i) list(incidence = i)) + dt <- 1 + + ## Manually compute likelihood: + f <- function(pars) { + base[names(pars)] <- pars + obj <- dust2_cpu_sir_alloc(base, time_start, dt, 1, 0, NULL, TRUE) + ptr <- obj[[1]] + dust2_cpu_sir_set_state_initial(ptr) + incidence <- numeric(length(time)) + time0 <- c(time_start, time) + for (i in seq_along(time)) { + dust2_cpu_sir_run_steps(ptr, round((time[i] - time0[i]) / dt)) + incidence[i] <- dust2_cpu_sir_state(ptr, FALSE)[5, , drop = TRUE] + } + sum(dpois(1:4, incidence + 1e-6, log = TRUE)) + } + + obj <- dust2_cpu_sir_unfilter_alloc(base, time_start, time, dt, data, 1, 0) + ptr <- obj[[1]] + expect_equal(dust2_cpu_sir_unfilter_run(ptr, NULL, FALSE), f(pars1)) + + expect_equal(dust2_cpu_sir_unfilter_run(ptr, pars1, FALSE), f(pars1)) + expect_equal(dust2_cpu_sir_unfilter_run(ptr, pars2, FALSE), f(pars2)) +}) + + +test_that("can run unfilter on structured model", { + base <- list(beta = 0.1, gamma = 0.2, N = 1000, I0 = 10, exp_noise = 1e6) + n_groups <- 3 + pars <- lapply(seq_len(n_groups), + function(i) modifyList(base, list(beta = i * 0.1))) + + time_start <- 0 + time <- c(4, 8, 12, 16) + data <- lapply(1:4, function(i) { + lapply(seq_len(n_groups), function(j) list(incidence = 2 * (i - 1) + j)) + }) + dt <- 1 + + ## Manually compute likelihood: + f <- function(pars) { + obj <- dust2_cpu_sir_alloc(pars, time_start, dt, 1, n_groups, NULL, TRUE) + ptr <- obj[[1]] + dust2_cpu_sir_set_state_initial(ptr) + incidence <- matrix(0, n_groups, length(time)) + time0 <- c(time_start, time) + for (i in seq_along(time)) { + dust2_cpu_sir_run_steps(ptr, round((time[i] - time0[i]) / dt)) + incidence[, i] <- dust2_cpu_sir_state(ptr, FALSE)[5, , drop = TRUE] + } + observed <- matrix(unlist(data, use.names = FALSE), n_groups) + rowSums(dpois(observed, incidence + 1e-6, log = TRUE)) + } + + obj <- dust2_cpu_sir_unfilter_alloc(pars, time_start, time, dt, data, 1, 3) + ptr <- obj[[1]] + expect_equal(dust2_cpu_sir_unfilter_run(ptr, NULL, TRUE), f(pars)) +}) + + +test_that("validate time for filter", { + pars <- list(beta = 0.1, gamma = 0.2, N = 1000, I0 = 10, exp_noise = 1e6) + time_start <- 0 + time <- as.integer(c(4, 8, 12, 16)) + data <- lapply(1:4, function(i) list(incidence = i)) + dt <- 1 + + expect_error( + dust2_cpu_sir_unfilter_alloc(pars, 5, time, dt, data, 1, 0), + "Expected 'time[1]' (5) to be larger than the previous value (4)", + fixed = TRUE) + time2 <- time + c(0, 0, .1, 0) + expect_error( + dust2_cpu_sir_unfilter_alloc(pars, 0, time2, dt, data, 1, 0), + "Expected 'time[3]' to be integer-like", + fixed = TRUE) + expect_error( + dust2_cpu_sir_unfilter_alloc(pars, 0, as.character(time), dt, data, 1, 0), + "'time' must be a numeric vector", + fixed = TRUE) +}) + + +test_that("can run replicated unfilter", { + pars <- list(beta = 0.1, gamma = 0.2, N = 1000, I0 = 10, exp_noise = 1e6) + time_start <- 0 + time <- c(4, 8, 12, 16) + data <- lapply(1:4, function(i) list(incidence = i)) + dt <- 1 + obj1 <- dust2_cpu_sir_unfilter_alloc(pars, time_start, time, dt, data, 5, 0) + obj2 <- dust2_cpu_sir_unfilter_alloc(pars, time_start, time, dt, data, 1, 0) + + expect_equal( + dust2_cpu_sir_unfilter_run(obj1[[1]], NULL, FALSE), + rep(dust2_cpu_sir_unfilter_run(obj2[[1]], NULL, FALSE), 5)) +}) + + +test_that("can run replicated structured unfilter", { + pars <- list( + list(beta = 0.1, gamma = 0.2, N = 1000, I0 = 10, exp_noise = 1e6), + list(beta = 0.2, gamma = 0.2, N = 1000, I0 = 10, exp_noise = 1e6)) + time_start <- 0 + time <- c(4, 8, 12, 16) + data <- lapply(1:4, function(i) { + lapply(seq_len(2), function(j) list(incidence = 2 * (i - 1) + j)) + }) + dt <- 1 + obj1 <- dust2_cpu_sir_unfilter_alloc(pars, time_start, time, dt, data, 5, 2) + obj2 <- dust2_cpu_sir_unfilter_alloc(pars, time_start, time, dt, data, 1, 2) + + expect_equal( + dust2_cpu_sir_unfilter_run(obj1[[1]], NULL, TRUE), + matrix(rep(dust2_cpu_sir_unfilter_run(obj2[[1]], NULL, TRUE), each = 5), 5)) +})