Skip to content

Commit

Permalink
Merge pull request #76 from mrc-ide/mrc-5799
Browse files Browse the repository at this point in the history
Allow systems to have concept of default dt
  • Loading branch information
weshinsley authored Sep 27, 2024
2 parents 51c94e6 + a47bdd9 commit 146cfd4
Show file tree
Hide file tree
Showing 15 changed files with 176 additions and 68 deletions.
4 changes: 3 additions & 1 deletion R/compile.R
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ dust_generate <- function(config, filename, linking_to, cpp_std,
optimisation_level, compiler_options, mangle) {
system <- read_lines(filename)
data <- dust_template_data(config$name, config$class, config$time_type,
linking_to, cpp_std,
config$default_dt, linking_to, cpp_std,
optimisation_level, compiler_options, mangle)
data$system_requirements <- data$cpp_std %||% "R (>= 4.0.0)"

Expand Down Expand Up @@ -171,6 +171,7 @@ dust_generate_cpp <- function(system, config, data) {
dust_template_data <- function(name,
class,
time_type,
default_dt,
linking_to = NULL,
cpp_std = NULL,
optimisation_level = NULL,
Expand All @@ -188,6 +189,7 @@ dust_template_data <- function(name,
list(name = name,
class = class,
time_type = time_type,
default_dt = deparse1(default_dt),
package = paste0(name, mangle %||% ""),
linking_to = linking_to,
cpp_std = cpp_std,
Expand Down
8 changes: 4 additions & 4 deletions R/dust.R

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

38 changes: 34 additions & 4 deletions R/filter-support.R
Original file line number Diff line number Diff line change
Expand Up @@ -10,16 +10,46 @@ check_generator_for_filter <- function(generator, what, call = NULL) {
}


check_dt <- function(dt, call = NULL) {
check_system_dt <- function(dt, generator, name = "dt", call = NULL) {
is_discrete <- generator$properties$time_type == "discrete"
if (is_discrete) {
check_dt(dt %||% generator$default_dt, name, call = call)
} else {
if (!is.null(dt)) {
cli::cli_abort("Can't use '{name}' with continuous-time systems",
call = call)
}
}
}


check_system_ode_control <- function(ode_control, generator,
name = "ode_control", call = NULL) {
is_discrete <- generator$properties$time_type == "discrete"
if (is_discrete) {
if (!is.null(ode_control)) {
cli::cli_abort("Can't use 'ode_control' with discrete-time systems")
}
} else {
if (is.null(ode_control)) {
ode_control <- dust_ode_control()
} else {
assert_is(ode_control, "dust_ode_control", call = environment())
}
}
}


check_dt <- function(dt, name = deparse(substitute(dt)), call = NULL) {
assert_scalar_numeric(dt, call = call)
if (dt <= 0) {
cli::cli_abort("Expected 'dt' to be greater than 0")
cli::cli_abort("Expected '{name}' to be greater than 0")
}
if (dt > 1) {
cli::cli_abort("Expected 'dt' to be at most 1")
cli::cli_abort("Expected '{name}' to be at most 1")
}
if (!rlang::is_integerish(1 / dt)) {
cli::cli_abort("Expected 'dt' to be the inverse of an integer",
cli::cli_abort("Expected '{name}' to be the inverse of an integer",
arg = "dt", call = call)
}
dt
Expand Down
4 changes: 2 additions & 2 deletions R/interface-filter.R
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
##'
##' @export
dust_filter_create <- function(generator, time_start, data,
n_particles, n_groups = NULL, dt = 1,
n_particles, n_groups = NULL, dt = NULL,
index_state = NULL, n_threads = 1,
preserve_group_dimension = FALSE,
seed = NULL) {
Expand All @@ -43,7 +43,7 @@ dust_filter_create <- function(generator, time_start, data,

data <- prepare_data(data, n_groups, call = call)
time_start <- check_time_start(time_start, data$time, call = call)
dt <- check_dt(dt, call = call)
dt <- check_system_dt(dt, generator, call = call)

n_groups <- data$n_groups
preserve_group_dimension <- preserve_group_dimension || n_groups > 1
Expand Down
4 changes: 2 additions & 2 deletions R/interface-unfilter.R
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
##' @export
dust_unfilter_create <- function(generator, time_start, data,
n_particles = 1, n_groups = NULL,
dt = 1, n_threads = 1, index_state = NULL,
dt = NULL, n_threads = 1, index_state = NULL,
preserve_particle_dimension = FALSE,
preserve_group_dimension = FALSE) {
call <- environment()
Expand All @@ -29,7 +29,7 @@ dust_unfilter_create <- function(generator, time_start, data,

data <- prepare_data(data, n_groups, call = call)
time_start <- check_time_start(time_start, data$time, call = call)
dt <- check_dt(dt, call = call)
dt <- check_system_dt(dt, generator, call = call)

n_groups <- data$n_groups
preserve_group_dimension <- preserve_group_dimension || n_groups > 1
Expand Down
46 changes: 24 additions & 22 deletions R/interface.R
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
##' the wrong time here will lead to crashes or failure to create
##' the generator.
##'
##' @param default_dt The default value for `dt` on initialisation
##'
##' @param env The environment where the generator is defined.
##'
##' @return A `dust_system_generator` object
Expand All @@ -18,12 +20,12 @@
##' @keywords internal
##' @examples
##' # This creates the "sir" generator
##' dust_system_generator("sir", "discrete", asNamespace("dust2"))
##' dust_system_generator("sir", "discrete", 1, asNamespace("dust2"))
##'
##' # This is the same code as in "dust2:::sir", except there we find
##' # the correct environment automatically
##' dust2:::sir
dust_system_generator <- function(name, time_type,
dust_system_generator <- function(name, time_type, default_dt,
env = parent.env(parent.frame())) {
## I don't love that this requires running through sprintf() each
## time we create a generator, but using a function for the generator (see
Expand Down Expand Up @@ -68,6 +70,7 @@ dust_system_generator <- function(name, time_type,

ret <- list(name = name,
methods = methods,
default_dt = default_dt,
properties = properties)

class(ret) <- "dust_system_generator"
Expand Down Expand Up @@ -147,22 +150,9 @@ dust_system_create <- function(generator, pars, n_particles, n_groups = 1,
call <- environment()
check_is_dust_system_generator(generator, substitute(generator))

is_discrete <- generator$properties$time_type == "discrete"
if (is_discrete) {
dt <- check_dt(dt %||% 1, call = call)
if (!is.null(ode_control)) {
cli::cli_abort("Can't use 'ode_control' with discrete-time systems")
}
} else {
if (!is.null(dt)) {
cli::cli_abort("Can't use 'dt' with continuous-time systems")
}
if (is.null(ode_control)) {
ode_control <- dust_ode_control()
} else {
assert_is(ode_control, "dust_ode_control", call = environment())
}
}
dt <- check_system_dt(dt, generator, call = call)
ode_control <- check_system_ode_control(ode_control, generator, call = call)

check_time(time, dt, call = call)

assert_scalar_size(n_particles, allow_zero = FALSE, call = call)
Expand All @@ -175,6 +165,7 @@ dust_system_create <- function(generator, pars, n_particles, n_groups = 1,
preserve_group_dimension <- preserve_group_dimension || n_groups > 1

pars <- check_pars(pars, n_groups, NULL, preserve_group_dimension)
is_discrete <- generator$properties$time_type == "discrete"
if (is_discrete) {
res <- generator$methods$alloc(pars, time, dt, n_particles,
n_groups, seed, deterministic,
Expand Down Expand Up @@ -558,8 +549,13 @@ print.dust_system <- function(x, ...) {
cli::cli_bullets(c(
i = "This system has 'adjoint' support, and can compute gradients"))
}
cli::cli_bullets(c(
i = "This system runs in {x$properties$time_type} time"))
if (x$properties$time_type == "discrete") {
cli::cli_bullets(c(
i = "This system runs in discrete time with dt = {x$dt}"))
} else {
cli::cli_bullets(c(
i = "This system runs in continuous time"))
}
invisible(x)
}

Expand All @@ -573,8 +569,14 @@ print.dust_system_generator <- function(x, ...) {
cli::cli_bullets(c(
i = "This system has 'compare_data' support"))
}
cli::cli_bullets(c(
i = "This system runs in {x$properties$time_type} time"))
if (x$properties$time_type == "discrete") {
cli::cli_bullets(c(
i = paste("This system runs in discrete time",
"with a default dt of {x$default_dt}")))
} else {
cli::cli_bullets(c(
i = "This system runs in continuous time"))
}
invisible(x)
}

Expand Down
31 changes: 30 additions & 1 deletion R/metadata.R
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,11 @@ parse_metadata <- function(filename, call = NULL) {
data <- decor::cpp_decorations(files = filename)

class <- parse_metadata_class(data, call)
time_type <- parse_metadata_time_type(data, call)
list(class = class,
name = parse_metadata_name(data, call) %||% class,
time_type = parse_metadata_time_type(data, call),
time_type = time_type,
default_dt = parse_metadata_default_dt(data, time_type, call),
has_compare = parse_metadata_has_compare(data, call),
has_adjoint = parse_metadata_has_adjoint(data, call),
parameters = parse_metadata_parameters(data, call))
Expand Down Expand Up @@ -95,6 +97,33 @@ parse_metadata_time_type <- function(data, call = NULL) {
}


parse_metadata_default_dt <- function(data, time_type, call = NULL) {
data <- find_attribute_value_single(data, "dust2::default_dt",
required = FALSE, call = call)
if (is.null(data)) {
return(if (time_type == "discrete") 1 else NULL)
}
if (time_type == "continuous") {
cli::cli_abort(
"Can't use '[[dust::default_dt()]]' with continuous-time systems",
call = call)
}
if (length(data) != 1 || nzchar(names(data))) {
cli::cli_abort(
"Expected a single unnamed argument to '[[dust2::default_dt()]]'",
call = call)
}
if (!is.numeric(data[[1]])) {
cli::cli_abort(
"Expected a numerical argument to '[[dust2::default_dt()]]'",
call = call)
}
value <- data[[1]]
check_dt(value, "[[dust2::default_dt()]]", call = call)
value
}


parse_metadata_has_compare <- function(data, call = NULL) {
parse_metadata_has_feature("compare", data, call)
}
Expand Down
3 changes: 2 additions & 1 deletion R/package.R
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,8 @@ package_validate_makevars_openmp <- function(text, call) {
package_generate <- function(filename, call) {
config <- parse_metadata(filename, call = call)
system <- read_lines(filename)
data <- dust_template_data(config$name, config$class, config$time_type)
data <- dust_template_data(config$name, config$class, config$time_type,
config$default_dt)
list(r = substitute_dust_template(data, "dust.R"),
cpp = dust_generate_cpp(system, config, data))
}
Expand Down
2 changes: 1 addition & 1 deletion inst/template/dust.R
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
{{name}} <- function() {
dust2::dust_system_generator("{{name}}", "{{time_type}}")
dust2::dust_system_generator("{{name}}", "{{time_type}}", {{default_dt}})
}
2 changes: 1 addition & 1 deletion man/dust_filter_create.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

11 changes: 9 additions & 2 deletions man/dust_system_generator.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion man/dust_unfilter_create.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading

0 comments on commit 146cfd4

Please sign in to comment.