Skip to content

Commit 8930d7f

Browse files
authored
Merge pull request #438 from mlr-org/cmprsk_improvements
Cmprsk improvements
2 parents 9bc5698 + d2436a9 commit 8930d7f

18 files changed

+308
-139
lines changed

DESCRIPTION

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
Package: mlr3proba
22
Title: Probabilistic Supervised Learning for 'mlr3'
3-
Version: 0.8.1
3+
Version: 0.8.2
44
Authors@R: c(
55
person("Raphael", "Sonabend", , "raphaelsonabend@gmail.com", role = "aut",
66
comment = c(ORCID = "0000-0001-9225-4654")),
@@ -22,9 +22,8 @@ Authors@R: c(
2222
person("Markus", "Goeswein", , "markus.goeswein@outlook.de", role = "ctb")
2323
)
2424
Description: Provides extensions for probabilistic supervised learning for
25-
'mlr3'. This includes extending the regression task to probabilistic
26-
and interval regression, adding a survival task, and other specialized
27-
models, predictions, and measures.
25+
'mlr3'. This currently includes survival analysis, competing risks and
26+
density estimation.
2827
License: LGPL-3
2928
URL: https://mlr3proba.mlr-org.com, https://github.com/mlr-org/mlr3proba
3029
BugReports: https://github.com/mlr-org/mlr3proba/issues
@@ -46,6 +45,7 @@ Suggests:
4645
abind,
4746
coxed,
4847
GGally,
48+
glmnet,
4949
knitr,
5050
lgr,
5151
lifecycle,
@@ -60,14 +60,14 @@ Suggests:
6060
set6 (>= 0.2.6),
6161
simsurv,
6262
survAUC,
63-
testthat (>= 3.0.0),
64-
glmnet
63+
testthat (>= 3.0.0)
6564
LinkingTo:
6665
Rcpp
6766
Remotes:
67+
jkropko/coxed@bc92e25,
6868
xoopR/distr6,
6969
xoopR/param6,
70-
xoopR/set6,
70+
xoopR/set6
7171
ByteCompile: true
7272
Config/testthat/edition: 3
7373
Encoding: UTF-8

NEWS.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
1-
# mlr3proba 0.8.1.9000
1+
# mlr3proba 0.8.2
22

3+
* fix: `coxed` package was removed from CRAN so now we install the latest working CRAN version (`0.3.3`) from GitHub
4+
* feat: event-weighted mean AUC(t) as default score in `msr("cmprsk.auc")`
35
* feat: The pipeops `PipeOpTaskSurvRegrPEM` and `PipeOpTaskSurvClassifDiscTime` now apply the data transformation on an internal validation task if it exists. This enables the use of e.g `xgboost` regression or classification learners with `early_stopping` on the corresponding pipelines.
46

57
# mlr3proba 0.8.1

R/LearnerCompRisksAalenJohansen.R

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,16 @@
77
#'
88
#' This learner estimates the Cumulative Incidence Function (CIF) for competing
99
#' risks using the empirical Aalen-Johansen (AJ) estimator.
10-
#' The probability of transitioning to each competing event is computed via the
11-
#' [survfit][survival::survfit.formula()] function.
10+
#'
11+
#' Transition probabilities to each event are computed from the training data via
12+
#' the [survfit][survival::survfit.formula()] function and predictions are made
13+
#' at all unique times (both events and censoring) observed in the training set.
1214
#'
1315
#' @references
1416
#' `r format_bib("aalen_1978")`
1517
#'
18+
#' @templateVar msr_id all
19+
#' @template example_cmprsk
1620
#' @export
1721
LearnerCompRisksAalenJohansen = R6Class("LearnerCompRisksAalenJohansen",
1822
inherit = LearnerCompRisks,
@@ -54,18 +58,18 @@ LearnerCompRisksAalenJohansen = R6Class("LearnerCompRisksAalenJohansen",
5458

5559
times = self$model$time # unique train set time points
5660
n_obs = task$nrow # number of test observations
57-
CIF = stats::setNames(vector("list", ncol(trans_mat)), colnames(trans_mat))
61+
cif_list = stats::setNames(vector("list", ncol(trans_mat)), colnames(trans_mat))
5862

59-
for (i in seq_along(CIF)) {
60-
CIF[[i]] = matrix(
63+
for (i in seq_along(cif_list)) {
64+
cif_list[[i]] = matrix(
6165
data = rep(trans_mat[, i], times = n_obs),
6266
nrow = n_obs,
6367
byrow = TRUE,
6468
dimnames = list(NULL, times)
6569
)
6670
}
6771

68-
list(cif = CIF)
72+
list(cif = cif_list)
6973
}
7074
)
7175
)

R/MeasureCompRisksAUC.R

Lines changed: 52 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -5,37 +5,42 @@
55
#' @aliases MeasureCompRisksAUC mlr_measures_cmprsk.auc
66
#'
77
#' @description
8-
#' Calculates the cause-specific ROC-AUC(t) at a **specific time point**,
9-
#' see Blanche et al. (2013).
10-
#' Can also return the mean AUC(t) over all competing causes.
8+
#' Calculates the cause-specific time-dependent ROC-AUC at a **specific time point**,
9+
#' as described in Blanche et al. (2013).
10+
#'
11+
#' By default, this measure returns a **cause-independent AUC(t)** score,
12+
#' calculated as a weighted average of the cause-specific AUCs.
13+
#' The weights correspond to the relative event frequencies of each cause,
14+
#' following Equation (7) in Heyard et al. (2020).
1115
#'
1216
#' @details
1317
#' Calls [riskRegression::Score()] with:
1418
#' - `metric = "auc"`
1519
#' - `cens.method = "ipcw"`
1620
#' - `cens.model = "km"`
1721
#'
18-
#' Note that the IPC weights (estimated via the Kaplan-Meier) are calculated
19-
#' using the test data.
22+
#' Notes on the `riskRegression` implementation:
23+
#' 1. IPCW weights are estimated using the **test data only**.
24+
#' 2. No extrapolation is supported: if `time_horizon` exceeds the maximum observed
25+
#' time on the test data, an error is thrown.
26+
#' 3. The choice of `time_horizon` is critical: if, at that time, no events of a
27+
#' given cause have occurred and all predicted CIFs are zero, `riskRegression`
28+
#' will return `NaN` for that cause-specific AUC (and subsequently for the
29+
#' summary AUC).
2030
#'
2131
#' @section Parameter details:
22-
#' - `cause` (`numeric(1)`)\cr
23-
#' Integer number indicating which cause to use (Default: `1`).
24-
#' If `"mean"`, then the mean AUC(t) over all causes is returned.
32+
#' - `cause` (`numeric(1)|"mean"`)\cr
33+
#' Integer number indicating which cause to use.
34+
#' Default value is `"mean"` which returns a weighted mean of the cause-specific AUCs.
2535
#' - `time_horizon` (`numeric(1)`)\cr
2636
#' Single time point at which to return the score.
27-
#' If `NULL`, we issue a warning and the median time from the test set is used.
37+
#' If `NULL`, the **median time point** from the test set is used.
2838
#'
2939
#' @references
30-
#' `r format_bib("blanche_2013")`
31-
#'
32-
#' @examplesIf mlr3misc::require_namespaces(c("riskRegression"), quietly = TRUE)
33-
#' t = tsk("pbc")
34-
#' l = lrn("cmprsk.aalen")
35-
#' p = l$train(t)$predict(t)
36-
#'
37-
#' p$score(msr("cmprsk.auc", time_horizon = 42))
40+
#' `r format_bib("blanche_2013", "heyard_2020")`
3841
#'
42+
#' @templateVar msr_id auc
43+
#' @template example_cmprsk
3944
#' @export
4045
MeasureCompRisksAUC = R6Class(
4146
"MeasureCompRisksAUC",
@@ -45,18 +50,16 @@ MeasureCompRisksAUC = R6Class(
4550
#' Creates a new instance of this [R6][R6::R6Class] class.
4651
initialize = function() {
4752
param_set = ps(
48-
cause = p_int(lower = 1, default = 1, special_vals = list("mean")),
53+
cause = p_int(lower = 1, init = "mean", special_vals = list("mean")),
4954
time_horizon = p_dbl(lower = 0, default = NULL, special_vals = list(NULL))
5055
)
5156

52-
param_set$set_values(cause = 1)
53-
5457
super$initialize(
5558
id = "cmprsk.auc",
5659
param_set = param_set,
5760
range = c(0, 1),
5861
minimize = FALSE,
59-
#properties = "requires_task", (only if we want `cen.model = cox`)
62+
properties = "na_score",
6063
packages = "riskRegression",
6164
label = "Blanche's Time-dependent IPCW ROC-AUC score",
6265
man = "mlr3proba::mlr_measures_cmprsk.auc"
@@ -68,33 +71,37 @@ MeasureCompRisksAUC = R6Class(
6871
.score = function(prediction, task, ...) {
6972
pv = self$param_set$values
7073

71-
# data with (time, event) columns for IPCW calculation
72-
# uses test set data as it needs to match predicted CIF rows/observations
73-
data = data.table(time = prediction$truth[, 1L], event = prediction$truth[, 2L])
74-
lhs = "Hist(time, event)"
75-
form = formulate(lhs, rhs = "1", env = getNamespace("prodlim"))
76-
77-
# single time point for AUC or median time
78-
if (is.null(pv$time)) {
79-
# TODO: add the warning again when this is not the default measure
80-
# warning("No time horizon specified. We use median time from the test set")
81-
time_horizon = median(data$time)
74+
# Prepare test set data (for IPCW)
75+
# uses test set observations as it needs to match exactly the number of
76+
# rows (observations) in the predicted CIF matrix
77+
data = data.table(
78+
time = prediction$truth[, 1L],
79+
event = prediction$truth[, 2L]
80+
)
81+
form = formulate(lhs = "Hist(time, event)", rhs = "1", env = getNamespace("prodlim"))
82+
83+
# Define evaluation time (single time point for AUC)
84+
time_horizon = if (is.null(pv$time_horizon)) {
85+
median(data$time)
8286
} else {
83-
time_horizon = assert_number(pv$time_horizon, lower = 0, finite = TRUE, na.ok = FALSE)
87+
assert_number(pv$time_horizon, lower = 0, finite = TRUE, na.ok = FALSE)
8488
}
8589

8690
# list of predicted CIF matrices
87-
cif = prediction$cif
91+
cif_list = prediction$cif
92+
causes = names(cif_list)
8893

8994
cause = pv$cause
90-
if (test_integerish(cause)) {
95+
if (test_int(cause)) {
96+
cause = as.character(cause)
97+
9198
# check if cause exists
92-
if (cause %nin% names(cif)) {
93-
stopf("Given cause (%i) is not included in the CIF causes", cause)
99+
if (cause %nin% causes) {
100+
stopf("Invalid cause. Use one of: %s", paste(causes, collapse = ", "))
94101
}
95102

96103
# get cause-specific CIF
97-
cif_mat = cif[[as.character(cause)]]
104+
cif_mat = cif_list[[cause]]
98105

99106
# get CIF on the time horizon
100107
mat = .interp_cif(cif_mat, eval_times = time_horizon)
@@ -112,10 +119,10 @@ MeasureCompRisksAUC = R6Class(
112119
times = NULL # fix: no global binding
113120
res$AUC$score[times == time_horizon][["AUC"]]
114121
} else {
115-
# iterate through cause-specific CIFs, get AUC(t), return the mean
116-
AUCs = sapply(names(cif), function(cause) {
122+
# iterate through cause-specific CIFs, get AUC(t)
123+
aucs = vapply(causes, function(cause) {
117124
# get cause-specific CIF
118-
cif_mat = cif[[cause]]
125+
cif_mat = cif_list[[cause]]
119126

120127
# get CIF on the time horizon
121128
mat = .interp_cif(cif_mat, eval_times = time_horizon)
@@ -132,10 +139,11 @@ MeasureCompRisksAUC = R6Class(
132139

133140
times = NULL # fix: no global binding
134141
res$AUC$score[times == time_horizon][["AUC"]]
135-
})
142+
}, numeric(1L))
136143

137-
# return mean (weighted?)
138-
mean(AUCs)
144+
event = data[event != 0, event] # remove censored obs (if they exist)
145+
w = prop.table(table(event)) # observed proportions per cause
146+
sum(w[names(aucs)] * aucs) # weighted mean
139147
}
140148
}
141149
)

R/PipeOpPredClassifSurvIPCW.R

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,15 +26,15 @@
2626
#' to a [PredictionSurv].
2727
#' Each input classification probability prediction corresponds to the
2828
#' probability of having the event up to the specified cutoff time
29-
#' \eqn{\hat{\pi}(\bold{X}_i) = P(T_i < \tau|\bold{X}_i)},
29+
#' \eqn{\hat{\pi}(\textbf{X}_i) = P(T_i < \tau|\textbf{X}_i)},
3030
#' see Vock et al. (2016) and [PipeOpTaskSurvClassifIPCW].
3131
#' Therefore, these predictions serve as **continuous risk scores** that can be
3232
#' directly interpreted as `crank` predictions in the right-censored survival
3333
#' setting. We also map them to the survival distribution prediction `distr`,
3434
#' at the specified cutoff time point \eqn{\tau}, i.e. as
35-
#' \eqn{S_i(\tau) = 1 - \hat{\pi}(\bold{X}_i)}.
35+
#' \eqn{S_i(\tau) = 1 - \hat{\pi}(\textbf{X}_i)}.
3636
#' Survival measures that use the survival distribution (eg [ISBS][mlr_measures_surv.brier])
37-
#' should be evaluated exactly at the cutoff time point \eqn{\tau}, see example.
37+
#' should be evaluated exactly at the cutoff time point \eqn{\tau}.
3838
#'
3939
#' @references
4040
#' `r format_bib("vock_2016")`

R/TaskGeneratorCoxed.R

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@
55
#' @description
66
#' A [TaskGenerator][mlr3::TaskGenerator] calling [coxed::sim.survdata()].
77
#'
8-
#' This generator creates a survival dataset using \CRANpkg{coxed}, and exposes
9-
#' some parameters from the `sim.survdata()` function.
8+
#' This generator creates a survival dataset using `coxed`, and exposes
9+
#' some parameters from the [coxed::sim.survdata()] function.
1010
#' We don't include the parameters `X` (user-specified variables), `covariate`,
1111
#' `low`, `high`, `compare`, `beta` and `hazard.fun` for this generator.
1212
#' The latter means that no user-specified hazard function can be used and the
@@ -18,7 +18,7 @@
1818
#' @template seealso_task_generator
1919
#' @references
2020
#' `r format_bib("harden_2019")`
21-
#' @examplesIf mlr3misc::require_namespaces(c("coxed"), quietly = TRUE)
21+
#' @examplesIf mlr3misc::require_namespaces("coxed", quietly = TRUE)
2222
#' library(mlr3)
2323
#'
2424
#' # time horizon = 365 days, censoring proportion = 60%, 6 covariates normally
@@ -72,8 +72,7 @@ TaskGeneratorCoxed = R6Class("TaskGeneratorCoxed",
7272
.generate = function(n) {
7373
require_namespaces("coxed")
7474

75-
pv = self$param_set$values
76-
data = invoke(coxed::sim.survdata, N = n, .args = pv)[[1]]
75+
data = invoke(coxed::sim.survdata, N = n, .args = self$param_set$values)[[1]]
7776
data = map_at(data, "failed", as.integer)
7877

7978
TaskSurv$new(id = self$id, backend = data, time = "y",

R/bibentries.R

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -449,5 +449,25 @@ bibentries = c(
449449
pages = "299--321",
450450
url = "https://doi.org/10.1177/1471082X17748083",
451451
year = "2018"
452+
),
453+
heyard_2020 = bibentry("article",
454+
author = "Heyard, Rachel and Timsit, Jean-Francois and Held, Leonhard",
455+
title = "Validation of discrete time-to-event prediction models in the presence of competing risks",
456+
journal = "Biometrical Journal",
457+
volume = "62",
458+
number = "3",
459+
pages = "643--657",
460+
year = "2020",
461+
url = "https://doi.org/10.1002/BIMJ.201800293"
462+
),
463+
schoop_2011 = bibentry("article",
464+
author = "Schoop, Roland and Beyersmann, Jan and Schumacher, Martin and Binder, Harald",
465+
title = "Quantifying the predictive accuracy of time-to-event models in the presence of competing risks",
466+
journal = "Biometrical Journal",
467+
volume = "53",
468+
number = "1",
469+
pages = "88--112",
470+
year = "2011",
471+
url = "https://doi.org/10.1002/BIMJ.201000073"
452472
)
453473
)

0 commit comments

Comments
 (0)