Skip to content

Commit a35a4c4

Browse files
authored
Merge pull request #429 from mlr-org/brier_fix
Brier fix
2 parents e4a478b + e3766e8 commit a35a4c4

33 files changed

+636
-143
lines changed

NEWS.md

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,11 @@
11
# mlr3proba 0.7.1
22

3-
* Removed all `PipeOp`s and pipelines related to survival => regression reduction techniques (see #414)
4-
* Bug fix: `$predict_type` of `survtoclassif_disctime` and `survtoclassif_IPCW` was `prob` (classification type) and not `crank` (survival type)
3+
* cleanup: removed all `PipeOp`s and pipelines related to survival => regression reduction techniques (see #414)
4+
* fix: `$predict_type` of `survtoclassif_disctime` and `survtoclassif_IPCW` was `prob` (classification type) and not `crank` (survival type)
5+
* fix: G(t) is not filtered when `t_max|p_max` is specified in scoring rules (didn't influence evaluation at all)
6+
* docs: Clarified the use and impact of using `t_max` in scoring rules, added examples in scoring rules and AUC scores
7+
* feat: Added new argument `remove_obs` in scoring rules to remove observations with observed time `t > t_max` as a processing step to alleviate IPCW issues.
8+
This was before 'hard-coded' which made the Integrated Brier Score (`msr("surv.graf")`) differ minimally from other implementations and the original definition.
59

610
# mlr3proba 0.7.0
711

R/MeasureSurvChamblessAUC.R

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#'
1616
#' @family AUC survival measures
1717
#' @family lp survival measures
18+
#' @template example_auc_measures
1819
#' @export
1920
MeasureSurvChamblessAUC = R6Class("MeasureSurvChamblessAUC",
2021
inherit = MeasureSurvAUC,

R/MeasureSurvCindex.R

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@
6969
#'
7070
#' # Harrell's C-index evaluated up to a specific time horizon
7171
#' p$score(msr("surv.cindex", t_max = 97))
72+
#'
7273
#' # Harrell's C-index evaluated up to the time corresponding to 30% of censoring
7374
#' p$score(msr("surv.cindex", p_max = 0.3))
7475
#'

R/MeasureSurvDCalibration.R

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
#' @templateVar fullname MeasureSurvDCalibration
44
#'
55
#' @description
6+
#' `r lifecycle::badge("experimental")`
7+
#'
68
#' This calibration method is defined by calculating the following statistic:
79
#' \deqn{s = B/n \sum_i (P_i - n/B)^2}
810
#' where \eqn{B} is number of 'buckets' (that equally divide \eqn{[0,1]} into intervals),
@@ -12,8 +14,8 @@
1214
#' falls within the corresponding interval.
1315
#' This statistic assumes that censoring time is independent of death time.
1416
#'
15-
#' A model is well-calibrated if \eqn{s \sim Unif(B)}, tested with `chisq.test`
16-
#' (\eqn{p > 0.05} if well-calibrated).
17+
#' A model is well D-calibrated if \eqn{s \sim Unif(B)}, tested with `chisq.test`
18+
#' (\eqn{p > 0.05} if well-calibrated, i.e. higher p-values are preferred).
1719
#' Model \eqn{i} is better calibrated than model \eqn{j} if \eqn{s(i) < s(j)},
1820
#' meaning that *lower values* of this measure are preferred.
1921
#'
@@ -23,7 +25,7 @@
2325
#' is well-calibrated. If `chisq = FALSE` and `s` is the predicted value then you can manually
2426
#' compute the p.value with `pchisq(s, B - 1, lower.tail = FALSE)`.
2527
#'
26-
#' NOTE: This measure is still experimental both theoretically and in implementation. Results
28+
#' **NOTE**: This measure is still experimental both theoretically and in implementation. Results
2729
#' should therefore only be taken as an indicator of performance and not for
2830
#' conclusive judgements about model calibration.
2931
#'
@@ -38,11 +40,12 @@
3840
#' You can manually get the p-value by executing `pchisq(s, B - 1, lower.tail = FALSE)`.
3941
#' The null hypothesis is that the model is D-calibrated.
4042
#' - `truncate` (`double(1)`) \cr
41-
#' This parameter controls the upper bound of the output statistic,
42-
#' when `chisq` is `FALSE`. We use `truncate = Inf` by default but \eqn{10} may be sufficient
43-
#' for most purposes, which corresponds to a p-value of 0.35 for the chisq.test using
44-
#' \eqn{B = 10} buckets. Values \eqn{>10} translate to even lower p-values and thus
45-
#' less calibrated models. If the number of buckets \eqn{B} changes, you probably will want to
43+
#' This parameter controls the upper bound of the output statistic, when `chisq` is `FALSE`.
44+
#' We use `truncate = Inf` by default but values between \eqn{10-16} are sufficient
45+
#' for most purposes, which correspond to p-values of \eqn{0.35-0.06} for the `chisq.test` using
46+
#' the default \eqn{B = 10} buckets.
47+
#' Values \eqn{B > 10} translate to even lower p-values and thus less D-calibrated models.
48+
#' If the number of buckets \eqn{B} changes, you probably will want to
4649
#' change the `truncate` value as well to correspond to the same p-value significance.
4750
#' Note that truncation may severely limit automated tuning with this measure.
4851
#'

R/MeasureSurvGraf.R

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
#' @templateVar eps 1e-3
1212
#' @template param_eps
1313
#' @template param_erv
14+
#' @template param_remove_obs
1415
#'
1516
#' @aliases MeasureSurvBrier mlr_measures_surv.brier
1617
#'
@@ -25,13 +26,13 @@
2526
#' survival function \eqn{S_i(t)}, the *observation-wise* loss integrated across
2627
#' the time dimension up to the time cutoff \eqn{\tau^*}, is:
2728
#'
28-
#' \deqn{L_{ISBS}(S_i, t_i, \delta_i) = \text{I}(t_i \leq \tau^*) \int^{\tau^*}_0 \frac{S_i^2(\tau) \text{I}(t_i \leq \tau, \delta=1)}{G(t_i)} + \frac{(1-S_i(\tau))^2 \text{I}(t_i > \tau)}{G(\tau)} \ d\tau}
29+
#' \deqn{L_{ISBS}(S_i, t_i, \delta_i) = \int^{\tau^*}_0 \frac{S_i^2(\tau) \text{I}(t_i \leq \tau, \delta_i=1)}{G(t_i)} + \frac{(1-S_i(\tau))^2 \text{I}(t_i > \tau)}{G(\tau)} \ d\tau}
2930
#'
3031
#' where \eqn{G} is the Kaplan-Meier estimate of the censoring distribution.
3132
#'
3233
#' The **re-weighted ISBS** (RISBS) is:
3334
#'
34-
#' \deqn{L_{RISBS}(S_i, t_i, \delta_i) = \delta_i \text{I}(t_i \leq \tau^*) \int^{\tau^*}_0 \frac{S_i^2(\tau) \text{I}(t_i \leq \tau) + (1-S_i(\tau))^2 \text{I}(t_i > \tau)}{G(t_i)} \ d\tau}
35+
#' \deqn{L_{RISBS}(S_i, t_i, \delta_i) = \delta_i \frac{\int^{\tau^*}_0 S_i^2(\tau) \text{I}(t_i \leq \tau) + (1-S_i(\tau))^2 \text{I}(t_i > \tau) \ d\tau}{G(t_i)}}
3536
#'
3637
#' which is always weighted by \eqn{G(t_i)} and is equal to zero for a censored subject.
3738
#'
@@ -48,10 +49,11 @@
4849
#' @template details_tmax
4950
#'
5051
#' @references
51-
#' `r format_bib("graf_1999")`
52+
#' `r format_bib("graf_1999", "sonabend2024", "kvamme2023")`
5253
#'
5354
#' @family Probabilistic survival measures
5455
#' @family distr survival measures
56+
#' @template example_scoring_rules
5557
#' @export
5658
MeasureSurvGraf = R6Class("MeasureSurvGraf",
5759
inherit = MeasureSurv,
@@ -73,11 +75,12 @@ MeasureSurvGraf = R6Class("MeasureSurvGraf",
7375
se = p_lgl(default = FALSE),
7476
proper = p_lgl(default = FALSE),
7577
eps = p_dbl(0, 1, default = 1e-3),
76-
ERV = p_lgl(default = FALSE)
78+
ERV = p_lgl(default = FALSE),
79+
remove_obs = p_lgl(default = FALSE)
7780
)
7881
ps$set_values(
7982
integrated = TRUE, method = 2L, se = FALSE,
80-
proper = FALSE, eps = 1e-3, ERV = ERV
83+
proper = FALSE, eps = 1e-3, ERV = ERV, remove_obs = FALSE
8184
)
8285

8386
range = if (ERV) c(-Inf, 1) else c(0, Inf)
@@ -132,7 +135,7 @@ MeasureSurvGraf = R6Class("MeasureSurvGraf",
132135
truth = prediction$truth,
133136
distribution = prediction$data$distr, times = times,
134137
t_max = ps$t_max, p_max = ps$p_max, proper = ps$proper, train = train,
135-
eps = ps$eps
138+
eps = ps$eps, remove_obs = ps$remove_obs
136139
)
137140

138141
if (ps$se) {

R/MeasureSurvHungAUC.R

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#'
1616
#' @family AUC survival measures
1717
#' @family lp survival measures
18+
#' @template example_auc_measures
1819
#' @export
1920
MeasureSurvHungAUC = R6Class("MeasureSurvHungAUC",
2021
inherit = MeasureSurvAUC,

R/MeasureSurvIntLogloss.R

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
#' @templateVar eps 1e-3
1212
#' @template param_eps
1313
#' @template param_erv
14+
#' @template param_remove_obs
1415
#'
1516
#' @description
1617
#' Calculates the **Integrated Survival Log-Likelihood** (ISLL) or Integrated
@@ -23,13 +24,13 @@
2324
#' survival function \eqn{S_i(t)}, the *observation-wise* loss integrated across
2425
#' the time dimension up to the time cutoff \eqn{\tau^*}, is:
2526
#'
26-
#' \deqn{L_{ISLL}(S_i, t_i, \delta_i) = -\text{I}(t_i \leq \tau^*) \int^{\tau^*}_0 \frac{log[1-S_i(\tau)] \text{I}(t_i \leq \tau, \delta=1)}{G(t_i)} + \frac{\log[S_i(\tau)] \text{I}(t_i > \tau)}{G(\tau)} \ d\tau}
27+
#' \deqn{L_{ISLL}(S_i, t_i, \delta_i) = - \int^{\tau^*}_0 \frac{log[1-S_i(\tau)] \text{I}(t_i \leq \tau, \delta_i=1)}{G(t_i)} + \frac{\log[S_i(\tau)] \text{I}(t_i > \tau)}{G(\tau)} \ d\tau}
2728
#'
2829
#' where \eqn{G} is the Kaplan-Meier estimate of the censoring distribution.
2930
#'
3031
#' The **re-weighted ISLL** (RISLL) is:
3132
#'
32-
#' \deqn{L_{RISLL}(S_i, t_i, \delta_i) = -\delta_i \text{I}(t_i \leq \tau^*) \int^{\tau^*}_0 \frac{\log[1-S_i(\tau)]) \text{I}(t_i \leq \tau) + \log[S_i(\tau)] \text{I}(t_i > \tau)}{G(t_i)} \ d\tau}
33+
#' \deqn{L_{RISLL}(S_i, t_i, \delta_i) = -\delta_i \frac{\int^{\tau^*}_0 \log[1-S_i(\tau)]) \text{I}(t_i \leq \tau) + \log[S_i(\tau)] \text{I}(t_i > \tau) \ d\tau}{G(t_i)}}
3334
#'
3435
#' which is always weighted by \eqn{G(t_i)} and is equal to zero for a censored subject.
3536
#'
@@ -46,10 +47,11 @@
4647
#' @template details_tmax
4748
#'
4849
#' @references
49-
#' `r format_bib("graf_1999")`
50+
#' `r format_bib("graf_1999", "sonabend2024", "kvamme2023")`
5051
#'
5152
#' @family Probabilistic survival measures
5253
#' @family distr survival measures
54+
#' @template example_scoring_rules
5355
#' @export
5456
MeasureSurvIntLogloss = R6Class("MeasureSurvIntLogloss",
5557
inherit = MeasureSurv,
@@ -71,11 +73,12 @@ MeasureSurvIntLogloss = R6Class("MeasureSurvIntLogloss",
7173
se = p_lgl(default = FALSE),
7274
proper = p_lgl(default = FALSE),
7375
eps = p_dbl(0, 1, default = 1e-3),
74-
ERV = p_lgl(default = FALSE)
76+
ERV = p_lgl(default = FALSE),
77+
remove_obs = p_lgl(default = FALSE)
7578
)
7679
ps$set_values(
7780
integrated = TRUE, method = 2L, se = FALSE,
78-
proper = FALSE, eps = 1e-3, ERV = ERV
81+
proper = FALSE, eps = 1e-3, ERV = ERV, remove_obs = FALSE
7982
)
8083

8184
range = if (ERV) c(-Inf, 1) else c(0, Inf)
@@ -130,7 +133,7 @@ MeasureSurvIntLogloss = R6Class("MeasureSurvIntLogloss",
130133
truth = prediction$truth,
131134
distribution = prediction$data$distr, times = times,
132135
t_max = ps$t_max, p_max = ps$p_max, proper = ps$proper, train = train,
133-
eps = ps$eps
136+
eps = ps$eps, remove_obs = ps$remove_obs
134137
)
135138

136139
if (ps$se) {

R/MeasureSurvLogloss.R

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
#' Calculates the cross-entropy, or negative log-likelihood (NLL) or logarithmic (log), loss.
1111
#' @section Parameter details:
1212
#' - `IPCW` (`logical(1)`)\cr
13-
#' If `TRUE` (default) then returns the \eqn{L_{RNLL}} score (which is proper), otherwise the \eqn{L_{NLL}} score (improper).
13+
#' If `TRUE` (default) then returns the \eqn{L_{RNLL}} score (which is proper), otherwise the \eqn{L_{NLL}} score (improper). See Sonabend et al. (2024) for more details.
1414
#'
1515
#' @details
1616
#' The Log Loss, in the context of probabilistic predictions, is defined as the
@@ -33,6 +33,9 @@
3333
#'
3434
#' @template details_trainG
3535
#'
36+
#' @references
37+
#' `r format_bib("sonabend2024")`
38+
#'
3639
#' @family Probabilistic survival measures
3740
#' @family distr survival measures
3841
#' @export

R/MeasureSurvSchmid.R

Lines changed: 9 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
#' @templateVar eps 1e-3
1212
#' @template param_eps
1313
#' @template param_erv
14+
#' @template param_remove_obs
1415
#'
1516
#' @description
1617
#' Calculates the **Integrated Schmid Score** (ISS), aka integrated absolute loss.
@@ -22,27 +23,20 @@
2223
#' survival function \eqn{S_i(t)}, the *observation-wise* loss integrated across
2324
#' the time dimension up to the time cutoff \eqn{\tau^*}, is:
2425
#'
25-
#' \deqn{L_{ISS}(S_i, t_i, \delta_i) = \text{I}(t_i \leq \tau^*) \int^{\tau^*}_0 \frac{S_i(\tau) \text{I}(t_i \leq \tau, \delta=1)}{G(t_i)} + \frac{(1-S_i(\tau)) \text{I}(t_i > \tau)}{G(\tau)} \ d\tau}
26+
#' \deqn{L_{ISS}(S_i, t_i, \delta_i) = \int^{\tau^*}_0 \frac{S_i(\tau) \text{I}(t_i \leq \tau, \delta=1)}{G(t_i)} + \frac{(1-S_i(\tau)) \text{I}(t_i > \tau)}{G(\tau)} \ d\tau}
2627
#'
2728
#' where \eqn{G} is the Kaplan-Meier estimate of the censoring distribution.
2829
#'
2930
#' The **re-weighted ISS** (RISS) is:
3031
#'
31-
#' \deqn{L_{RISS}(S_i, t_i, \delta_i) = \delta_i \text{I}(t_i \leq \tau^*) \int^{\tau^*}_0 \frac{S_i(\tau) \text{I}(t_i \leq \tau) + (1-S_i(\tau)) \text{I}(t_i > \tau)}{G(t_i)} \ d\tau}
32+
#' \deqn{L_{RISS}(S_i, t_i, \delta_i) = \delta_i \frac{\int^{\tau^*}_0 S_i(\tau) \text{I}(t_i \leq \tau) + (1-S_i(\tau)) \text{I}(t_i > \tau) \ d\tau}{G(t_i)}}
3233
#'
3334
#' which is always weighted by \eqn{G(t_i)} and is equal to zero for a censored subject.
3435
#'
3536
#' To get a single score across all \eqn{N} observations of the test set, we
3637
#' return the average of the time-integrated observation-wise scores:
3738
#' \deqn{\sum_{i=1}^N L(S_i, t_i, \delta_i) / N}
3839
#'
39-
#'
40-
#' \deqn{L_{ISS}(S,t|t^*) = [(S(t^*))I(t \le t^*, \delta = 1)(1/G(t))] + [((1 - S(t^*)))I(t > t^*)(1/G(t^*))]}
41-
#' where \eqn{G} is the Kaplan-Meier estimate of the censoring distribution.
42-
#'
43-
#' The re-weighted ISS, RISS is given by
44-
#' \deqn{L_{RISS}(S,t|t^*) = [(S(t^*))I(t \le t^*, \delta = 1)(1/G(t))] + [((1 - S(t^*)))I(t > t^*)(1/G(t))]}
45-
#'
4640
#' @template properness
4741
#' @templateVar improper_id ISS
4842
#' @templateVar proper_id RISS
@@ -52,10 +46,11 @@
5246
#' @template details_tmax
5347
#'
5448
#' @references
55-
#' `r format_bib("schemper_2000", "schmid_2011")`
49+
#' `r format_bib("schemper_2000", "schmid_2011", "sonabend2024", "kvamme2023")`
5650
#'
5751
#' @family Probabilistic survival measures
5852
#' @family distr survival measures
53+
#' @template example_scoring_rules
5954
#' @export
6055
MeasureSurvSchmid = R6Class("MeasureSurvSchmid",
6156
inherit = MeasureSurv,
@@ -77,11 +72,12 @@ MeasureSurvSchmid = R6Class("MeasureSurvSchmid",
7772
se = p_lgl(default = FALSE),
7873
proper = p_lgl(default = FALSE),
7974
eps = p_dbl(0, 1, default = 1e-3),
80-
ERV = p_lgl(default = FALSE)
75+
ERV = p_lgl(default = FALSE),
76+
remove_obs = p_lgl(default = FALSE)
8177
)
8278
ps$set_values(
8379
integrated = TRUE, method = 2L, se = FALSE,
84-
proper = FALSE, eps = 1e-3, ERV = ERV
80+
proper = FALSE, eps = 1e-3, ERV = ERV, remove_obs = FALSE
8581
)
8682

8783
range = if (ERV) c(-Inf, 1) else c(0, Inf)
@@ -135,7 +131,7 @@ MeasureSurvSchmid = R6Class("MeasureSurvSchmid",
135131
truth = prediction$truth,
136132
distribution = prediction$data$distr, times = times,
137133
t_max = ps$t_max, p_max = ps$p_max, proper = ps$proper, train = train,
138-
eps = ps$eps
134+
eps = ps$eps, remove_obs = ps$remove_obs
139135
)
140136

141137
if (ps$se) {

R/MeasureSurvSongAUC.R

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
#'
1717
#' @family AUC survival measures
1818
#' @family lp survival measures
19+
#' @template example_auc_measures
1920
#' @export
2021
MeasureSurvSongAUC = R6Class("MeasureSurvSongAUC",
2122
inherit = MeasureSurvAUC,

0 commit comments

Comments
 (0)