Skip to content

Commit

Permalink
Merge pull request #9 from mrc-ide/dev
Browse files Browse the repository at this point in the history
Dev
  • Loading branch information
pwinskill authored Mar 29, 2023
2 parents 56974f7 + d3ada89 commit 144d4d7
Show file tree
Hide file tree
Showing 7 changed files with 96 additions and 73 deletions.
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
Package: cali
Title: Good vibes and model calibration
Version: 0.2.1
Version: 0.2.3
Authors@R: c(
person("Pete", "Winskill", email = "[email protected]", role = c("aut", "cre"))
)
Expand Down
62 changes: 16 additions & 46 deletions R/calibrate.R
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@
#' @param target Values of target variable to calibrate to.
#' @param summary_function A function that take the raw model output as an argument and
#' produces a vector of the target variable.
#' @param tolerance The routine will complete when the absolute sum of the weighted difference between the target variable
#' @param tolerance The routine will complete when the average absolute weighted difference between the target variable
#' and the target values falls below this value. Tolerance is specified in
#' the units of the target variable (e.g. if my target variable is prevalence, and my tolerance is 0.05,
#' then the absolute sum of the weighted difference between the target variable and the target values must be <5% for the routine to succeed).
#' then the average absolute sum of the weighted difference between the target variable and the target values must be <5% for the routine to succeed).
#' @param weights A numerical vector of weights the same length as target giving the weights to use for elements of target.
#' @param elimination_check Check transmission is maintained for all target points with ongoing transmission before exiting early.
#' @param low Lower boundof EIRs
Expand All @@ -26,48 +26,18 @@ calibrate <- function(target,
elimination_check = TRUE,
maxiter = 20,
low = 0.001, high = 2000){

x <- proposal(low, high)

for(i in 1:maxiter){
if(low == high) break

y <- objective(x = x,
parameters = parameters,
summary_function = summary_function)

difference <- y - target
weighted_difference <- difference * weights
print(signif(rbind(y, target, difference, weighted_difference)), 3)
diff <- sum(weighted_difference)
abs_diff <- sum(abs(weighted_difference))

# Can stop early if close enough and transmission maintained
within_tol <- abs_diff < tolerance
transmisison <- TRUE
if(elimination_check){
transmission <- all(y[target > 0] > 0)
}
if(within_tol & transmission) break

if(diff < 0){
low <- x
x <- proposal(x, high)
}
if(diff > 0){
high <- x
x <- proposal(low, x)
}
}
return(x)
}

#' Propose new EIR, moving on log scale
#'
#' @param a lower EIR
#' @param b upper EIR
#'
#' @return EIR
proposal <- function(a, b){
exp(log(a) + (log(b) - log(a)) / 2)
}
x <- stats::uniroot(
f = objective,
lower = low,
upper = high,
maxiter = maxiter,
parameters = parameters,
summary_function = summary_function,
target = target,
weights = weights,
tolerance = tolerance,
elimination_check = elimination_check
)
return(x$root)
}
27 changes: 24 additions & 3 deletions R/objective.R
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,33 @@
#' @inheritParams calibrate
#'
#' @return Difference between output and target.
objective <- function(x, parameters, summary_function){
objective <- function(x, parameters, summary_function, target, weights, tolerance, elimination_check){
message("\nTrying EIR: ", signif(x, 5))

p <- malariasimulation::set_equilibrium(parameters, init_EIR = x)
raw_output <- malariasimulation::run_simulation(timesteps = p$timesteps, parameters = p)
model_output <- summary_function(raw_output)

return(model_output)
difference <- model_output - target
weighted_difference <- difference * weights

print(signif(rbind(model_output, target, difference, weighted_difference)), 3)

if(elimination_check){
bad_eliminated <- model_output == 0 & target != 0
if(sum(bad_eliminated) > 0){
message("Unwanted elimination")
weighted_difference[bad_eliminated] <- -1e6
}
}

mean_weighted_difference <- mean(weighted_difference)

message("\nSum squared weighted difference: ", signif(mean_weighted_difference, 5))

if(mean(abs(weighted_difference)) < tolerance){
message("Mean absolute difference < tolerance")
mean_weighted_difference <- 0
}

return(mean_weighted_difference)
}
4 changes: 2 additions & 2 deletions man/calibrate.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

21 changes: 20 additions & 1 deletion man/objective.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

19 changes: 0 additions & 19 deletions man/proposal.Rd

This file was deleted.

34 changes: 33 additions & 1 deletion tests/testthat/test-objective.R
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,38 @@ test_that("objective works", {
p <- malariasimulation::get_parameters(list(human_population = 1000))
mockery::stub(objective, "malariasimulation::run_simulation", x)
expect_equal(objective(x = 1, parameters = p,
summary_function = summary_mean_pfpr_2_10), 0.5)
summary_function = summary_mean_pfpr_2_10,
target = 0.5,
weights = 1,
tolerance = 0,
elimination_check = FALSE), 0)
mockery::stub(objective, "malariasimulation::run_simulation", x)
expect_equal(objective(x = 1, parameters = p,
summary_function = summary_mean_pfpr_2_10,
target = 0,
weights = 1,
tolerance = 0,
elimination_check = FALSE), 0.5)

x <- data.frame(n_detect_730_3650 = c(0, 0, 0),
n_730_3650 = 100)
mockery::stub(objective, "malariasimulation::run_simulation", x)
expect_equal(objective(x = 1,
parameters = p,
summary_function = summary_mean_pfpr_2_10,
target = 0.5,
weights = 1,
tolerance = 0.02,
elimination_check = TRUE), -1e6)

x <- data.frame(n_detect_730_3650 = c(0, 0, 0),
n_730_3650 = 100)
mockery::stub(objective, "malariasimulation::run_simulation", x)
expect_equal(objective(x = 1,
parameters = p,
summary_function = summary_mean_pfpr_2_10,
target = 0.001,
weights = 1,
tolerance = 0.02,
elimination_check = TRUE), -1e6)
})

0 comments on commit 144d4d7

Please sign in to comment.