Skip to content

Commit

Permalink
Merge pull request #352 from mlr-org/calib_alpha_diff
Browse files Browse the repository at this point in the history
measure optimizations and documentation updates
  • Loading branch information
bblodfon authored Feb 21, 2024
2 parents 4bc2513 + b552e6d commit ed6c351
Show file tree
Hide file tree
Showing 102 changed files with 1,644 additions and 685 deletions.
6 changes: 3 additions & 3 deletions DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
Package: mlr3proba
Title: Probabilistic Supervised Learning for 'mlr3'
Version: 0.5.9
Version: 0.6.0
Authors@R:
c(person(given = "Raphael",
family = "Sonabend",
Expand Down Expand Up @@ -79,10 +79,10 @@ Encoding: UTF-8
LazyData: true
NeedsCompilation: no
Roxygen: list(markdown = TRUE, r6 = TRUE)
RoxygenNote: 7.3.1
RoxygenNote: 7.3.1.9000
Collate:
'aaa.R'
'LearnerDens.R'
'aaa.R'
'LearnerDensHistogram.R'
'LearnerDensKDE.R'
'LearnerSurv.R'
Expand Down
1 change: 1 addition & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ importFrom(stats,reformulate)
importFrom(stats,sd)
importFrom(survival,Surv)
importFrom(utils,data)
importFrom(utils,getFromNamespace)
importFrom(utils,head)
importFrom(utils,tail)
useDynLib(mlr3proba)
Expand Down
6 changes: 6 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
# mlr3proba 0.6.0

* Optimized `surv.logloss` and `calib_alpha` measures (bypassing `distr6`)
* Update/refine all measure docs (namign conventions from upcoming scoring rules paper) + doc templates
* fix very rare bugs in `calib_alpha`, `surv.logloss` and `surv.graf` (version with proper = FALSE)

# mlr3proba 0.5.9

* Fix several old issues (#348, #301, #281)
Expand Down
2 changes: 1 addition & 1 deletion R/LearnerDensHistogram.R
Original file line number Diff line number Diff line change
Expand Up @@ -39,5 +39,5 @@ LearnerDensHistogram = R6::R6Class("LearnerDensHistogram",
)
)

#' @include zzz.R
#' @include aaa.R
register_learner("dens.hist", LearnerDensHistogram)
9 changes: 4 additions & 5 deletions R/MeasureDensLogloss.R
Original file line number Diff line number Diff line change
@@ -1,16 +1,15 @@
#' @template dens_measure
#' @templateVar title Log loss
#' @templateVar title Log Loss
#' @templateVar inherit [MeasureDens]
#' @templateVar fullname MeasureDensLogloss
#' @templateVar pars eps = 1e-15
#' @templateVar eps_par TRUE
#'
#' @templateVar eps 1e-15
#' @template param_eps
#'
#' @description
#' Calculates the cross-entropy, or logarithmic (log), loss.
#'
#' The logloss, in the context of probabilistic predictions, is defined as the negative log
#' @details
#' The Log Loss, in the context of probabilistic predictions, is defined as the negative log
#' probability density function, \eqn{f}, evaluated at the observed value, \eqn{y},
#' \deqn{L(f, y) = -\log(f(y))}{L(f, y) = -log(f(y))}
#'
Expand Down
9 changes: 4 additions & 5 deletions R/MeasureRegrLogloss.R
Original file line number Diff line number Diff line change
@@ -1,16 +1,15 @@
#' @template regr_measure
#' @templateVar title Log loss
#' @templateVar title Log Loss
#' @templateVar inherit [MeasureRegr]
#' @templateVar fullname MeasureRegrLogloss
#' @templateVar pars eps = 1e-15
#' @templateVar eps_par TRUE
#'
#' @templateVar eps 1e-15
#' @template param_eps
#'
#' @description
#' Calculates the cross-entropy, or logarithmic (log), loss.
#'
#' The logloss, in the context of probabilistic predictions, is defined as the negative log
#' @details
#' The Log Loss, in the context of probabilistic predictions, is defined as the negative log
#' probability density function, \eqn{f}, evaluated at the observed value, \eqn{y},
#' \deqn{L(f, y) = -\log(f(y))}{L(f, y) = -log(f(y))}
#'
Expand Down
3 changes: 2 additions & 1 deletion R/MeasureSurv.R
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
#' @template param_packages
#' @template param_label
#' @template param_man
#' @template param_se
#'
#' @family Measure
#' @seealso
Expand All @@ -32,6 +31,8 @@ MeasureSurv = R6Class("MeasureSurv",
public = list(
#' @description
#' Creates a new instance of this [R6][R6::R6Class] class.
#' @param se If `TRUE` then returns standard error of the
#' measure otherwise returns the mean (default).
initialize = function(id, param_set = ps(), range, minimize = NA, aggregator = NULL,
properties = character(), predict_type = "distr", task_properties = character(),
packages = character(), label = NA_character_, man = NA_character_, se = FALSE) {
Expand Down
106 changes: 83 additions & 23 deletions R/MeasureSurvCalibrationAlpha.R
Original file line number Diff line number Diff line change
@@ -1,19 +1,37 @@
#' @template surv_measure
#' @templateVar title Van Houwelingen's Alpha
#' @templateVar title Van Houwelingen's Calibration Alpha
#' @templateVar fullname MeasureSurvCalibrationAlpha
#'
#' @template param_se
#' @templateVar eps 1e-3
#' @template param_eps
#'
#' @description
#' This calibration method is defined by estimating
#' \deqn{\alpha = \sum \delta_i / \sum H_i(t_i)}
#' where \eqn{\delta} is the observed censoring indicator from the test data, \eqn{H_i} is the
#' predicted cumulative hazard, and \eqn{t_i} is the observed survival time.
#' \deqn{\hat{\alpha} = \sum \delta_i / \sum H_i(T_i)}
#' where \eqn{\delta} is the observed censoring indicator from the test data,
#' \eqn{H_i} is the predicted cumulative hazard, and \eqn{T_i} is the observed
#' survival time (event or censoring).
#'
#' The standard error is given by
#' \deqn{exp(1/\sqrt{\sum \delta_i})}
#' \deqn{\hat{\alpha_{se}} = exp(1/\sqrt{\sum \delta_i})}
#'
#' The model is well calibrated if the estimated \eqn{\hat{\alpha}} coefficient
#' (returned score) is equal to 1.
#'
#' The model is well calibrated if the estimated \eqn{\alpha} coefficient is equal to 1.
#' @section Parameter details:
#' - `se` (`logical(1)`)\cr
#' If `TRUE` then return standard error of the measure, otherwise the score
#' itself (default).
#' - `method` (`character(1)`)\cr
#' Returns \eqn{\hat{\alpha}} if equal to `ratio` (default) and
#' \eqn{|1-\hat{\alpha}|} if equal to `diff`.
#' With `diff`, the output score can be minimized and for example be used for
#' tuning purposes. This parameter takes effect only if `se` is `FALSE`.
#' - `truncate` (`double(1)`) \cr
#' This parameter controls the upper bound of the output score.
#' We use `truncate = Inf` by default (so no truncation) and it's up to the user
#' **to set this up reasonably** given the chosen `method`.
#' Note that truncation may severely limit automated tuning with this measure
#' using `method = diff`.
#'
#' @references
#' `r format_bib("vanhouwelingen_2000")`
Expand All @@ -25,16 +43,25 @@ MeasureSurvCalibrationAlpha = R6Class("MeasureSurvCalibrationAlpha",
inherit = MeasureSurv,
public = list(
#' @description Creates a new instance of this [R6][R6::R6Class] class.
initialize = function() {
#' @param method defines which output score to return, see "Parameter
#' details" section.
initialize = function(method = "ratio") {
assert_choice(method, choices = c("ratio", "diff"))

ps = ps(
se = p_lgl(default = FALSE)
eps = p_dbl(0, 1, default = 1e-3),
se = p_lgl(default = FALSE),
method = p_fct(c("ratio", "diff"), default = "ratio"),
truncate = p_dbl(lower = -Inf, upper = Inf, default = Inf)
)
ps$values$se = FALSE
ps$values = list(eps = 1e-3, se = FALSE, method = method, truncate = Inf)
range = if (method == "ratio") c(-Inf, Inf) else c(0, Inf)
minimize = ifelse(method == "ratio", FALSE, TRUE)

super$initialize(
id = "surv.calib_alpha",
range = c(-Inf, Inf),
minimize = FALSE,
range = range,
minimize = minimize,
predict_type = "distr",
label = "Van Houwelingen's Alpha",
man = "mlr3proba::mlr_measures_surv.calib_alpha",
Expand All @@ -45,21 +72,54 @@ MeasureSurvCalibrationAlpha = R6Class("MeasureSurvCalibrationAlpha",

private = list(
.score = function(prediction, ...) {
deaths = sum(prediction$truth[, 2])
truth = prediction$truth
all_times = truth[, 1] # both event times and censoring times
status = truth[, 2]
deaths = sum(status)

if (self$param_set$values$se) {
ps = self$param_set$values
if (ps$se) {
return(exp(1 / sqrt(deaths)))
} else {
if (inherits(prediction$distr, "VectorDistribution")) {
haz = as.numeric(prediction$distr$cumHazard(
data = matrix(prediction$truth[, 1], nrow = 1)
))
distr = prediction$data$distr

# Bypass distr6 construction if underlying distr represented by array
if (inherits(distr, "array")) {
surv = distr
if (length(dim(surv)) == 3) {
# survival 3d array, extract median
surv = .ext_surv_mat(arr = surv, which.curve = 0.5)
}
times = as.numeric(colnames(surv))

extend_times_cdf = getFromNamespace("C_Vec_WeightedDiscreteCdf", ns = "distr6")
# get survival probability for each test obs at observed time
surv_all = diag(
extend_times_cdf(all_times, times, cdf = t(1 - surv), FALSE, FALSE)
)

# H(t) = -log(S(t))
cumhaz = -log(surv_all)
} else {
haz = diag(prediction$distr$cumHazard(prediction$truth[, 1]))
if (inherits(distr, "VectorDistribution")) {
cumhaz = as.numeric(
distr$cumHazard(data = matrix(all_times, nrow = 1))
)
} else {
cumhaz = diag(as.matrix(distr$cumHazard(all_times)))
}
}
# cumulative hazard should only be infinite if only censoring occurs at the final time-point
haz[haz == Inf] = 0
return(deaths / sum(haz))

# Inf => case where censoring occurs at last time point
# 0 => case where survival probabilities are all 1
cumhaz[cumhaz == Inf | cumhaz == 0] = ps$eps
out = deaths / sum(cumhaz)

if (ps$method == "diff") {
out = abs(1 - out)
}

return(min(ps$truncate, out))
}
}
)
Expand Down
63 changes: 45 additions & 18 deletions R/MeasureSurvCalibrationBeta.R
Original file line number Diff line number Diff line change
@@ -1,18 +1,30 @@
#' @template surv_measure
#' @templateVar title Van Houwelingen's Beta
#' @templateVar title Van Houwelingen's Calibration Beta
#' @templateVar fullname MeasureSurvCalibrationBeta
#'
#' @template param_se
#'
#' @description
#' This calibration method fits the predicted linear predictor from a Cox PH model as the only
#' predictor in a new Cox PH model with the test data as the response.
#' \deqn{h(t|x) = h_0(t)exp(l\beta)}
#' where \eqn{l} is the predicted linear predictor.
#' This calibration method fits the predicted linear predictor from a Cox PH
#' model as the only predictor in a new Cox PH model with the test data as
#' the response.
#' \deqn{h(t|x) = h_0(t)exp(\beta \times lp)}
#' where \eqn{lp} is the predicted linear predictor on the test data.
#'
#' The model is well calibrated if the estimated \eqn{\hat{\beta}} coefficient
#' (returned score) is equal to 1.
#'
#' The model is well calibrated if the estimated \eqn{\beta} coefficient is equal to 1.
#' **Note**: Assumes fitted model is Cox PH (i.e. has an `lp` prediction type).
#'
#' Assumes fitted model is Cox PH.
#' @section Parameter details:
#' - `se` (`logical(1)`)\cr
#' If `TRUE` then return standard error of the measure which is the standard
#' error of the estimated coefficient \eqn{se_{\hat{\beta}}} from the Cox PH model.
#' If `FALSE` (default) then returns the estimated coefficient \eqn{\hat{\beta}}.
#' - `method` (`character(1)`)\cr
#' Returns \eqn{\hat{\beta}} if equal to `ratio` (default) and \eqn{|1-\hat{\beta}|}
#' if `diff`.
#' With `diff`, the output score can be minimized and for example be used for
#' tuning purposes.
#' This parameter takes effect only if `se` is `FALSE`.
#'
#' @references
#' `r format_bib("vanhouwelingen_2000")`
Expand All @@ -24,16 +36,23 @@ MeasureSurvCalibrationBeta = R6Class("MeasureSurvCalibrationBeta",
inherit = MeasureSurv,
public = list(
#' @description Creates a new instance of this [R6][R6::R6Class] class.
initialize = function() {
#' @param method defines which output score to return, see "Parameter
#' details" section.
initialize = function(method = "ratio") {
assert_choice(method, choices = c("ratio", "diff"))

ps = ps(
se = p_lgl(default = FALSE)
se = p_lgl(default = FALSE),
method = p_fct(c("ratio", "diff"), default = "ratio")
)
ps$values$se = FALSE
ps$values = list(se = FALSE, method = method)
range = if (method == "ratio") c(-Inf, Inf) else c(0, Inf)
minimize = ifelse(method == "ratio", FALSE, TRUE)

super$initialize(
id = "surv.calib_beta",
range = c(-Inf, Inf),
minimize = FALSE,
range = range,
minimize = minimize,
predict_type = "lp",
label = "Van Houwelingen's Beta",
man = "mlr3proba::mlr_measures_surv.calib_beta",
Expand All @@ -44,16 +63,24 @@ MeasureSurvCalibrationBeta = R6Class("MeasureSurvCalibrationBeta",

private = list(
.score = function(prediction, ...) {

df = data.frame(truth = prediction$truth, lp = prediction$lp)
fit = try(summary(survival::coxph(truth ~ lp, data = df)), silent = TRUE)

if (class(fit)[1] == "try-error") {
return(NA)
} else {
if (self$param_set$values$se) {
return(fit$coefficients[3])
ps = self$param_set$values

if (ps$se) {
return(fit$coefficients[,"se(coef)"])
} else {
return(fit$coefficients[1])
out = fit$coefficients[,"coef"]

if (ps$method == "diff") {
out = abs(1 - out)
}

return(out)
}
}
}
Expand Down
11 changes: 5 additions & 6 deletions R/MeasureSurvChamblessAUC.R
Original file line number Diff line number Diff line change
@@ -1,16 +1,15 @@
#' @template surv_measure
#' @templateVar title Chambless and Diao's AUC
#' @templateVar fullname MeasureSurvChamblessAUC
#' @template measure_survAUC
#' @template param_integrated
#' @template param_times
#'
#' @description
#' Calls [survAUC::AUC.cd()].
#'
#' Assumes Cox PH model specification.
#'
#' @template param_integrated
#' @template param_times
#' @template measure_survAUC
#'
#' @references
#' `r format_bib("chambless_2006")`
#'
Expand All @@ -24,8 +23,8 @@ MeasureSurvChamblessAUC = R6Class("MeasureSurvChamblessAUC",
#' Creates a new instance of this [R6][R6::R6Class] class.
initialize = function() {
ps = ps(
times = p_uty(),
integrated = p_lgl(default = TRUE)
integrated = p_lgl(default = TRUE),
times = p_uty()
)
ps$values$integrated = TRUE

Expand Down
Loading

0 comments on commit ed6c351

Please sign in to comment.