Skip to content

Commit 6548c5b

Browse files
authored
Merge pull request #391 from mlr-org/surv_to_classif_pipeline
Draft Surv to classif pipeline #194
2 parents fbcc8a2 + 80bf473 commit 6548c5b

38 files changed

+970
-72
lines changed

DESCRIPTION

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
Package: mlr3proba
22
Title: Probabilistic Supervised Learning for 'mlr3'
3-
Version: 0.6.4
3+
Version: 0.6.5
44
Authors@R:
55
c(person(given = "Raphael",
66
family = "Sonabend",
@@ -35,6 +35,10 @@ Authors@R:
3535
email = "github@quantenbrot.de",
3636
role = "ctb",
3737
comment = c(ORCID = "0000-0001-7528-3795")),
38+
person(given = "Philip",
39+
family = "Studener",
40+
role = "aut",
41+
email = "philip.studener@gmx.de"),
3842
person(given = "Maximilian",
3943
family = "Muecke",
4044
email = "muecke.maximilian@gmail.com",
@@ -79,7 +83,9 @@ Suggests:
7983
vdiffr,
8084
abind,
8185
Ecdat,
82-
coxed
86+
coxed,
87+
mlr3learners,
88+
pammtools
8389
LinkingTo:
8490
Rcpp
8591
Remotes:
@@ -133,13 +139,15 @@ Collate:
133139
'PipeOpBreslow.R'
134140
'PipeOpCrankCompositor.R'
135141
'PipeOpDistrCompositor.R'
142+
'PipeOpPredClassifSurvDiscTime.R'
136143
'PipeOpTransformer.R'
137144
'PipeOpPredTransformer.R'
138145
'PipeOpPredRegrSurv.R'
139146
'PipeOpPredSurvRegr.R'
140147
'PipeOpProbregrCompositor.R'
141148
'PipeOpSurvAvg.R'
142149
'PipeOpTaskRegrSurv.R'
150+
'PipeOpTaskSurvClassifDiscTime.R'
143151
'PipeOpTaskSurvRegr.R'
144152
'PipeOpTaskTransformer.R'
145153
'PredictionDataDens.R'

NAMESPACE

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,12 +72,14 @@ export(MeasureSurvXuR2)
7272
export(PipeOpBreslow)
7373
export(PipeOpCrankCompositor)
7474
export(PipeOpDistrCompositor)
75+
export(PipeOpPredClassifSurvDiscTime)
7576
export(PipeOpPredRegrSurv)
7677
export(PipeOpPredSurvRegr)
7778
export(PipeOpPredTransformer)
7879
export(PipeOpProbregr)
7980
export(PipeOpSurvAvg)
8081
export(PipeOpTaskRegrSurv)
82+
export(PipeOpTaskSurvClassifDiscTime)
8183
export(PipeOpTaskSurvRegr)
8284
export(PipeOpTaskTransformer)
8385
export(PipeOpTransformer)
@@ -95,6 +97,7 @@ export(as_task_surv)
9597
export(assert_surv)
9698
export(breslow)
9799
export(pecs)
100+
export(pipeline_survtoclassif_disctime)
98101
export(pipeline_survtoregr)
99102
export(plot_probregr)
100103
import(checkmate)

NEWS.md

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,13 @@
1+
# mlr3proba 0.6.5
2+
3+
* Add support for discrete-time survival analysis
4+
* New `PipeOp`s: `PipeOpTaskSurvClassifDiscTime`, `PipeOpPredClassifSurvDiscTime`
5+
* New pipeline: `pipeline_survtoclassif`
6+
17
# mlr3proba 0.6.4
28

39
* Add useR! 2024 tutorial
4-
* Lots of refactoring, improve code quality (thanks to @m-muecke)
10+
* Lots of refactoring, improving code quality, migration to testthat v3, etc. (thanks to @m-muecke)
511

612
# mlr3proba 0.6.3
713

R/LearnerSurvCoxPH.R

Lines changed: 7 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -42,30 +42,20 @@ LearnerSurvCoxPH = R6Class("LearnerSurvCoxPH",
4242
pv$weights = task$weights$weight
4343
}
4444

45-
invoke(survival::coxph, formula = task$formula(), data = task$data(), .args = pv, x = TRUE)
45+
invoke(survival::coxph, formula = task$formula(), data = task$data(),
46+
.args = pv, x = TRUE)
4647
},
4748

4849
.predict = function(task) {
49-
50-
newdata = task$data(cols = task$feature_names)
51-
52-
# We move the missingness checks here manually as if any NAs are made in predictions then the
53-
# distribution object cannot be create (initialization of distr6 objects does not handle NAs)
54-
if (anyMissing(newdata)) {
55-
stopf(
56-
"Learner %s on task %s failed to predict: Missing values in new data (line(s) %s)\n",
57-
self$id, task$id,
58-
toString(which(!complete.cases(newdata)))
59-
)
60-
}
61-
50+
newdata = ordered_features(task, self)
6251
pv = self$param_set$get_values(tags = "predict")
6352

64-
# Get predicted values
53+
# Get survival predictions via `survfit`
6554
fit = invoke(survival::survfit, formula = self$model, newdata = newdata,
66-
se.fit = FALSE, .args = pv)
55+
se.fit = FALSE, .args = pv)
6756

68-
lp = predict(self$model, type = "lp", newdata = newdata)
57+
# Get linear predictors
58+
lp = invoke(predict, self$model, type = "lp", newdata = newdata)
6959

7060
.surv_return(times = fit$time, surv = t(fit$surv), lp = lp)
7161
}

R/LearnerSurvRpart.R

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -79,14 +79,15 @@ LearnerSurvRpart = R6Class("LearnerSurvRpart",
7979
pv = insert_named(pv, list(weights = task$weights$weight))
8080
}
8181

82-
invoke(rpart::rpart,
83-
formula = task$formula(), data = task$data(),
84-
method = "exp", .args = pv)
82+
invoke(rpart::rpart, formula = task$formula(), data = task$data(),
83+
method = "exp", .args = pv)
8584
},
8685

8786
.predict = function(task) {
88-
preds = invoke(predict, object = self$model, newdata = task$data(cols = task$feature_names))
89-
list(crank = preds)
87+
newdata = ordered_features(task, self)
88+
p = invoke(predict, object = self$model, newdata = newdata)
89+
90+
list(crank = p)
9091
}
9192
)
9293
)

R/PipeOpPredClassifSurvDiscTime.R

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
#' @title PipeOpPredClassifSurvDiscTime
2+
#' @name mlr_pipeops_trafopred_classifsurv_disctime
3+
#'
4+
#' @description
5+
#' Transform [PredictionClassif] to [PredictionSurv] by converting
6+
#' event probabilities of a pseudo status variable (discrete time hazards)
7+
#' to survival probabilities using the product rule (Tutz et al. 2016):
8+
#'
9+
#' \deqn{P_k = p_k\cdot ... \cdot p_1}
10+
#'
11+
#' Where:
12+
#' - We assume that continuous time is divided into time intervals
13+
#' \eqn{[0, t_1), [t_1, t_2), ..., [t_n, \infty)}
14+
#' - \eqn{P_k = P(T > t_k)} is the survival probability at time \eqn{t_k}
15+
#' - \eqn{h_k} is the discrete-time hazard (classifier prediction), i.e. the
16+
#' conditional probability for an event in the \eqn{k}-interval.
17+
#' - \eqn{p_k = 1 - h_k = P(T \ge t_k | T \ge t_{k-1})}
18+
#'
19+
#' @section Input and Output Channels:
20+
#' The input is a [PredictionClassif] and a [data.table][data.table::data.table]
21+
#' with the transformed data both generated by [PipeOpTaskSurvClassifDiscTime].
22+
#' The output is the input [PredictionClassif] transformed to a [PredictionSurv].
23+
#' Only works during prediction phase.
24+
#'
25+
#' @references
26+
#' `r format_bib("tutz_2016")`
27+
#' @family PipeOps
28+
#' @family Transformation PipeOps
29+
#' @export
30+
PipeOpPredClassifSurvDiscTime = R6Class(
31+
"PipeOpPredClassifSurvDiscTime",
32+
inherit = mlr3pipelines::PipeOp,
33+
34+
public = list(
35+
#' @description
36+
#' Creates a new instance of this [R6][R6::R6Class] class.
37+
#' @param id (character(1))\cr
38+
#' Identifier of the resulting object.
39+
initialize = function(id = "trafopred_classifsurv_disctime") {
40+
super$initialize(
41+
id = id,
42+
input = data.table(
43+
name = c("input", "transformed_data"),
44+
train = c("NULL", "data.table"),
45+
predict = c("PredictionClassif", "data.table")
46+
),
47+
output = data.table(
48+
name = "output",
49+
train = "NULL",
50+
predict = "PredictionSurv"
51+
)
52+
)
53+
}
54+
),
55+
56+
private = list(
57+
.predict = function(input) {
58+
pred = input[[1]]
59+
data = input[[2]]
60+
assert_true(!is.null(pred$prob))
61+
# probability of having the event (1) in each respective interval
62+
# is the discrete-time hazard
63+
data = cbind(data, dt_hazard = pred$prob[, "1"])
64+
65+
# From theory, convert hazards to surv as prod(1 - h(t))
66+
rows_per_id = nrow(data) / length(unique(data$id))
67+
surv = t(vapply(unique(data$id), function(unique_id) {
68+
cumprod(1 - data[data$id == unique_id, ][["dt_hazard"]])
69+
}, numeric(rows_per_id)))
70+
71+
pred_list = list()
72+
unique_end_times = sort(unique(data$tend))
73+
# coerce to distribution and crank
74+
pred_list = .surv_return(times = unique_end_times, surv = surv)
75+
76+
# select the real tend values by only selecting the last row of each id
77+
# basically a slightly more complex unique()
78+
real_tend = data$time2[seq_len(nrow(data)) %% rows_per_id == 0]
79+
80+
# select last row for every id
81+
data = as.data.table(data)
82+
id = ped_status = NULL # to fix note
83+
data = data[, .SD[.N, list(ped_status)], by = id]
84+
85+
# create prediction object
86+
p = PredictionSurv$new(
87+
row_ids = seq_row(data),
88+
crank = pred_list$crank, distr = pred_list$distr,
89+
truth = Surv(real_tend, as.integer(as.character(data$ped_status))))
90+
91+
list(p)
92+
},
93+
94+
.train = function(input) {
95+
self$state = list()
96+
list(input)
97+
}
98+
)
99+
)
100+
101+
register_pipeop("trafopred_classifsurv_disctime", PipeOpPredClassifSurvDiscTime)

0 commit comments

Comments
 (0)