Skip to content

Commit

Permalink
Merge pull request #81 from mrc-ide/mrc-5775
Browse files Browse the repository at this point in the history
Basic state unpacking support for dust systems
  • Loading branch information
weshinsley authored Sep 27, 2024
2 parents ac759a5 + 533b81a commit 28ea20d
Show file tree
Hide file tree
Showing 5 changed files with 172 additions and 0 deletions.
2 changes: 2 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -32,4 +32,6 @@ export(dust_system_state)
export(dust_system_time)
export(dust_system_update_pars)
export(dust_unfilter_create)
export(dust_unpack_index)
export(dust_unpack_state)
useDynLib(dust2, .registration = TRUE)
77 changes: 77 additions & 0 deletions R/tools.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
##' Unpack state. You will see state come out of dust2 systems in
##' several places, for example [dust_system_state], but it will
##' typically be an unstructed vector with no names; this is not very
##' useful! However, your model knows what each element, or group of
##' elements "means". You can unpack your state from this
##' unstructured array into a named list using this function. The
##' same idea applies to the higher-dimensional arrays that you might
##' get if your system has multiple particles, multiple parameter
##' groups or has been run for multiple time steps.
##'
##' @title Unpack state
##'
##' @param obj A `dust_system` object (from [dust_system_create]) or
##' `dust_likelihood` object (from [dust_filter_create] or
##' [dust_unfilter_create]).
##'
##' @param state A state vector, matrix or array. This might have
##' come from [dust_system_state], [dust_likelihood_last_history], or
##' [dust_likelihood_last_state].
##'
##' @return A named list, where each element corresponds to a logical
##' compartment.
##'
##' @seealso [monty::monty_packer], and within that especially
##' documentation for `$unpack()`, which powers this function.
##'
##' @rdname dust_unpack
##' @export
##' @examples
##'
##' sys <- dust_system_create(dust2:::sir(), list(), n_particles = 10, dt = 0.25)
##' dust_system_set_state_initial(sys)
##' t <- seq(0, 100, by = 5)
##' y <- dust_system_simulate(sys, t)
##' # The result here is a 5 x 10 x 21 matrix: 5 states by 10 particles by
##' # 21 times.
##' dim(y)
##'
##' # The 10 particles and 21 times (following t) are simple enough, but
##' # what are our 5 compartments?
##'
##' # You can use dust_unpack_state() to reshape your output as a
##' # list:
##' dust_unpack_state(sys, y)
##'
##' # Here, the list is named following the compartments (S, I, R,
##' # etc) and is a 10 x 21 matrix (i.e., the remaining dimensions
##' # from y)
##'
##' # We could apply this to the final state, which converts a 5 x 10
##' # matrix of state into a 5 element list of vectors, each with
##' # length 10:
##' s <- dust_system_state(sys)
##' dim(s)
##' dust_unpack_state(sys, s)
##'
##' # If you need more control, you can use 'dust_unpack_index' to map
##' # names to positions within the state dimension of this array
##' dust_unpack_index(sys)
dust_unpack_state <- function(obj, state) {
get_unpacker(obj)$unpack(state)
}


##' @rdname dust_unpack
##' @export
dust_unpack_index <- function(obj) {
get_unpacker(obj)$index()
}


get_unpacker <- function(obj, call = parent.frame()) {
## Once mrc-5806 is merged, we can add this into the filter/unfilter
## too and do the same basic idea.
assert_is(obj, "dust_system", call = call)
obj$packer_state
}
4 changes: 4 additions & 0 deletions _pkgdown.yml
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,10 @@ reference:
- dust_likelihood_copy
- dust_likelihood_rng_state
- dust_likelihood_set_rng_state
- title: Tools
contents:
- dust_unpack_state
- dust_unpack_index
- title: Creation
contents:
- dust_compile
Expand Down
71 changes: 71 additions & 0 deletions man/dust_unpack.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

18 changes: 18 additions & 0 deletions tests/testthat/test-tools.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
test_that("can unpack state from systems with a single particle", {
sys <- dust_system_create(sir(), list(), n_particles = 1)
dust_system_set_state_initial(sys)
s <- dust_system_state(sys)
s2 <- dust_unpack_state(sys, s)
expect_equal(s2, sys$packer_state$unpack(s))
expect_equal(s2, list(S = 990, I = 10, R = 0, cases_cumul = 0, cases_inc = 0))
})


test_that("can unpack state from systems with several particles", {
sys <- dust_system_create(sir(), list(), n_particles = 10)
dust_system_set_state_initial(sys)
s <- dust_system_state(sys)
s2 <- dust_unpack_state(sys, s)
expect_equal(s2, sys$packer_state$unpack(s))
expect_equal(lengths(s2, FALSE), rep(10, 5))
})

0 comments on commit 28ea20d

Please sign in to comment.