Skip to content

Cmprsk improvements #438

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 32 commits into from
Aug 11, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
3da27a3
document with noRd helper functions
bblodfon Jul 7, 2025
3409002
refactor a bit MeasureCompRisksAUC and add better doc
bblodfon Jul 7, 2025
548ba0c
add doc about which time poitns are used for prediction in AJ estimator
bblodfon Jul 9, 2025
35f9a70
better param check (cause)
bblodfon Jul 9, 2025
b1ffd8f
name refactor
bblodfon Jul 9, 2025
91add54
add example template for cmprsk measures
bblodfon Jul 9, 2025
fec2601
update examples
bblodfon Jul 9, 2025
f6276b4
some refactoring, cause = "mean" by default
bblodfon Jul 9, 2025
3d42d9d
return event-weighted mean AUC(t) as default
bblodfon Jul 10, 2025
ba6574a
improve sanity task for cmprsk
bblodfon Jul 10, 2025
0dfa1d8
add new paper
bblodfon Jul 10, 2025
66a6c9d
increase N for AJ autotest
bblodfon Jul 10, 2025
ed8dc32
better doc about why AUC(t) can return NaN
bblodfon Jul 10, 2025
f1e450b
improve example template
bblodfon Jul 10, 2025
7322149
update doc
bblodfon Jul 10, 2025
cde5fa8
Merge branch 'main' into cmprsk_improvements
bblodfon Jul 10, 2025
9705794
update version
bblodfon Jul 10, 2025
b6d58b5
replace non-ASCII character
bblodfon Jul 10, 2025
2924f6f
tidy up DESCRIPTION
bblodfon Jul 10, 2025
a79fc8d
fix: coxed out of CRAN
bblodfon Jul 10, 2025
be8bd00
update description doc
bblodfon Jul 10, 2025
910f9eb
remove coxed from test
bblodfon Jul 10, 2025
d85f309
update doc
bblodfon Jul 10, 2025
8a44c5b
get function of non-CRAN coxed package from namespace reflection
bblodfon Jul 11, 2025
4429f47
use require_namespaces
bblodfon Jul 11, 2025
5d463c7
add coxed in Suggests, refactor class code
bblodfon Jul 11, 2025
76dcbeb
increase task sample size in pipelines test
bblodfon Jul 11, 2025
ecd3478
add specific commit of coxed package to installs version 0.3.3 so tha…
bblodfon Jul 11, 2025
41d783c
update news
bblodfon Jul 11, 2025
e5df62d
add Prediction Error paper for CRs
bblodfon Jul 22, 2025
e462c20
relax autotest C-index requirement
bblodfon Aug 11, 2025
d2436a9
fix doc
bblodfon Aug 11, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
Package: mlr3proba
Title: Probabilistic Supervised Learning for 'mlr3'
Version: 0.8.1
Version: 0.8.2
Authors@R: c(
person("Raphael", "Sonabend", , "raphaelsonabend@gmail.com", role = "aut",
comment = c(ORCID = "0000-0001-9225-4654")),
Expand All @@ -22,9 +22,8 @@ Authors@R: c(
person("Markus", "Goeswein", , "markus.goeswein@outlook.de", role = "ctb")
)
Description: Provides extensions for probabilistic supervised learning for
'mlr3'. This includes extending the regression task to probabilistic
and interval regression, adding a survival task, and other specialized
models, predictions, and measures.
'mlr3'. This currently includes survival analysis, competing risks and
density estimation.
License: LGPL-3
URL: https://mlr3proba.mlr-org.com, https://github.com/mlr-org/mlr3proba
BugReports: https://github.com/mlr-org/mlr3proba/issues
Expand All @@ -46,6 +45,7 @@ Suggests:
abind,
coxed,
GGally,
glmnet,
knitr,
lgr,
lifecycle,
Expand All @@ -60,14 +60,14 @@ Suggests:
set6 (>= 0.2.6),
simsurv,
survAUC,
testthat (>= 3.0.0),
glmnet
testthat (>= 3.0.0)
LinkingTo:
Rcpp
Remotes:
jkropko/coxed@bc92e25,
xoopR/distr6,
xoopR/param6,
xoopR/set6,
xoopR/set6
ByteCompile: true
Config/testthat/edition: 3
Encoding: UTF-8
Expand Down
4 changes: 3 additions & 1 deletion NEWS.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# mlr3proba 0.8.1.9000
# mlr3proba 0.8.2

* fix: `coxed` package was removed from CRAN so now we install the latest working CRAN version (`0.3.3`) from GitHub
* feat: event-weighted mean AUC(t) as default score in `msr("cmprsk.auc")`
* feat: The pipeops `PipeOpTaskSurvRegrPEM` and `PipeOpTaskSurvClassifDiscTime` now apply the data transformation on an internal validation task if it exists. This enables the use of e.g `xgboost` regression or classification learners with `early_stopping` on the corresponding pipelines.

# mlr3proba 0.8.1
Expand Down
16 changes: 10 additions & 6 deletions R/LearnerCompRisksAalenJohansen.R
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,16 @@
#'
#' This learner estimates the Cumulative Incidence Function (CIF) for competing
#' risks using the empirical Aalen-Johansen (AJ) estimator.
#' The probability of transitioning to each competing event is computed via the
#' [survfit][survival::survfit.formula()] function.
#'
#' Transition probabilities to each event are computed from the training data via
#' the [survfit][survival::survfit.formula()] function and predictions are made
#' at all unique times (both events and censoring) observed in the training set.
#'
#' @references
#' `r format_bib("aalen_1978")`
#'
#' @templateVar msr_id all
#' @template example_cmprsk
#' @export
LearnerCompRisksAalenJohansen = R6Class("LearnerCompRisksAalenJohansen",
inherit = LearnerCompRisks,
Expand Down Expand Up @@ -54,18 +58,18 @@ LearnerCompRisksAalenJohansen = R6Class("LearnerCompRisksAalenJohansen",

times = self$model$time # unique train set time points
n_obs = task$nrow # number of test observations
CIF = stats::setNames(vector("list", ncol(trans_mat)), colnames(trans_mat))
cif_list = stats::setNames(vector("list", ncol(trans_mat)), colnames(trans_mat))

for (i in seq_along(CIF)) {
CIF[[i]] = matrix(
for (i in seq_along(cif_list)) {
cif_list[[i]] = matrix(
data = rep(trans_mat[, i], times = n_obs),
nrow = n_obs,
byrow = TRUE,
dimnames = list(NULL, times)
)
}

list(cif = CIF)
list(cif = cif_list)
}
)
)
Expand Down
96 changes: 52 additions & 44 deletions R/MeasureCompRisksAUC.R
Original file line number Diff line number Diff line change
Expand Up @@ -5,37 +5,42 @@
#' @aliases MeasureCompRisksAUC mlr_measures_cmprsk.auc
#'
#' @description
#' Calculates the cause-specific ROC-AUC(t) at a **specific time point**,
#' see Blanche et al. (2013).
#' Can also return the mean AUC(t) over all competing causes.
#' Calculates the cause-specific time-dependent ROC-AUC at a **specific time point**,
#' as described in Blanche et al. (2013).
#'
#' By default, this measure returns a **cause-independent AUC(t)** score,
#' calculated as a weighted average of the cause-specific AUCs.
#' The weights correspond to the relative event frequencies of each cause,
#' following Equation (7) in Heyard et al. (2020).
#'
#' @details
#' Calls [riskRegression::Score()] with:
#' - `metric = "auc"`
#' - `cens.method = "ipcw"`
#' - `cens.model = "km"`
#'
#' Note that the IPC weights (estimated via the Kaplan-Meier) are calculated
#' using the test data.
#' Notes on the `riskRegression` implementation:
#' 1. IPCW weights are estimated using the **test data only**.
#' 2. No extrapolation is supported: if `time_horizon` exceeds the maximum observed
#' time on the test data, an error is thrown.
#' 3. The choice of `time_horizon` is critical: if, at that time, no events of a
#' given cause have occurred and all predicted CIFs are zero, `riskRegression`
#' will return `NaN` for that cause-specific AUC (and subsequently for the
#' summary AUC).
#'
#' @section Parameter details:
#' - `cause` (`numeric(1)`)\cr
#' Integer number indicating which cause to use (Default: `1`).
#' If `"mean"`, then the mean AUC(t) over all causes is returned.
#' - `cause` (`numeric(1)|"mean"`)\cr
#' Integer number indicating which cause to use.
#' Default value is `"mean"` which returns a weighted mean of the cause-specific AUCs.
#' - `time_horizon` (`numeric(1)`)\cr
#' Single time point at which to return the score.
#' If `NULL`, we issue a warning and the median time from the test set is used.
#' If `NULL`, the **median time point** from the test set is used.
#'
#' @references
#' `r format_bib("blanche_2013")`
#'
#' @examplesIf mlr3misc::require_namespaces(c("riskRegression"), quietly = TRUE)
#' t = tsk("pbc")
#' l = lrn("cmprsk.aalen")
#' p = l$train(t)$predict(t)
#'
#' p$score(msr("cmprsk.auc", time_horizon = 42))
#' `r format_bib("blanche_2013", "heyard_2020")`
#'
#' @templateVar msr_id auc
#' @template example_cmprsk
#' @export
MeasureCompRisksAUC = R6Class(
"MeasureCompRisksAUC",
Expand All @@ -45,18 +50,16 @@ MeasureCompRisksAUC = R6Class(
#' Creates a new instance of this [R6][R6::R6Class] class.
initialize = function() {
param_set = ps(
cause = p_int(lower = 1, default = 1, special_vals = list("mean")),
cause = p_int(lower = 1, init = "mean", special_vals = list("mean")),
time_horizon = p_dbl(lower = 0, default = NULL, special_vals = list(NULL))
)

param_set$set_values(cause = 1)

super$initialize(
id = "cmprsk.auc",
param_set = param_set,
range = c(0, 1),
minimize = FALSE,
#properties = "requires_task", (only if we want `cen.model = cox`)
properties = "na_score",
packages = "riskRegression",
label = "Blanche's Time-dependent IPCW ROC-AUC score",
man = "mlr3proba::mlr_measures_cmprsk.auc"
Expand All @@ -68,33 +71,37 @@ MeasureCompRisksAUC = R6Class(
.score = function(prediction, task, ...) {
pv = self$param_set$values

# data with (time, event) columns for IPCW calculation
# uses test set data as it needs to match predicted CIF rows/observations
data = data.table(time = prediction$truth[, 1L], event = prediction$truth[, 2L])
lhs = "Hist(time, event)"
form = formulate(lhs, rhs = "1", env = getNamespace("prodlim"))

# single time point for AUC or median time
if (is.null(pv$time)) {
# TODO: add the warning again when this is not the default measure
# warning("No time horizon specified. We use median time from the test set")
time_horizon = median(data$time)
# Prepare test set data (for IPCW)
# uses test set observations as it needs to match exactly the number of
# rows (observations) in the predicted CIF matrix
data = data.table(
time = prediction$truth[, 1L],
event = prediction$truth[, 2L]
)
form = formulate(lhs = "Hist(time, event)", rhs = "1", env = getNamespace("prodlim"))

# Define evaluation time (single time point for AUC)
time_horizon = if (is.null(pv$time_horizon)) {
median(data$time)
} else {
time_horizon = assert_number(pv$time_horizon, lower = 0, finite = TRUE, na.ok = FALSE)
assert_number(pv$time_horizon, lower = 0, finite = TRUE, na.ok = FALSE)
}

# list of predicted CIF matrices
cif = prediction$cif
cif_list = prediction$cif
causes = names(cif_list)

cause = pv$cause
if (test_integerish(cause)) {
if (test_int(cause)) {
cause = as.character(cause)

# check if cause exists
if (cause %nin% names(cif)) {
stopf("Given cause (%i) is not included in the CIF causes", cause)
if (cause %nin% causes) {
stopf("Invalid cause. Use one of: %s", paste(causes, collapse = ", "))
}

# get cause-specific CIF
cif_mat = cif[[as.character(cause)]]
cif_mat = cif_list[[cause]]

# get CIF on the time horizon
mat = .interp_cif(cif_mat, eval_times = time_horizon)
Expand All @@ -112,10 +119,10 @@ MeasureCompRisksAUC = R6Class(
times = NULL # fix: no global binding
res$AUC$score[times == time_horizon][["AUC"]]
} else {
# iterate through cause-specific CIFs, get AUC(t), return the mean
AUCs = sapply(names(cif), function(cause) {
# iterate through cause-specific CIFs, get AUC(t)
aucs = vapply(causes, function(cause) {
# get cause-specific CIF
cif_mat = cif[[cause]]
cif_mat = cif_list[[cause]]

# get CIF on the time horizon
mat = .interp_cif(cif_mat, eval_times = time_horizon)
Expand All @@ -132,10 +139,11 @@ MeasureCompRisksAUC = R6Class(

times = NULL # fix: no global binding
res$AUC$score[times == time_horizon][["AUC"]]
})
}, numeric(1L))

# return mean (weighted?)
mean(AUCs)
event = data[event != 0, event] # remove censored obs (if they exist)
w = prop.table(table(event)) # observed proportions per cause
sum(w[names(aucs)] * aucs) # weighted mean
}
}
)
Expand Down
6 changes: 3 additions & 3 deletions R/PipeOpPredClassifSurvIPCW.R
Original file line number Diff line number Diff line change
Expand Up @@ -26,15 +26,15 @@
#' to a [PredictionSurv].
#' Each input classification probability prediction corresponds to the
#' probability of having the event up to the specified cutoff time
#' \eqn{\hat{\pi}(\bold{X}_i) = P(T_i < \tau|\bold{X}_i)},
#' \eqn{\hat{\pi}(\textbf{X}_i) = P(T_i < \tau|\textbf{X}_i)},
#' see Vock et al. (2016) and [PipeOpTaskSurvClassifIPCW].
#' Therefore, these predictions serve as **continuous risk scores** that can be
#' directly interpreted as `crank` predictions in the right-censored survival
#' setting. We also map them to the survival distribution prediction `distr`,
#' at the specified cutoff time point \eqn{\tau}, i.e. as
#' \eqn{S_i(\tau) = 1 - \hat{\pi}(\bold{X}_i)}.
#' \eqn{S_i(\tau) = 1 - \hat{\pi}(\textbf{X}_i)}.
#' Survival measures that use the survival distribution (eg [ISBS][mlr_measures_surv.brier])
#' should be evaluated exactly at the cutoff time point \eqn{\tau}, see example.
#' should be evaluated exactly at the cutoff time point \eqn{\tau}.
#'
#' @references
#' `r format_bib("vock_2016")`
Expand Down
9 changes: 4 additions & 5 deletions R/TaskGeneratorCoxed.R
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
#' @description
#' A [TaskGenerator][mlr3::TaskGenerator] calling [coxed::sim.survdata()].
#'
#' This generator creates a survival dataset using \CRANpkg{coxed}, and exposes
#' some parameters from the `sim.survdata()` function.
#' This generator creates a survival dataset using `coxed`, and exposes
#' some parameters from the [coxed::sim.survdata()] function.
#' We don't include the parameters `X` (user-specified variables), `covariate`,
#' `low`, `high`, `compare`, `beta` and `hazard.fun` for this generator.
#' The latter means that no user-specified hazard function can be used and the
Expand All @@ -18,7 +18,7 @@
#' @template seealso_task_generator
#' @references
#' `r format_bib("harden_2019")`
#' @examplesIf mlr3misc::require_namespaces(c("coxed"), quietly = TRUE)
#' @examplesIf mlr3misc::require_namespaces("coxed", quietly = TRUE)
#' library(mlr3)
#'
#' # time horizon = 365 days, censoring proportion = 60%, 6 covariates normally
Expand Down Expand Up @@ -72,8 +72,7 @@ TaskGeneratorCoxed = R6Class("TaskGeneratorCoxed",
.generate = function(n) {
require_namespaces("coxed")

pv = self$param_set$values
data = invoke(coxed::sim.survdata, N = n, .args = pv)[[1]]
data = invoke(coxed::sim.survdata, N = n, .args = self$param_set$values)[[1]]
data = map_at(data, "failed", as.integer)

TaskSurv$new(id = self$id, backend = data, time = "y",
Expand Down
20 changes: 20 additions & 0 deletions R/bibentries.R
Original file line number Diff line number Diff line change
Expand Up @@ -449,5 +449,25 @@ bibentries = c(
pages = "299--321",
url = "https://doi.org/10.1177/1471082X17748083",
year = "2018"
),
heyard_2020 = bibentry("article",
author = "Heyard, Rachel and Timsit, Jean-Francois and Held, Leonhard",
title = "Validation of discrete time-to-event prediction models in the presence of competing risks",
journal = "Biometrical Journal",
volume = "62",
number = "3",
pages = "643--657",
year = "2020",
url = "https://doi.org/10.1002/BIMJ.201800293"
),
schoop_2011 = bibentry("article",
author = "Schoop, Roland and Beyersmann, Jan and Schumacher, Martin and Binder, Harald",
title = "Quantifying the predictive accuracy of time-to-event models in the presence of competing risks",
journal = "Biometrical Journal",
volume = "53",
number = "1",
pages = "88--112",
year = "2011",
url = "https://doi.org/10.1002/BIMJ.201000073"
)
)
Loading