-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #81 from mrc-ide/mrc-5775
Basic state unpacking support for dust systems
- Loading branch information
Showing
5 changed files
with
172 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,77 @@ | ||
##' Unpack state. You will see state come out of dust2 systems in | ||
##' several places, for example [dust_system_state], but it will | ||
##' typically be an unstructed vector with no names; this is not very | ||
##' useful! However, your model knows what each element, or group of | ||
##' elements "means". You can unpack your state from this | ||
##' unstructured array into a named list using this function. The | ||
##' same idea applies to the higher-dimensional arrays that you might | ||
##' get if your system has multiple particles, multiple parameter | ||
##' groups or has been run for multiple time steps. | ||
##' | ||
##' @title Unpack state | ||
##' | ||
##' @param obj A `dust_system` object (from [dust_system_create]) or | ||
##' `dust_likelihood` object (from [dust_filter_create] or | ||
##' [dust_unfilter_create]). | ||
##' | ||
##' @param state A state vector, matrix or array. This might have | ||
##' come from [dust_system_state], [dust_likelihood_last_history], or | ||
##' [dust_likelihood_last_state]. | ||
##' | ||
##' @return A named list, where each element corresponds to a logical | ||
##' compartment. | ||
##' | ||
##' @seealso [monty::monty_packer], and within that especially | ||
##' documentation for `$unpack()`, which powers this function. | ||
##' | ||
##' @rdname dust_unpack | ||
##' @export | ||
##' @examples | ||
##' | ||
##' sys <- dust_system_create(dust2:::sir(), list(), n_particles = 10, dt = 0.25) | ||
##' dust_system_set_state_initial(sys) | ||
##' t <- seq(0, 100, by = 5) | ||
##' y <- dust_system_simulate(sys, t) | ||
##' # The result here is a 5 x 10 x 21 matrix: 5 states by 10 particles by | ||
##' # 21 times. | ||
##' dim(y) | ||
##' | ||
##' # The 10 particles and 21 times (following t) are simple enough, but | ||
##' # what are our 5 compartments? | ||
##' | ||
##' # You can use dust_unpack_state() to reshape your output as a | ||
##' # list: | ||
##' dust_unpack_state(sys, y) | ||
##' | ||
##' # Here, the list is named following the compartments (S, I, R, | ||
##' # etc) and is a 10 x 21 matrix (i.e., the remaining dimensions | ||
##' # from y) | ||
##' | ||
##' # We could apply this to the final state, which converts a 5 x 10 | ||
##' # matrix of state into a 5 element list of vectors, each with | ||
##' # length 10: | ||
##' s <- dust_system_state(sys) | ||
##' dim(s) | ||
##' dust_unpack_state(sys, s) | ||
##' | ||
##' # If you need more control, you can use 'dust_unpack_index' to map | ||
##' # names to positions within the state dimension of this array | ||
##' dust_unpack_index(sys) | ||
dust_unpack_state <- function(obj, state) { | ||
get_unpacker(obj)$unpack(state) | ||
} | ||
|
||
|
||
##' @rdname dust_unpack | ||
##' @export | ||
dust_unpack_index <- function(obj) { | ||
get_unpacker(obj)$index() | ||
} | ||
|
||
|
||
get_unpacker <- function(obj, call = parent.frame()) { | ||
## Once mrc-5806 is merged, we can add this into the filter/unfilter | ||
## too and do the same basic idea. | ||
assert_is(obj, "dust_system", call = call) | ||
obj$packer_state | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
test_that("can unpack state from systems with a single particle", { | ||
sys <- dust_system_create(sir(), list(), n_particles = 1) | ||
dust_system_set_state_initial(sys) | ||
s <- dust_system_state(sys) | ||
s2 <- dust_unpack_state(sys, s) | ||
expect_equal(s2, sys$packer_state$unpack(s)) | ||
expect_equal(s2, list(S = 990, I = 10, R = 0, cases_cumul = 0, cases_inc = 0)) | ||
}) | ||
|
||
|
||
test_that("can unpack state from systems with several particles", { | ||
sys <- dust_system_create(sir(), list(), n_particles = 10) | ||
dust_system_set_state_initial(sys) | ||
s <- dust_system_state(sys) | ||
s2 <- dust_unpack_state(sys, s) | ||
expect_equal(s2, sys$packer_state$unpack(s)) | ||
expect_equal(lengths(s2, FALSE), rep(10, 5)) | ||
}) |