Skip to content

Commit

Permalink
Merge pull request #391 from mlr-org/surv_to_classif_pipeline
Browse files Browse the repository at this point in the history
Draft Surv to classif pipeline #194
  • Loading branch information
bblodfon authored Jul 25, 2024
2 parents fbcc8a2 + 80bf473 commit 6548c5b
Show file tree
Hide file tree
Showing 38 changed files with 970 additions and 72 deletions.
12 changes: 10 additions & 2 deletions DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
Package: mlr3proba
Title: Probabilistic Supervised Learning for 'mlr3'
Version: 0.6.4
Version: 0.6.5
Authors@R:
c(person(given = "Raphael",
family = "Sonabend",
Expand Down Expand Up @@ -35,6 +35,10 @@ Authors@R:
email = "[email protected]",
role = "ctb",
comment = c(ORCID = "0000-0001-7528-3795")),
person(given = "Philip",
family = "Studener",
role = "aut",
email = "[email protected]"),
person(given = "Maximilian",
family = "Muecke",
email = "[email protected]",
Expand Down Expand Up @@ -79,7 +83,9 @@ Suggests:
vdiffr,
abind,
Ecdat,
coxed
coxed,
mlr3learners,
pammtools
LinkingTo:
Rcpp
Remotes:
Expand Down Expand Up @@ -133,13 +139,15 @@ Collate:
'PipeOpBreslow.R'
'PipeOpCrankCompositor.R'
'PipeOpDistrCompositor.R'
'PipeOpPredClassifSurvDiscTime.R'
'PipeOpTransformer.R'
'PipeOpPredTransformer.R'
'PipeOpPredRegrSurv.R'
'PipeOpPredSurvRegr.R'
'PipeOpProbregrCompositor.R'
'PipeOpSurvAvg.R'
'PipeOpTaskRegrSurv.R'
'PipeOpTaskSurvClassifDiscTime.R'
'PipeOpTaskSurvRegr.R'
'PipeOpTaskTransformer.R'
'PredictionDataDens.R'
Expand Down
3 changes: 3 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -72,12 +72,14 @@ export(MeasureSurvXuR2)
export(PipeOpBreslow)
export(PipeOpCrankCompositor)
export(PipeOpDistrCompositor)
export(PipeOpPredClassifSurvDiscTime)
export(PipeOpPredRegrSurv)
export(PipeOpPredSurvRegr)
export(PipeOpPredTransformer)
export(PipeOpProbregr)
export(PipeOpSurvAvg)
export(PipeOpTaskRegrSurv)
export(PipeOpTaskSurvClassifDiscTime)
export(PipeOpTaskSurvRegr)
export(PipeOpTaskTransformer)
export(PipeOpTransformer)
Expand All @@ -95,6 +97,7 @@ export(as_task_surv)
export(assert_surv)
export(breslow)
export(pecs)
export(pipeline_survtoclassif_disctime)
export(pipeline_survtoregr)
export(plot_probregr)
import(checkmate)
Expand Down
8 changes: 7 additions & 1 deletion NEWS.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,13 @@
# mlr3proba 0.6.5

* Add support for discrete-time survival analysis
* New `PipeOp`s: `PipeOpTaskSurvClassifDiscTime`, `PipeOpPredClassifSurvDiscTime`
* New pipeline: `pipeline_survtoclassif`

# mlr3proba 0.6.4

* Add useR! 2024 tutorial
* Lots of refactoring, improve code quality (thanks to @m-muecke)
* Lots of refactoring, improving code quality, migration to testthat v3, etc. (thanks to @m-muecke)

# mlr3proba 0.6.3

Expand Down
24 changes: 7 additions & 17 deletions R/LearnerSurvCoxPH.R
Original file line number Diff line number Diff line change
Expand Up @@ -42,30 +42,20 @@ LearnerSurvCoxPH = R6Class("LearnerSurvCoxPH",
pv$weights = task$weights$weight
}

invoke(survival::coxph, formula = task$formula(), data = task$data(), .args = pv, x = TRUE)
invoke(survival::coxph, formula = task$formula(), data = task$data(),
.args = pv, x = TRUE)
},

.predict = function(task) {

newdata = task$data(cols = task$feature_names)

# We move the missingness checks here manually as if any NAs are made in predictions then the
# distribution object cannot be create (initialization of distr6 objects does not handle NAs)
if (anyMissing(newdata)) {
stopf(
"Learner %s on task %s failed to predict: Missing values in new data (line(s) %s)\n",
self$id, task$id,
toString(which(!complete.cases(newdata)))
)
}

newdata = ordered_features(task, self)
pv = self$param_set$get_values(tags = "predict")

# Get predicted values
# Get survival predictions via `survfit`
fit = invoke(survival::survfit, formula = self$model, newdata = newdata,
se.fit = FALSE, .args = pv)
se.fit = FALSE, .args = pv)

lp = predict(self$model, type = "lp", newdata = newdata)
# Get linear predictors
lp = invoke(predict, self$model, type = "lp", newdata = newdata)

.surv_return(times = fit$time, surv = t(fit$surv), lp = lp)
}
Expand Down
11 changes: 6 additions & 5 deletions R/LearnerSurvRpart.R
Original file line number Diff line number Diff line change
Expand Up @@ -79,14 +79,15 @@ LearnerSurvRpart = R6Class("LearnerSurvRpart",
pv = insert_named(pv, list(weights = task$weights$weight))
}

invoke(rpart::rpart,
formula = task$formula(), data = task$data(),
method = "exp", .args = pv)
invoke(rpart::rpart, formula = task$formula(), data = task$data(),
method = "exp", .args = pv)
},

.predict = function(task) {
preds = invoke(predict, object = self$model, newdata = task$data(cols = task$feature_names))
list(crank = preds)
newdata = ordered_features(task, self)
p = invoke(predict, object = self$model, newdata = newdata)

list(crank = p)
}
)
)
Expand Down
101 changes: 101 additions & 0 deletions R/PipeOpPredClassifSurvDiscTime.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
#' @title PipeOpPredClassifSurvDiscTime
#' @name mlr_pipeops_trafopred_classifsurv_disctime
#'
#' @description
#' Transform [PredictionClassif] to [PredictionSurv] by converting
#' event probabilities of a pseudo status variable (discrete time hazards)
#' to survival probabilities using the product rule (Tutz et al. 2016):
#'
#' \deqn{P_k = p_k\cdot ... \cdot p_1}
#'
#' Where:
#' - We assume that continuous time is divided into time intervals
#' \eqn{[0, t_1), [t_1, t_2), ..., [t_n, \infty)}
#' - \eqn{P_k = P(T > t_k)} is the survival probability at time \eqn{t_k}
#' - \eqn{h_k} is the discrete-time hazard (classifier prediction), i.e. the
#' conditional probability for an event in the \eqn{k}-interval.
#' - \eqn{p_k = 1 - h_k = P(T \ge t_k | T \ge t_{k-1})}
#'
#' @section Input and Output Channels:
#' The input is a [PredictionClassif] and a [data.table][data.table::data.table]
#' with the transformed data both generated by [PipeOpTaskSurvClassifDiscTime].
#' The output is the input [PredictionClassif] transformed to a [PredictionSurv].
#' Only works during prediction phase.
#'
#' @references
#' `r format_bib("tutz_2016")`
#' @family PipeOps
#' @family Transformation PipeOps
#' @export
PipeOpPredClassifSurvDiscTime = R6Class(
"PipeOpPredClassifSurvDiscTime",
inherit = mlr3pipelines::PipeOp,

public = list(
#' @description
#' Creates a new instance of this [R6][R6::R6Class] class.
#' @param id (character(1))\cr
#' Identifier of the resulting object.
initialize = function(id = "trafopred_classifsurv_disctime") {
super$initialize(
id = id,
input = data.table(
name = c("input", "transformed_data"),
train = c("NULL", "data.table"),
predict = c("PredictionClassif", "data.table")
),
output = data.table(
name = "output",
train = "NULL",
predict = "PredictionSurv"
)
)
}
),

private = list(
.predict = function(input) {
pred = input[[1]]
data = input[[2]]
assert_true(!is.null(pred$prob))
# probability of having the event (1) in each respective interval
# is the discrete-time hazard
data = cbind(data, dt_hazard = pred$prob[, "1"])

# From theory, convert hazards to surv as prod(1 - h(t))
rows_per_id = nrow(data) / length(unique(data$id))
surv = t(vapply(unique(data$id), function(unique_id) {
cumprod(1 - data[data$id == unique_id, ][["dt_hazard"]])
}, numeric(rows_per_id)))

pred_list = list()
unique_end_times = sort(unique(data$tend))
# coerce to distribution and crank
pred_list = .surv_return(times = unique_end_times, surv = surv)

# select the real tend values by only selecting the last row of each id
# basically a slightly more complex unique()
real_tend = data$time2[seq_len(nrow(data)) %% rows_per_id == 0]

# select last row for every id
data = as.data.table(data)
id = ped_status = NULL # to fix note
data = data[, .SD[.N, list(ped_status)], by = id]

# create prediction object
p = PredictionSurv$new(
row_ids = seq_row(data),
crank = pred_list$crank, distr = pred_list$distr,
truth = Surv(real_tend, as.integer(as.character(data$ped_status))))

list(p)
},

.train = function(input) {
self$state = list()
list(input)
}
)
)

register_pipeop("trafopred_classifsurv_disctime", PipeOpPredClassifSurvDiscTime)
Loading

0 comments on commit 6548c5b

Please sign in to comment.