Skip to content

Commit

Permalink
Restructure code
Browse files Browse the repository at this point in the history
  • Loading branch information
richfitz committed May 13, 2024
1 parent bfa947e commit 2dee361
Show file tree
Hide file tree
Showing 3 changed files with 82 additions and 68 deletions.
4 changes: 4 additions & 0 deletions inst/include/dust2/cpu.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,10 @@ class dust_cpu {
return time_;
}

auto dt() const {
return dt_;
}

auto n_state() const {
return n_state_;
}
Expand Down
68 changes: 68 additions & 0 deletions inst/include/dust2/filter.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
#pragma once

#include <dust2/cpu.hpp>

namespace dust2 {

template <typename T>
class unfilter {
public:
using real_type = typename T::real_type;
using data_type = typename T::data_type;

// We need to provide direct access to the model, because the user
// will want to set parameters in, and pull out state, etc.
dust_cpu<T> model;

unfilter(dust_cpu<T> model_,
real_type time_start,
std::vector<real_type> time,
std::vector<data_type> data) :
model(model_),
time_start_(time_start),
time_(time),
data_(data),
n_groups_(model.n_groups()),
ll_(n_groups_, 0),
ll_step_(n_groups_, 0) {
const auto dt = model_.dt();
for (size_t i = 0; i < time_.size(); i++) {
const auto t0 = i == 0 ? time_start_ : time_[i - 1];
const auto t1 = time_[i];
step_.push_back(static_cast<size_t>(std::round((t1 - t0) / dt)));
}
}

void run() {
const auto n_times = step_.size();

model.set_time(time_start_);
model.set_state_initial();
std::fill(ll_.begin(), ll_.end(), 0);

auto it_data = data_.begin();
for (size_t i = 0; i < n_times; ++i, it_data += n_groups_) {
model.run_steps(step_[i]); // just compute this at point of use?
model.compare_data(it_data, ll_step_.begin());
for (size_t j = 0; j < n_groups_; ++j) {
ll_[j] += ll_step_[j];
}
}
}

template <typename It>
void last_log_likelihood(It it) {
std::copy_n(ll_.begin(), n_groups_, it);
}

private:
real_type time_start_;
std::vector<real_type> time_;
std::vector<size_t> step_;
std::vector<data_type> data_;
size_t n_groups_;
std::vector<real_type> ll_;
std::vector<real_type> ll_step_;
};

}
78 changes: 10 additions & 68 deletions inst/include/dust2/r/filter.hpp
Original file line number Diff line number Diff line change
@@ -1,36 +1,11 @@
#pragma once
#pragma once

#include <dust2/r/helpers.hpp>
#include <dust2/cpu.hpp>
#include <dust2/filter.hpp>

namespace dust2 {
namespace r {

// There's a big question here of if we take a model pointer or do the
// initialisation ourself. I think that we really need to do the
// latter because otherwise the rng is quite hard to think about
// because the user might end up holding two copies? Or we can offer
// both approaches (though we've never once needed to initialise the
// model first). The user will mostly interact with this as
//
// filter <- make_filter_sir(pars, ...)
// filter$run(pars) # update pars, verify size has not changed.
//
// The first bit of init
template <typename T>
struct unfilter_state {
using real_type = typename T::real_type;
using data_type = typename T::data_type;
dust_cpu<T> model;
real_type time_start;
std::vector<real_type> time;
std::vector<size_t> step;
std::vector<data_type> data;
std::vector<real_type> ll;
std::vector<real_type> ll_step;
size_t n_groups;
};

// TODO: this name must be changed!
template <typename T>
cpp11::sexp dust2_cpu_unfilter_alloc(cpp11::list r_pars,
Expand All @@ -51,17 +26,6 @@ cpp11::sexp dust2_cpu_unfilter_alloc(cpp11::list r_pars,
const auto internal = build_internal<T>(shared);
const auto data = check_data<T>(r_data, time.size(), n_groups, "data");

n_groups = shared.size();

// This probably gets reused a bit, too, easily pulled into a
// helper; it will be used in a simulate() method too
std::vector<size_t> step;
for (size_t i = 0; i < time.size(); i++) {
const auto t0 = i == 0 ? time_start : time[i - 1];
const auto t1 = time[i];
step.push_back(static_cast<size_t>(std::round((t1 - t0) / dt)));
}

// It's possible that we don't want to always really be
// deterministic here? Though nooone can think of a case where
// that's actually the behaviour wanted. For now let's go fully
Expand All @@ -74,14 +38,11 @@ cpp11::sexp dust2_cpu_unfilter_alloc(cpp11::list r_pars,
// we need. At this point we could have constructed the model out
// of one that exists already on the R side, but I think that's
// going to feel weirder overall.
const auto model = dust_cpu<T>(shared, internal, time_start, dt, n_particles,
const auto model = dust2::dust_cpu<T>(shared, internal, time_start, dt, n_particles,
seed, deterministic);

std::vector<real_type> ll(n_groups);
std::vector<real_type> ll_step(n_groups);
auto obj = new unfilter_state<T>{model, time_start, time, step, data,
ll, ll_step, n_groups};
cpp11::external_pointer<unfilter_state<T>> ptr(obj, true, false);
auto obj = new unfilter<T>(model, time_start, time, data);
cpp11::external_pointer<unfilter<T>> ptr(obj, true, false);

cpp11::sexp r_n_state = cpp11::as_sexp(obj->model.n_state());
cpp11::sexp r_group_names = R_NilValue;
Expand All @@ -98,34 +59,15 @@ template <typename T>
cpp11::sexp dust2_cpu_unfilter_run(cpp11::sexp ptr, cpp11::sexp r_pars,
bool grouped) {
auto *obj =
cpp11::as_cpp<cpp11::external_pointer<unfilter_state<T>>>(ptr).get();

cpp11::as_cpp<cpp11::external_pointer<unfilter<T>>>(ptr).get();
if (r_pars != R_NilValue) {
update_pars(obj->model, cpp11::as_cpp<cpp11::list>(r_pars), grouped);
}
obj->run();

// There's a question of how much we move this into the C++ part
// (rather than this interface layer) which we will want to do to
// make this more extendable.
const auto time_start = obj->time_start;
const auto& step = obj->step;
const auto n_groups = obj->n_groups;
const auto n_times = step.size();

obj->model.set_time(time_start);
obj->model.set_state_initial();
std::fill(obj->ll.begin(), obj->ll.end(), 0);

auto it_data = obj->data.begin();
for (size_t i = 0; i < n_times; ++i, it_data += n_groups) {
obj->model.run_steps(step[i]); // just compute this at point of use?
obj->model.compare_data(it_data, obj->ll_step.begin());
for (size_t j = 0; j < n_groups; ++j) {
obj->ll[j] += obj->ll_step[j];
}
}

return cpp11::writable::doubles(obj->ll);
cpp11::writable::doubles ret(obj->model.n_groups());
obj->last_log_likelihood(REAL(ret));
return ret;
}

}
Expand Down

0 comments on commit 2dee361

Please sign in to comment.