-
-
Notifications
You must be signed in to change notification settings - Fork 20
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #391 from mlr-org/surv_to_classif_pipeline
Draft Surv to classif pipeline #194
- Loading branch information
Showing
38 changed files
with
970 additions
and
72 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,6 @@ | ||
Package: mlr3proba | ||
Title: Probabilistic Supervised Learning for 'mlr3' | ||
Version: 0.6.4 | ||
Version: 0.6.5 | ||
Authors@R: | ||
c(person(given = "Raphael", | ||
family = "Sonabend", | ||
|
@@ -35,6 +35,10 @@ Authors@R: | |
email = "[email protected]", | ||
role = "ctb", | ||
comment = c(ORCID = "0000-0001-7528-3795")), | ||
person(given = "Philip", | ||
family = "Studener", | ||
role = "aut", | ||
email = "[email protected]"), | ||
person(given = "Maximilian", | ||
family = "Muecke", | ||
email = "[email protected]", | ||
|
@@ -79,7 +83,9 @@ Suggests: | |
vdiffr, | ||
abind, | ||
Ecdat, | ||
coxed | ||
coxed, | ||
mlr3learners, | ||
pammtools | ||
LinkingTo: | ||
Rcpp | ||
Remotes: | ||
|
@@ -133,13 +139,15 @@ Collate: | |
'PipeOpBreslow.R' | ||
'PipeOpCrankCompositor.R' | ||
'PipeOpDistrCompositor.R' | ||
'PipeOpPredClassifSurvDiscTime.R' | ||
'PipeOpTransformer.R' | ||
'PipeOpPredTransformer.R' | ||
'PipeOpPredRegrSurv.R' | ||
'PipeOpPredSurvRegr.R' | ||
'PipeOpProbregrCompositor.R' | ||
'PipeOpSurvAvg.R' | ||
'PipeOpTaskRegrSurv.R' | ||
'PipeOpTaskSurvClassifDiscTime.R' | ||
'PipeOpTaskSurvRegr.R' | ||
'PipeOpTaskTransformer.R' | ||
'PredictionDataDens.R' | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,101 @@ | ||
#' @title PipeOpPredClassifSurvDiscTime | ||
#' @name mlr_pipeops_trafopred_classifsurv_disctime | ||
#' | ||
#' @description | ||
#' Transform [PredictionClassif] to [PredictionSurv] by converting | ||
#' event probabilities of a pseudo status variable (discrete time hazards) | ||
#' to survival probabilities using the product rule (Tutz et al. 2016): | ||
#' | ||
#' \deqn{P_k = p_k\cdot ... \cdot p_1} | ||
#' | ||
#' Where: | ||
#' - We assume that continuous time is divided into time intervals | ||
#' \eqn{[0, t_1), [t_1, t_2), ..., [t_n, \infty)} | ||
#' - \eqn{P_k = P(T > t_k)} is the survival probability at time \eqn{t_k} | ||
#' - \eqn{h_k} is the discrete-time hazard (classifier prediction), i.e. the | ||
#' conditional probability for an event in the \eqn{k}-interval. | ||
#' - \eqn{p_k = 1 - h_k = P(T \ge t_k | T \ge t_{k-1})} | ||
#' | ||
#' @section Input and Output Channels: | ||
#' The input is a [PredictionClassif] and a [data.table][data.table::data.table] | ||
#' with the transformed data both generated by [PipeOpTaskSurvClassifDiscTime]. | ||
#' The output is the input [PredictionClassif] transformed to a [PredictionSurv]. | ||
#' Only works during prediction phase. | ||
#' | ||
#' @references | ||
#' `r format_bib("tutz_2016")` | ||
#' @family PipeOps | ||
#' @family Transformation PipeOps | ||
#' @export | ||
PipeOpPredClassifSurvDiscTime = R6Class( | ||
"PipeOpPredClassifSurvDiscTime", | ||
inherit = mlr3pipelines::PipeOp, | ||
|
||
public = list( | ||
#' @description | ||
#' Creates a new instance of this [R6][R6::R6Class] class. | ||
#' @param id (character(1))\cr | ||
#' Identifier of the resulting object. | ||
initialize = function(id = "trafopred_classifsurv_disctime") { | ||
super$initialize( | ||
id = id, | ||
input = data.table( | ||
name = c("input", "transformed_data"), | ||
train = c("NULL", "data.table"), | ||
predict = c("PredictionClassif", "data.table") | ||
), | ||
output = data.table( | ||
name = "output", | ||
train = "NULL", | ||
predict = "PredictionSurv" | ||
) | ||
) | ||
} | ||
), | ||
|
||
private = list( | ||
.predict = function(input) { | ||
pred = input[[1]] | ||
data = input[[2]] | ||
assert_true(!is.null(pred$prob)) | ||
# probability of having the event (1) in each respective interval | ||
# is the discrete-time hazard | ||
data = cbind(data, dt_hazard = pred$prob[, "1"]) | ||
|
||
# From theory, convert hazards to surv as prod(1 - h(t)) | ||
rows_per_id = nrow(data) / length(unique(data$id)) | ||
surv = t(vapply(unique(data$id), function(unique_id) { | ||
cumprod(1 - data[data$id == unique_id, ][["dt_hazard"]]) | ||
}, numeric(rows_per_id))) | ||
|
||
pred_list = list() | ||
unique_end_times = sort(unique(data$tend)) | ||
# coerce to distribution and crank | ||
pred_list = .surv_return(times = unique_end_times, surv = surv) | ||
|
||
# select the real tend values by only selecting the last row of each id | ||
# basically a slightly more complex unique() | ||
real_tend = data$time2[seq_len(nrow(data)) %% rows_per_id == 0] | ||
|
||
# select last row for every id | ||
data = as.data.table(data) | ||
id = ped_status = NULL # to fix note | ||
data = data[, .SD[.N, list(ped_status)], by = id] | ||
|
||
# create prediction object | ||
p = PredictionSurv$new( | ||
row_ids = seq_row(data), | ||
crank = pred_list$crank, distr = pred_list$distr, | ||
truth = Surv(real_tend, as.integer(as.character(data$ped_status)))) | ||
|
||
list(p) | ||
}, | ||
|
||
.train = function(input) { | ||
self$state = list() | ||
list(input) | ||
} | ||
) | ||
) | ||
|
||
register_pipeop("trafopred_classifsurv_disctime", PipeOpPredClassifSurvDiscTime) |
Oops, something went wrong.