From 03ab28fb6391f6445ed09a7ad34c146ad38354bf Mon Sep 17 00:00:00 2001 From: Rich FitzJohn Date: Tue, 24 Sep 2024 12:16:14 +0100 Subject: [PATCH 1/2] Allow systems to have concept of default dt --- R/compile.R | 4 +- R/dust.R | 8 ++-- R/filter-support.R | 38 +++++++++++++++++-- R/interface-filter.R | 4 +- R/interface-unfilter.R | 4 +- R/interface.R | 44 +++++++++++---------- R/metadata.R | 31 ++++++++++++++- R/package.R | 3 +- inst/template/dust.R | 2 +- man/dust_filter_create.Rd | 2 +- man/dust_system_generator.Rd | 9 ++++- man/dust_unfilter_create.Rd | 2 +- tests/testthat/test-compile.R | 57 +++++++++++++++++----------- tests/testthat/test-filter-support.R | 6 +-- tests/testthat/test-metadata.R | 26 +++++++++++++ 15 files changed, 174 insertions(+), 66 deletions(-) diff --git a/R/compile.R b/R/compile.R index 9ffeaf8c..b763df2a 100644 --- a/R/compile.R +++ b/R/compile.R @@ -116,7 +116,7 @@ dust_generate <- function(config, filename, linking_to, cpp_std, optimisation_level, compiler_options, mangle) { system <- read_lines(filename) data <- dust_template_data(config$name, config$class, config$time_type, - linking_to, cpp_std, + config$default_dt, linking_to, cpp_std, optimisation_level, compiler_options, mangle) data$system_requirements <- data$cpp_std %||% "R (>= 4.0.0)" @@ -171,6 +171,7 @@ dust_generate_cpp <- function(system, config, data) { dust_template_data <- function(name, class, time_type, + default_dt, linking_to = NULL, cpp_std = NULL, optimisation_level = NULL, @@ -188,6 +189,7 @@ dust_template_data <- function(name, list(name = name, class = class, time_type = time_type, + default_dt = deparse1(default_dt), package = paste0(name, mangle %||% ""), linking_to = linking_to, cpp_std = cpp_std, diff --git a/R/dust.R b/R/dust.R index 020f67ff..847ab234 100644 --- a/R/dust.R +++ b/R/dust.R @@ -1,13 +1,13 @@ ## Generated by dust2 (version 0.1.8) - do not edit logistic <- function() { - dust_system_generator("logistic", "continuous") + dust_system_generator("logistic", "continuous", NULL) } sir <- function() { - dust_system_generator("sir", "discrete") + dust_system_generator("sir", "discrete", 1) } sirode <- function() { - dust_system_generator("sirode", "continuous") + dust_system_generator("sirode", "continuous", NULL) } walk <- function() { - dust_system_generator("walk", "discrete") + dust_system_generator("walk", "discrete", 1) } diff --git a/R/filter-support.R b/R/filter-support.R index e8b1e529..b72dac04 100644 --- a/R/filter-support.R +++ b/R/filter-support.R @@ -10,16 +10,46 @@ check_generator_for_filter <- function(generator, what, call = NULL) { } -check_dt <- function(dt, call = NULL) { +check_system_dt <- function(dt, generator, name = "dt", call = NULL) { + is_discrete <- generator$properties$time_type == "discrete" + if (is_discrete) { + check_dt(dt %||% generator$default_dt, name, call = call) + } else { + if (!is.null(dt)) { + cli::cli_abort("Can't use '{name}' with continuous-time systems", + call = call) + } + } +} + + +check_system_ode_control <- function(ode_control, generator, + name = "ode_control", call = NULL) { + is_discrete <- generator$properties$time_type == "discrete" + if (is_discrete) { + if (!is.null(ode_control)) { + cli::cli_abort("Can't use 'ode_control' with discrete-time systems") + } + } else { + if (is.null(ode_control)) { + ode_control <- dust_ode_control() + } else { + assert_is(ode_control, "dust_ode_control", call = environment()) + } + } +} + + +check_dt <- function(dt, name = deparse(substitute(dt)), call = NULL) { assert_scalar_numeric(dt, call = call) if (dt <= 0) { - cli::cli_abort("Expected 'dt' to be greater than 0") + cli::cli_abort("Expected '{name}' to be greater than 0") } if (dt > 1) { - cli::cli_abort("Expected 'dt' to be at most 1") + cli::cli_abort("Expected '{name}' to be at most 1") } if (!rlang::is_integerish(1 / dt)) { - cli::cli_abort("Expected 'dt' to be the inverse of an integer", + cli::cli_abort("Expected '{name}' to be the inverse of an integer", arg = "dt", call = call) } dt diff --git a/R/interface-filter.R b/R/interface-filter.R index 63caf9d1..e50fe28b 100644 --- a/R/interface-filter.R +++ b/R/interface-filter.R @@ -32,7 +32,7 @@ ##' ##' @export dust_filter_create <- function(generator, time_start, data, - n_particles, n_groups = NULL, dt = 1, + n_particles, n_groups = NULL, dt = NULL, index_state = NULL, n_threads = 1, preserve_group_dimension = FALSE, seed = NULL) { @@ -43,7 +43,7 @@ dust_filter_create <- function(generator, time_start, data, data <- prepare_data(data, n_groups, call = call) time_start <- check_time_start(time_start, data$time, call = call) - dt <- check_dt(dt, call = call) + dt <- check_system_dt(dt, generator, call = call) n_groups <- data$n_groups preserve_group_dimension <- preserve_group_dimension || n_groups > 1 diff --git a/R/interface-unfilter.R b/R/interface-unfilter.R index 23d82113..bd081ecf 100644 --- a/R/interface-unfilter.R +++ b/R/interface-unfilter.R @@ -19,7 +19,7 @@ ##' @export dust_unfilter_create <- function(generator, time_start, data, n_particles = 1, n_groups = NULL, - dt = 1, n_threads = 1, index_state = NULL, + dt = NULL, n_threads = 1, index_state = NULL, preserve_particle_dimension = FALSE, preserve_group_dimension = FALSE) { call <- environment() @@ -29,7 +29,7 @@ dust_unfilter_create <- function(generator, time_start, data, data <- prepare_data(data, n_groups, call = call) time_start <- check_time_start(time_start, data$time, call = call) - dt <- check_dt(dt, call = call) + dt <- check_system_dt(dt, generator, call = call) n_groups <- data$n_groups preserve_group_dimension <- preserve_group_dimension || n_groups > 1 diff --git a/R/interface.R b/R/interface.R index d5488a52..c799f5e8 100644 --- a/R/interface.R +++ b/R/interface.R @@ -10,6 +10,8 @@ ##' the wrong time here will lead to crashes or failure to create ##' the generator. ##' +##' @param default_dt The default value for `dt` on initialisation +##' ##' @param env The environment where the generator is defined. ##' ##' @return A `dust_system_generator` object @@ -23,7 +25,7 @@ ##' # This is the same code as in "dust2:::sir", except there we find ##' # the correct environment automatically ##' dust2:::sir -dust_system_generator <- function(name, time_type, +dust_system_generator <- function(name, time_type, default_dt, env = parent.env(parent.frame())) { ## I don't love that this requires running through sprintf() each ## time we create a generator, but using a function for the generator (see @@ -68,6 +70,7 @@ dust_system_generator <- function(name, time_type, ret <- list(name = name, methods = methods, + default_dt = default_dt, properties = properties) class(ret) <- "dust_system_generator" @@ -147,22 +150,9 @@ dust_system_create <- function(generator, pars, n_particles, n_groups = 1, call <- environment() check_is_dust_system_generator(generator, substitute(generator)) - is_discrete <- generator$properties$time_type == "discrete" - if (is_discrete) { - dt <- check_dt(dt %||% 1, call = call) - if (!is.null(ode_control)) { - cli::cli_abort("Can't use 'ode_control' with discrete-time systems") - } - } else { - if (!is.null(dt)) { - cli::cli_abort("Can't use 'dt' with continuous-time systems") - } - if (is.null(ode_control)) { - ode_control <- dust_ode_control() - } else { - assert_is(ode_control, "dust_ode_control", call = environment()) - } - } + dt <- check_system_dt(dt, generator, call = call) + ode_control <- check_system_ode_control(ode_control, generator, call = call) + check_time(time, dt, call = call) assert_scalar_size(n_particles, allow_zero = FALSE, call = call) @@ -175,6 +165,7 @@ dust_system_create <- function(generator, pars, n_particles, n_groups = 1, preserve_group_dimension <- preserve_group_dimension || n_groups > 1 pars <- check_pars(pars, n_groups, NULL, preserve_group_dimension) + is_discrete <- generator$properties$time_type == "discrete" if (is_discrete) { res <- generator$methods$alloc(pars, time, dt, n_particles, n_groups, seed, deterministic, @@ -558,8 +549,13 @@ print.dust_system <- function(x, ...) { cli::cli_bullets(c( i = "This system has 'adjoint' support, and can compute gradients")) } - cli::cli_bullets(c( - i = "This system runs in {x$properties$time_type} time")) + if (x$properties$time_type == "discrete") { + cli::cli_bullets(c( + i = "This system runs in discrete time with dt = {x$dt}")) + } else { + cli::cli_bullets(c( + i = "This system runs in continuous time")) + } invisible(x) } @@ -573,8 +569,14 @@ print.dust_system_generator <- function(x, ...) { cli::cli_bullets(c( i = "This system has 'compare_data' support")) } - cli::cli_bullets(c( - i = "This system runs in {x$properties$time_type} time")) + if (x$properties$time_type == "discrete") { + cli::cli_bullets(c( + i = paste("This system runs in discrete time", + "with a default dt of {x$default_dt}"))) + } else { + cli::cli_bullets(c( + i = "This system runs in continuous time")) + } invisible(x) } diff --git a/R/metadata.R b/R/metadata.R index 4b7c3fa0..aaac1bc7 100644 --- a/R/metadata.R +++ b/R/metadata.R @@ -22,9 +22,11 @@ parse_metadata <- function(filename, call = NULL) { data <- decor::cpp_decorations(files = filename) class <- parse_metadata_class(data, call) + time_type <- parse_metadata_time_type(data, call) list(class = class, name = parse_metadata_name(data, call) %||% class, - time_type = parse_metadata_time_type(data, call), + time_type = time_type, + default_dt = parse_metadata_default_dt(data, time_type, call), has_compare = parse_metadata_has_compare(data, call), has_adjoint = parse_metadata_has_adjoint(data, call), parameters = parse_metadata_parameters(data, call)) @@ -95,6 +97,33 @@ parse_metadata_time_type <- function(data, call = NULL) { } +parse_metadata_default_dt <- function(data, time_type, call = NULL) { + data <- find_attribute_value_single(data, "dust2::default_dt", + required = FALSE, call = call) + if (is.null(data)) { + return(if (time_type == "discrete") 1 else NULL) + } + if (time_type == "continuous") { + cli::cli_abort( + "Can't use '[[dust::default_dt()]]' with continuous-time systems", + call = call) + } + if (length(data) != 1 || nzchar(names(data))) { + cli::cli_abort( + "Expected a single unnamed argument to '[[dust2::default_dt()]]'", + call = call) + } + if (!is.numeric(data[[1]])) { + cli::cli_abort( + "Expected a numerical argument to '[[dust2::default_dt()]]'", + call = call) + } + value <- data[[1]] + check_dt(value, "[[dust2::default_dt()]]", call = call) + value +} + + parse_metadata_has_compare <- function(data, call = NULL) { parse_metadata_has_feature("compare", data, call) } diff --git a/R/package.R b/R/package.R index c8e5ce31..10c89d95 100644 --- a/R/package.R +++ b/R/package.R @@ -218,7 +218,8 @@ package_validate_makevars_openmp <- function(text, call) { package_generate <- function(filename, call) { config <- parse_metadata(filename, call = call) system <- read_lines(filename) - data <- dust_template_data(config$name, config$class, config$time_type) + data <- dust_template_data(config$name, config$class, config$time_type, + config$default_dt) list(r = substitute_dust_template(data, "dust.R"), cpp = dust_generate_cpp(system, config, data)) } diff --git a/inst/template/dust.R b/inst/template/dust.R index 70fdf6c9..41b0bdea 100644 --- a/inst/template/dust.R +++ b/inst/template/dust.R @@ -1,3 +1,3 @@ {{name}} <- function() { - dust2::dust_system_generator("{{name}}", "{{time_type}}") + dust2::dust_system_generator("{{name}}", "{{time_type}}", {{default_dt}}) } diff --git a/man/dust_filter_create.Rd b/man/dust_filter_create.Rd index 3c244dd5..22139e90 100644 --- a/man/dust_filter_create.Rd +++ b/man/dust_filter_create.Rd @@ -10,7 +10,7 @@ dust_filter_create( data, n_particles, n_groups = NULL, - dt = 1, + dt = NULL, index_state = NULL, n_threads = 1, preserve_group_dimension = FALSE, diff --git a/man/dust_system_generator.Rd b/man/dust_system_generator.Rd index 2793f8fd..dc5769cf 100644 --- a/man/dust_system_generator.Rd +++ b/man/dust_system_generator.Rd @@ -4,7 +4,12 @@ \alias{dust_system_generator} \title{Create a system generator} \usage{ -dust_system_generator(name, time_type, env = parent.env(parent.frame())) +dust_system_generator( + name, + time_type, + default_dt, + env = parent.env(parent.frame()) +) } \arguments{ \item{name}{The name of the generator} @@ -13,6 +18,8 @@ dust_system_generator(name, time_type, env = parent.env(parent.frame())) the wrong time here will lead to crashes or failure to create the generator.} +\item{default_dt}{The default value for \code{dt} on initialisation} + \item{env}{The environment where the generator is defined.} } \value{ diff --git a/man/dust_unfilter_create.Rd b/man/dust_unfilter_create.Rd index ab3d6c16..ede1003e 100644 --- a/man/dust_unfilter_create.Rd +++ b/man/dust_unfilter_create.Rd @@ -10,7 +10,7 @@ dust_unfilter_create( data, n_particles = 1, n_groups = NULL, - dt = 1, + dt = NULL, n_threads = 1, index_state = NULL, preserve_particle_dimension = FALSE, diff --git a/tests/testthat/test-compile.R b/tests/testthat/test-compile.R index f5f7dcd4..6c2fc0f2 100644 --- a/tests/testthat/test-compile.R +++ b/tests/testthat/test-compile.R @@ -1,42 +1,53 @@ test_that("can construct template data", { expect_equal( - dust_template_data("foo", "foo", "discrete"), - list(name = "foo", class = "foo", time_type = "discrete", package = "foo", - linking_to = "cpp11, dust2, monty", cpp_std = NULL, + dust_template_data("foo", "foo", "discrete", 1), + list(name = "foo", class = "foo", time_type = "discrete", default_dt = "1", + package = "foo", linking_to = "cpp11, dust2, monty", cpp_std = NULL, compiler_options = "")) expect_equal( - dust_template_data("foo", "bar", "discrete", mangle = "abc"), - list(name = "foo", class = "bar", time_type = "discrete", + dust_template_data("foo", "bar", "discrete", 1, mangle = "abc"), + list(name = "foo", class = "bar", time_type = "discrete", default_dt = "1", package = "fooabc", linking_to = "cpp11, dust2, monty", cpp_std = NULL, compiler_options = "")) expect_equal( - dust_template_data("foo", "foo", "discrete", linking_to = "baz"), - list(name = "foo", class = "foo", time_type = "discrete", package = "foo", - linking_to = "cpp11, dust2, monty, baz", cpp_std = NULL, - compiler_options = "")) + dust_template_data("foo", "foo", "discrete", 1, linking_to = "baz"), + list(name = "foo", class = "foo", time_type = "discrete", default_dt = "1", + package = "foo", linking_to = "cpp11, dust2, monty, baz", + cpp_std = NULL, compiler_options = "")) expect_equal( - dust_template_data("foo", "foo", "discrete", + dust_template_data("foo", "foo", "discrete", 1, linking_to = c("x", "dust2", "y")), - list(name = "foo", class = "foo", time_type = "discrete", package = "foo", - linking_to = "cpp11, dust2, monty, x, y", cpp_std = NULL, - compiler_options = "")) + list(name = "foo", class = "foo", time_type = "discrete", default_dt = "1", + package = "foo", linking_to = "cpp11, dust2, monty, x, y", + cpp_std = NULL, compiler_options = "")) expect_equal( - dust_template_data("foo", "foo", time_type = "discrete", - compiler_options = "-Xf"), - list(name = "foo", class = "foo", time_type = "discrete", package = "foo", - linking_to = "cpp11, dust2, monty", cpp_std = NULL, + dust_template_data("foo", "foo", "discrete", 1, compiler_options = "-Xf"), + list(name = "foo", class = "foo", time_type = "discrete", default_dt = "1", + package = "foo", linking_to = "cpp11, dust2, monty", cpp_std = NULL, compiler_options = "-Xf")) expect_equal( - dust_template_data("foo", "foo", "discrete", optimisation_level = "none", + dust_template_data("foo", "foo", "discrete", 1, optimisation_level = "none", compiler_options = "-Xf"), - list(name = "foo", class = "foo", time_type = "discrete", package = "foo", - linking_to = "cpp11, dust2, monty", cpp_std = NULL, + list(name = "foo", class = "foo", time_type = "discrete", default_dt = "1", + package = "foo", linking_to = "cpp11, dust2, monty", cpp_std = NULL, compiler_options = "-Xf -O0")) expect_equal( - dust_template_data("foo", "foo", "discrete", cpp_std = "c++14"), - list(name = "foo", class = "foo", time_type = "discrete", package = "foo", - linking_to = "cpp11, dust2, monty", cpp_std = "c++14", + dust_template_data("foo", "foo", "discrete", 1, cpp_std = "c++14"), + list(name = "foo", class = "foo", time_type = "discrete", default_dt = "1", + package = "foo", linking_to = "cpp11, dust2, monty", cpp_std = "c++14", + compiler_options = "")) + expect_equal( + dust_template_data("foo", "foo", "discrete", 0.25), + list(name = "foo", class = "foo", time_type = "discrete", + default_dt = "0.25", + package = "foo", linking_to = "cpp11, dust2, monty", cpp_std = NULL, + compiler_options = "")) + expect_equal( + dust_template_data("foo", "foo", "continuous", NULL), + list(name = "foo", class = "foo", time_type = "continuous", + default_dt = "NULL", package = "foo", + linking_to = "cpp11, dust2, monty", cpp_std = NULL, compiler_options = "")) }) diff --git a/tests/testthat/test-filter-support.R b/tests/testthat/test-filter-support.R index d1933d20..e78f3e72 100644 --- a/tests/testthat/test-filter-support.R +++ b/tests/testthat/test-filter-support.R @@ -1,11 +1,11 @@ test_that("can validate that 'dt' is reasonable", { expect_no_error(check_dt(1)) expect_no_error(check_dt(1 / 5)) - expect_error(check_dt(-1), + expect_error(check_dt(-1, "dt"), "Expected 'dt' to be greater than 0") - expect_error(check_dt(2), + expect_error(check_dt(2, "dt"), "Expected 'dt' to be at most 1") - expect_error(check_dt(1 / 3.5), + expect_error(check_dt(1 / 3.5, "dt"), "Expected 'dt' to be the inverse of an integer") }) diff --git a/tests/testthat/test-metadata.R b/tests/testthat/test-metadata.R index eb888c87..f08b8c51 100644 --- a/tests/testthat/test-metadata.R +++ b/tests/testthat/test-metadata.R @@ -3,6 +3,7 @@ test_that("can read sir metadata", { expect_equal(meta$class, "sir") expect_equal(meta$name, "sir") expect_true(meta$has_compare) + expect_equal(meta$default_dt, 1) expect_equal(meta$parameters, data.frame(name = c("I0", "N", "beta", "gamma", "exp_noise"))) }) @@ -129,3 +130,28 @@ test_that("require that file exists", { parse_metadata(tempfile()), "File '.+' does not exist") }) + + +test_that("can specify default dt in discrete time models", { + tmp <- withr::local_tempfile() + writeLines(c( + "// [[dust2::class(a)]]", + "// [[dust2::time_type(discrete)]]", + "// [[dust2::default_dt(0.25)]]"), + tmp) + expect_equal(parse_metadata(tmp)$default_dt, 0.25) +}) + + +test_that("can validate default dt in discrete time models", { + tmp <- withr::local_tempfile() + writeLines(c( + "// [[dust2::class(a)]]", + "// [[dust2::time_type(discrete)]]", + "// [[dust2::default_dt(0.32)]]"), + tmp) + expect_error( + parse_metadata(tmp), + "Expected '[[dust2::default_dt()]]' to be the inverse of an integer", + fixed = TRUE) +}) From a47bdd90b2c6ef8818a8740f7dca4594e33f819a Mon Sep 17 00:00:00 2001 From: Rich FitzJohn Date: Tue, 24 Sep 2024 12:25:50 +0100 Subject: [PATCH 2/2] Fix example --- R/interface.R | 2 +- man/dust_system_generator.Rd | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/R/interface.R b/R/interface.R index c799f5e8..9de99586 100644 --- a/R/interface.R +++ b/R/interface.R @@ -20,7 +20,7 @@ ##' @keywords internal ##' @examples ##' # This creates the "sir" generator -##' dust_system_generator("sir", "discrete", asNamespace("dust2")) +##' dust_system_generator("sir", "discrete", 1, asNamespace("dust2")) ##' ##' # This is the same code as in "dust2:::sir", except there we find ##' # the correct environment automatically diff --git a/man/dust_system_generator.Rd b/man/dust_system_generator.Rd index dc5769cf..3cce6d47 100644 --- a/man/dust_system_generator.Rd +++ b/man/dust_system_generator.Rd @@ -32,7 +32,7 @@ using \link{dust_package} } \examples{ # This creates the "sir" generator -dust_system_generator("sir", "discrete", asNamespace("dust2")) +dust_system_generator("sir", "discrete", 1, asNamespace("dust2")) # This is the same code as in "dust2:::sir", except there we find # the correct environment automatically