Skip to content

Commit dc2832a

Browse files
authored
Merge pull request #338 from mlr-org/breslow
Breslow estimator
2 parents af918e9 + b6cc254 commit dc2832a

30 files changed

+1156
-986
lines changed

DESCRIPTION

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
Package: mlr3proba
22
Title: Probabilistic Supervised Learning for 'mlr3'
3-
Version: 0.5.6
3+
Version: 0.5.7
44
Authors@R:
55
c(person(given = "Raphael",
66
family = "Sonabend",
@@ -117,6 +117,7 @@ Collate:
117117
'MeasureSurvUnoTNR.R'
118118
'MeasureSurvUnoTPR.R'
119119
'MeasureSurvXuR2.R'
120+
'PipeOpBreslow.R'
120121
'PipeOpCrankCompositor.R'
121122
'PipeOpDistrCompositor.R'
122123
'PipeOpTransformer.R'
@@ -146,6 +147,7 @@ Collate:
146147
'assertions.R'
147148
'autoplot.R'
148149
'bibentries.R'
150+
'breslow.R'
149151
'cindex.R'
150152
'data.R'
151153
'helpers.R'

NAMESPACE

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ export(MeasureSurvUnoAUC)
6969
export(MeasureSurvUnoTNR)
7070
export(MeasureSurvUnoTPR)
7171
export(MeasureSurvXuR2)
72+
export(PipeOpBreslow)
7273
export(PipeOpCrankCompositor)
7374
export(PipeOpDistrCompositor)
7475
export(PipeOpPredRegrSurv)
@@ -91,6 +92,7 @@ export(as_prediction_surv)
9192
export(as_task_dens)
9293
export(as_task_surv)
9394
export(assert_surv)
95+
export(breslow)
9496
export(crankcompositor)
9597
export(distrcompositor)
9698
export(pecs)

NEWS.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,9 @@
1+
# mlr3proba 0.5.7
2+
3+
* Add `breslow` function for estimating the cumulative baseline hazard of proportional hazard models
4+
* Add `PipeOpBreslow` to wrap a survival learner and generate `distr` predictions from `lp` predictions
5+
* Add option `breslow` estimator option in `distrcompositor`
6+
17
# mlr3proba 0.5.6
28

39
* Add `extend_quantile` to `autoplot.PredictionSurv` for `type = "dcalib"`, which imputes NAs with the maximum observed survival time

R/PipeOpBreslow.R

Lines changed: 180 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,180 @@
1+
#' @title Wrap a learner into a PipeOp with survival predictions estimated by the Breslow estimator
2+
#' @name mlr_pipeops_compose_breslow_distr
3+
#' @template param_pipelines
4+
#' @description
5+
#' Composes a survival distribution (`distr`) using the linear predictor
6+
#' predictions (`lp`) from a given [LearnerSurv] during training and prediction,
7+
#' utilizing the [breslow estimator][breslow]. The specified `learner` must be
8+
#' capable of generating `lp`-type predictions (e.g., a Cox-type model).
9+
#'
10+
#' @section Dictionary:
11+
#' This [PipeOp][mlr3pipelines::PipeOp] can be instantiated via the
12+
#' [Dictionary][mlr3misc::Dictionary] [mlr_pipeops][mlr3pipelines::mlr_pipeops]
13+
#' or with the associated sugar function [po()][mlr3pipelines::po]:
14+
#' ```
15+
#' PipeOpBreslow$new(learner)
16+
#' mlr_pipeops$get("breslowcompose", learner)
17+
#' po("breslowcompose", learner, breslow.overwrite = TRUE)
18+
#' ```
19+
#'
20+
#' @section Input and Output Channels:
21+
#' [PipeOpBreslow] is like a [LearnerSurv].
22+
#' It has one input channel, named `input` that takes a [TaskSurv] during training
23+
#' and another [TaskSurv] during prediction.
24+
#' [PipeOpBreslow] has one output channel named `output`, producing `NULL` during
25+
#' training and a [PredictionSurv] during prediction.
26+
#'
27+
#' @section State:
28+
#' The `$state` slot stores the `times` and `status` survival target variables of
29+
#' the train [TaskSurv] as well as the `lp` predictions on the train set.
30+
#'
31+
#' @section Parameters:
32+
#' The parameters are:
33+
#' * `breslow.overwrite` :: `logical(1)` \cr
34+
#' If `FALSE` (default) then the compositor does nothing and returns the
35+
#' input `learner`'s [PredictionSurv].
36+
#' If `TRUE` or in the case that the input `learner` doesn't have `distr`
37+
#' predictions, then the `distr` is overwritten with the `distr` composed
38+
#' from `lp` and the train set information using [breslow].
39+
#' This is useful for changing the prediction `distr` from one model form to
40+
#' another.
41+
#' @seealso [pipeline_distrcompositor]
42+
#' @export
43+
#' @family survival compositors
44+
#' @examples
45+
#' \dontrun{
46+
#' if (requireNamespace("mlr3pipelines", quietly = TRUE)) {
47+
#' library(mlr3)
48+
#' library(mlr3pipelines)
49+
#' task = tsk("rats")
50+
#' part = partition(task, ratio = 0.8)
51+
#' train_task = task$clone()$filter(part$train)
52+
#' test_task = task$clone()$filter(part$test)
53+
#'
54+
#' learner = lrn("surv.coxph") # learner with lp predictions
55+
#' b = po("breslowcompose", learner = learner, breslow.overwrite = TRUE)
56+
#'
57+
#' b$train(list(train_task))
58+
#' p = b$predict(list(test_task))[[1L]]
59+
#' }
60+
#' }
61+
PipeOpBreslow = R6Class("PipeOpBreslow",
62+
inherit = mlr3pipelines::PipeOp,
63+
public = list(
64+
#' @description
65+
#' Creates a new instance of this [R6][R6::R6Class] class.
66+
#' @param learner ([LearnerSurv])\cr
67+
#' Survival learner which must provide `lp`-type predictions
68+
#' @param id (character(1))\cr
69+
#' Identifier of the resulting object. If `NULL` (default), it will be set
70+
#' as the `id` of the input `learner`.
71+
initialize = function(learner, id = NULL, param_vals = list()) {
72+
assert_learner(learner, task_type = "surv")
73+
if ("lp" %nin% learner$predict_types) {
74+
stopf("Learner %s must provide lp predictions", learner$id)
75+
}
76+
77+
# id of the PipeOp is the id of the learner
78+
private$.learner = as_learner(learner, clone = TRUE)
79+
private$.learner$param_set$set_id = ""
80+
id = id %??% private$.learner$id
81+
82+
# define `breslow.overwrite` parameter
83+
private$.breslow_ps = ps(
84+
overwrite = p_lgl(default = FALSE, tags = c("predict", "required"))
85+
)
86+
private$.breslow_ps$values = list(overwrite = FALSE)
87+
private$.breslow_ps$set_id = "breslow"
88+
89+
super$initialize(
90+
id = id,
91+
param_set = alist(private$.breslow_ps, private$.learner$param_set),
92+
param_vals = param_vals,
93+
input = data.table(name = "input", train = "TaskSurv", predict = "TaskSurv"),
94+
output = data.table(name = "output", train = "NULL", predict = "PredictionSurv"),
95+
packages = private$.learner$packages
96+
)
97+
}
98+
),
99+
100+
active = list(
101+
#' @field learner \cr
102+
#' The input survival learner.
103+
learner = function(rhs) {
104+
assert_ro_binding(rhs)
105+
private$.learner
106+
}
107+
),
108+
109+
private = list(
110+
.train = function(inputs) {
111+
task = inputs[[1]]
112+
learner = private$.learner
113+
114+
# train learner
115+
learner$train(task)
116+
117+
# predictions on the train set
118+
p = learner$predict(task)
119+
120+
# Breslow works only with non-NA `lp` predictions
121+
if (anyMissing(p$lp)) {
122+
stopf("Missing lp predictions")
123+
}
124+
125+
# keep the training data that Breslow estimator needs
126+
self$state = list(
127+
times = task$times(),
128+
status = task$status(),
129+
lp_train = p$lp
130+
)
131+
132+
list(NULL)
133+
},
134+
135+
.predict = function(inputs) {
136+
task = inputs[[1]]
137+
learner = private$.learner
138+
139+
if (is.null(learner$model)) {
140+
stopf("Cannot predict, Learner '%s' has not been trained yet", learner$id)
141+
}
142+
143+
# predictions on the test set
144+
p = learner$predict(task)
145+
146+
pv = self$param_set$get_values(tags = "predict")
147+
overwrite = pv$breslow.overwrite
148+
149+
# If learner predicts `distr` and overwrite is FALSE don't use breslow
150+
if ("distr" %in% learner$predict_types & !overwrite) {
151+
pred = list(p)
152+
} else {
153+
# Breslow works only with non-NA `lp` predictions
154+
if (anyMissing(p$lp)) {
155+
stopf("Missing lp predictions!")
156+
}
157+
158+
distr = breslow(
159+
times = self$state$times,
160+
status = self$state$status,
161+
lp_train = self$state$lp_train,
162+
lp_test = p$lp
163+
)
164+
165+
pred = list(PredictionSurv$new(
166+
row_ids = p$row_ids,
167+
truth = p$truth,
168+
crank = p$crank,
169+
lp = p$lp,
170+
distr = distr
171+
))
172+
}
173+
174+
pred
175+
},
176+
177+
.breslow_ps = NULL,
178+
.learner = NULL
179+
)
180+
)

0 commit comments

Comments
 (0)