5
5
# ' @aliases MeasureCompRisksAUC mlr_measures_cmprsk.auc
6
6
# '
7
7
# ' @description
8
- # ' Calculates the cause-specific ROC-AUC(t) at a **specific time point**,
9
- # ' see Blanche et al. (2013).
10
- # ' Can also return the mean AUC(t) over all competing causes.
8
+ # ' Calculates the cause-specific time-dependent ROC-AUC at a **specific time point**,
9
+ # ' as described in Blanche et al. (2013).
10
+ # '
11
+ # ' By default, this measure returns a **cause-independent AUC(t)** score,
12
+ # ' calculated as a weighted average of the cause-specific AUCs.
13
+ # ' The weights correspond to the relative event frequencies of each cause,
14
+ # ' following Equation (7) in Heyard et al. (2020).
11
15
# '
12
16
# ' @details
13
17
# ' Calls [riskRegression::Score()] with:
14
18
# ' - `metric = "auc"`
15
19
# ' - `cens.method = "ipcw"`
16
20
# ' - `cens.model = "km"`
17
21
# '
18
- # ' Note that the IPC weights (estimated via the Kaplan-Meier) are calculated
19
- # ' using the test data.
22
+ # ' Notes on the `riskRegression` implementation:
23
+ # ' 1. IPCW weights are estimated using the **test data only**.
24
+ # ' 2. No extrapolation is supported: if `time_horizon` exceeds the maximum observed
25
+ # ' time on the test data, an error is thrown.
26
+ # ' 3. The choice of `time_horizon` is critical: if, at that time, no events of a
27
+ # ' given cause have occurred and all predicted CIFs are zero, `riskRegression`
28
+ # ' will return `NaN` for that cause-specific AUC (and subsequently for the
29
+ # ' summary AUC).
20
30
# '
21
31
# ' @section Parameter details:
22
- # ' - `cause` (`numeric(1)`)\cr
23
- # ' Integer number indicating which cause to use (Default: `1`) .
24
- # ' If `"mean"`, then the mean AUC(t) over all causes is returned .
32
+ # ' - `cause` (`numeric(1)|"mean" `)\cr
33
+ # ' Integer number indicating which cause to use.
34
+ # ' Default value is `"mean"` which returns a weighted mean of the cause-specific AUCs .
25
35
# ' - `time_horizon` (`numeric(1)`)\cr
26
36
# ' Single time point at which to return the score.
27
- # ' If `NULL`, we issue a warning and the median time from the test set is used.
37
+ # ' If `NULL`, the ** median time point** from the test set is used.
28
38
# '
29
39
# ' @references
30
- # ' `r format_bib("blanche_2013")`
31
- # '
32
- # ' @examplesIf mlr3misc::require_namespaces(c("riskRegression"), quietly = TRUE)
33
- # ' t = tsk("pbc")
34
- # ' l = lrn("cmprsk.aalen")
35
- # ' p = l$train(t)$predict(t)
36
- # '
37
- # ' p$score(msr("cmprsk.auc", time_horizon = 42))
40
+ # ' `r format_bib("blanche_2013", "heyard_2020")`
38
41
# '
42
+ # ' @templateVar msr_id auc
43
+ # ' @template example_cmprsk
39
44
# ' @export
40
45
MeasureCompRisksAUC = R6Class(
41
46
" MeasureCompRisksAUC" ,
@@ -45,18 +50,16 @@ MeasureCompRisksAUC = R6Class(
45
50
# ' Creates a new instance of this [R6][R6::R6Class] class.
46
51
initialize = function () {
47
52
param_set = ps(
48
- cause = p_int(lower = 1 , default = 1 , special_vals = list (" mean" )),
53
+ cause = p_int(lower = 1 , init = " mean " , special_vals = list (" mean" )),
49
54
time_horizon = p_dbl(lower = 0 , default = NULL , special_vals = list (NULL ))
50
55
)
51
56
52
- param_set $ set_values(cause = 1 )
53
-
54
57
super $ initialize(
55
58
id = " cmprsk.auc" ,
56
59
param_set = param_set ,
57
60
range = c(0 , 1 ),
58
61
minimize = FALSE ,
59
- # properties = "requires_task", (only if we want `cen.model = cox`)
62
+ properties = " na_score " ,
60
63
packages = " riskRegression" ,
61
64
label = " Blanche's Time-dependent IPCW ROC-AUC score" ,
62
65
man = " mlr3proba::mlr_measures_cmprsk.auc"
@@ -68,33 +71,37 @@ MeasureCompRisksAUC = R6Class(
68
71
.score = function (prediction , task , ... ) {
69
72
pv = self $ param_set $ values
70
73
71
- # data with (time, event) columns for IPCW calculation
72
- # uses test set data as it needs to match predicted CIF rows/observations
73
- data = data.table(time = prediction $ truth [, 1L ], event = prediction $ truth [, 2L ])
74
- lhs = " Hist(time, event)"
75
- form = formulate(lhs , rhs = " 1" , env = getNamespace(" prodlim" ))
76
-
77
- # single time point for AUC or median time
78
- if (is.null(pv $ time )) {
79
- # TODO: add the warning again when this is not the default measure
80
- # warning("No time horizon specified. We use median time from the test set")
81
- time_horizon = median(data $ time )
74
+ # Prepare test set data (for IPCW)
75
+ # uses test set observations as it needs to match exactly the number of
76
+ # rows (observations) in the predicted CIF matrix
77
+ data = data.table(
78
+ time = prediction $ truth [, 1L ],
79
+ event = prediction $ truth [, 2L ]
80
+ )
81
+ form = formulate(lhs = " Hist(time, event)" , rhs = " 1" , env = getNamespace(" prodlim" ))
82
+
83
+ # Define evaluation time (single time point for AUC)
84
+ time_horizon = if (is.null(pv $ time_horizon )) {
85
+ median(data $ time )
82
86
} else {
83
- time_horizon = assert_number(pv $ time_horizon , lower = 0 , finite = TRUE , na.ok = FALSE )
87
+ assert_number(pv $ time_horizon , lower = 0 , finite = TRUE , na.ok = FALSE )
84
88
}
85
89
86
90
# list of predicted CIF matrices
87
- cif = prediction $ cif
91
+ cif_list = prediction $ cif
92
+ causes = names(cif_list )
88
93
89
94
cause = pv $ cause
90
- if (test_integerish(cause )) {
95
+ if (test_int(cause )) {
96
+ cause = as.character(cause )
97
+
91
98
# check if cause exists
92
- if (cause %nin % names( cif ) ) {
93
- stopf(" Given cause (%i) is not included in the CIF causes " , cause )
99
+ if (cause %nin % causes ) {
100
+ stopf(" Invalid cause. Use one of: %s " , paste( causes , collapse = " , " ) )
94
101
}
95
102
96
103
# get cause-specific CIF
97
- cif_mat = cif [[as.character( cause ) ]]
104
+ cif_mat = cif_list [[ cause ]]
98
105
99
106
# get CIF on the time horizon
100
107
mat = .interp_cif(cif_mat , eval_times = time_horizon )
@@ -112,10 +119,10 @@ MeasureCompRisksAUC = R6Class(
112
119
times = NULL # fix: no global binding
113
120
res $ AUC $ score [times == time_horizon ][[" AUC" ]]
114
121
} else {
115
- # iterate through cause-specific CIFs, get AUC(t), return the mean
116
- AUCs = sapply(names( cif ) , function (cause ) {
122
+ # iterate through cause-specific CIFs, get AUC(t)
123
+ aucs = vapply( causes , function (cause ) {
117
124
# get cause-specific CIF
118
- cif_mat = cif [[cause ]]
125
+ cif_mat = cif_list [[cause ]]
119
126
120
127
# get CIF on the time horizon
121
128
mat = .interp_cif(cif_mat , eval_times = time_horizon )
@@ -132,10 +139,11 @@ MeasureCompRisksAUC = R6Class(
132
139
133
140
times = NULL # fix: no global binding
134
141
res $ AUC $ score [times == time_horizon ][[" AUC" ]]
135
- })
142
+ }, numeric ( 1L ) )
136
143
137
- # return mean (weighted?)
138
- mean(AUCs )
144
+ event = data [event != 0 , event ] # remove censored obs (if they exist)
145
+ w = prop.table(table(event )) # observed proportions per cause
146
+ sum(w [names(aucs )] * aucs ) # weighted mean
139
147
}
140
148
}
141
149
)
0 commit comments