Skip to content

Commit 68650a2

Browse files
author
Raphael Sonabend
authored
Merge pull request #173 from mlr-org/surv.dcalib
add surv.dcalib
2 parents 5c82385 + d2d2a94 commit 68650a2

30 files changed

+356
-14
lines changed

DESCRIPTION

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
Package: mlr3proba
22
Title: Probabilistic Supervised Learning for 'mlr3'
3-
Version: 0.4.0.9000
3+
Version: 0.4.0
44
Authors@R:
55
c(person(given = "Raphael",
66
family = "Sonabend",
@@ -87,6 +87,7 @@ Collate:
8787
'MeasureSurvCalibrationBeta.R'
8888
'MeasureSurvChamblessAUC.R'
8989
'MeasureSurvCindex.R'
90+
'MeasureSurvDCalibration.R'
9091
'MeasureSurvGraf.R'
9192
'MeasureSurvHungAUC.R'
9293
'MeasureSurvIntLogloss.R'

NAMESPACE

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ export(MeasureSurvCalibrationAlpha)
3434
export(MeasureSurvCalibrationBeta)
3535
export(MeasureSurvChamblessAUC)
3636
export(MeasureSurvCindex)
37+
export(MeasureSurvDCalibration)
3738
export(MeasureSurvGraf)
3839
export(MeasureSurvHungAUC)
3940
export(MeasureSurvIntLogloss)

NEWS.md

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,11 @@
1-
# mlr3proba 0.4.0.9000
2-
3-
- Internal changes only.
4-
51

62
# mlr3proba 0.4.0
73

84
* Deprecated measures from 0.2.0 have now been deleted.
95
* IPCW measures such as `surv.graf`, `surv.schmid`, and `surv.intlogloss` now allow training data to be passed to the score function with `task` and `train_set` to allow the censoring distribution to be estimated on the training data. This is automatically applied for resample and benchmark results.
106
* IPCW measures such as `surv.graf`, `surv.schmid`, and `surv.intlogloss` now include a parameter `proper` to determine what weighting scheme should be applied by the estimated censoring distribution, The current method (Graf, 1999) `proper = FALSE`, weights observations either by their event time or 'current' time depending if they're dead or not, the new method `proper = TRUE` weights observations by event time. The `proper = TRUE` method is strictly proper when censoring and survival times are independent and G is estimated on large enough data. The `proper = FALSE` method is never proper. The default is currently `proper = FALSE` to enable backward compatibility, this will be changed to `proper = TRUE` in v0.6.0.
117
* The `rm_cens` parameter in `surv.logloss` has been deprecated in favour of `IPCW`. `rm_cens` will be removed in v0.6.0. If `rm_cens` or `IPCW` are `TRUE` then censored observations are removed and the score is weighted by an estimate of the censoring distribution at individual event times. Otherwise if `rm_cens` and `IPCW` are `FALSE` then no deletion or weighting takes place. The `IPCW = TRUE` method is strictly proper when censoring and survival times are independent and G is estimated on large enough data. The `ipcw = FALSE` method is never proper.
8+
* Add `surv.dcalib` for the D-Calibration measure from Haider et al. (2020).
129

1310
# mlr3proba 0.3.2
1411

R/MeasureSurvCindex.R

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,12 @@
1818
#' * `"SG"` = Weights concordance by S/G (Shemper et al.)
1919
#' * `"S"` = Weights concordance by S (Peto and Peto)
2020
#'
21-
#' The last three require training data.
21+
#' The last three require training data. `"GH"` is only applicable to [LearnerSurvCoxPH].
22+
#'
23+
#' @details
24+
#' The implementation is slightly different from [survival::concordance]. Firstly this
25+
#' implementation is faster, and secondly the weights are computed on the training dataset whereas
26+
#' in [survival::concordance] the weights are computed on the same testing data.
2227
#'
2328
#' @references
2429
#' `r format_bib("peto_1972", "harrell_1982", "goenen_2005", "schemper_2009", "uno_2011")`

R/MeasureSurvDCalibration.R

Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,119 @@
1+
#' @template surv_measure
2+
#' @templateVar title D-Calibration
3+
#' @templateVar fullname MeasureSurvDCalibration
4+
#'
5+
#' @description
6+
#' This calibration method is defined by calculating
7+
#' \deqn{s = B/n \sum_i (P_i - n/B)^2}
8+
#' where \eqn{B} is number of 'buckets', \eqn{n} is the number of predictions,
9+
#' and \eqn{P_i} is the predicted number of deaths in the \eqn{i}th interval
10+
#' [0, 100/B), [100/B, 50/B),....,[(B - 100)/B, 1).
11+
#'
12+
#' A model is well-calibrated if `s ~ Unif(B)`, tested with `chisq.test`
13+
#' (`p > 0.05` if well-calibrated).
14+
#' Model `i` is better calibrated than model `j` if `s_i < s_j`.
15+
#'
16+
#' @details
17+
#' This measure can either return the test statistic or the p-value from the `chisq.test`.
18+
#' The former is useful for model comparison whereas the latter is useful for determining if a model
19+
#' is well-calibration. If `chisq = FALSE` and `m` is the predicted value then you can manually
20+
#' compute the p.value with `pchisq(m, B - 1, lower.tail = FALSE)`.
21+
#'
22+
#' NOTE: This measure is still experimental both theoretically and in implementation. Results
23+
#' should therefore only be taken as an indicator of performance and not for
24+
#' conclusive judgements about model calibration.
25+
#'
26+
#' @references
27+
#' `r format_bib("haider_2020")`
28+
#'
29+
#' @family calibration survival measures
30+
#' @family distr survival measures
31+
#' @export
32+
MeasureSurvDCalibration = R6Class("MeasureSurvDCalibration",
33+
inherit = MeasureSurv,
34+
public = list(
35+
#' @description Creates a new instance of this [R6][R6::R6Class] class.
36+
#' @param B (`integer(1)`) \cr
37+
#' Number of buckets to test for uniform predictions over. Default of `10` is recommended by
38+
#' Haider et al. (2020).
39+
#' @param chisq (`logical(1)`) \cr
40+
#' If `TRUE` returns the p.value of the corresponding chisq.test instead of the measure.
41+
#' Otherwise this can be performed manually with `pchisq(m, B - 1, lower.tail = FALSE)`.
42+
#' `p > 0.05` indicates well-calibrated.
43+
initialize = function(B = 10L, chisq = FALSE) {
44+
super$initialize(
45+
id = "surv.dcalib",
46+
range = c(0, Inf),
47+
minimize = TRUE,
48+
predict_type = "distr",
49+
man = "mlr3proba::mlr_measures_surv.dcalib",
50+
)
51+
52+
private$.B = assert_integerish(B)
53+
private$.chisq = assert_flag(chisq)
54+
}
55+
),
56+
57+
active = list(
58+
#' @field B (`integer(1)`) \cr
59+
#' Number of buckets to test for uniform predictions over. Default of `10` is recommended by
60+
#' Haider et al. (2020).
61+
B = function(x) {
62+
if (!missing(x)) {
63+
private$.B = assert_integerish(x)
64+
} else {
65+
return(private$.B)
66+
}
67+
},
68+
69+
#' @field chisq `(logical(1))` \cr
70+
#' If `TRUE` returns the p.value of the corresponding chisq.test instead of the measure.
71+
#' Otherwise this can be performed manually with `pchisq(m, B - 1, lower.tail = FALSE)`.
72+
#' `p > 0.05` indicates well-calibrated.
73+
chisq = function(x) {
74+
if (!missing(x)) {
75+
private$.chisq = assert_flag(x)
76+
} else {
77+
return(private$.chisq)
78+
}
79+
}
80+
),
81+
82+
private = list(
83+
.B = 10L,
84+
.chisq = FALSE,
85+
.score = function(prediction, ...) {
86+
87+
# initialize buckets
88+
bj = numeric(self$B)
89+
# predict individual probability of death at observed event time
90+
si = as.numeric(prediction$distr$survival(data = matrix(prediction$truth[, 1L], nrow = 1L)))
91+
# remove zeros
92+
si = map_dbl(si, function(.x) max(.x, 1e-5))
93+
# index of associated bucket
94+
js = ceiling(self$B * si)
95+
96+
# could remove loop for dead observations but needed for censored ones and minimal overhead
97+
# in combining both
98+
for (i in seq_along(si)) {
99+
ji = js[[i]]
100+
if (prediction$truth[i, 2L] == 1L) {
101+
# dead observations contribute 1 to their index
102+
bj[ji] = bj[ji] + 1
103+
} else {
104+
# uncensored observations spread across buckets with most weighting on penultimate
105+
for (k in seq.int(ji - 1)) {
106+
bj[k] = bj[k] + 1/(self$B * si[[i]])
107+
}
108+
bj[ji] = bj[ji] + (1 - (ji - 1)/(self$B * si[[i]]))
109+
}
110+
}
111+
112+
if (self$chisq) {
113+
return(stats::chisq.test(bj)$p.value)
114+
} else {
115+
return((self$B/length(si)) * sum((bj - length(si)/self$B)^2))
116+
}
117+
}
118+
)
119+
)

R/bibentries.R

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -599,5 +599,16 @@ bibentries = c( # nolint start
599599
url = "https://www.jstor.org/stable/2335161",
600600
volume = "66",
601601
year = "1979"
602+
),
603+
604+
haider_2020 = bibentry("article",
605+
author = "Haider, Humza and Hoehn, Bret and Davis, Sarah and Greiner, Russell",
606+
journal = "Journal of Machine Learning Research",
607+
volume = "21",
608+
number = "85",
609+
pages = "1--63",
610+
title = "Effective Ways to Build and Evaluate Individual Survival Distributions",
611+
url = "http://jmlr.org/papers/v21/18-772.html",
612+
year = "2020"
602613
)
603614
) # nolint end

R/zzz.R

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,7 @@ register_mlr3 = function() {
9797

9898
x$add("surv.cindex", MeasureSurvCindex)
9999

100+
x$add("surv.dcalib", MeasureSurvDCalibration)
100101
x$add("surv.calib_beta", MeasureSurvCalibrationBeta)
101102
x$add("surv.calib_alpha", MeasureSurvCalibrationAlpha)
102103

@@ -166,12 +167,12 @@ register_mlr3pipelines = function() {
166167
pkgname = vapply(hooks[-1], function(x) environment(x)$pkgname, NA_character_)
167168
setHook(event, hooks[pkgname != "mlr3proba"], action = "replace")
168169

169-
event = packageEvent("mlr3pipelines", "onLoad")
170-
hooks = getHook(event)
171-
pkgname = vapply(hooks[-1], function(x) environment(x)$pkgname, NA_character_)
172-
setHook(event, hooks[pkgname != "mlr3proba"], action = "replace")
170+
event = packageEvent("mlr3pipelines", "onLoad")
171+
hooks = getHook(event)
172+
pkgname = vapply(hooks[-1], function(x) environment(x)$pkgname, NA_character_)
173+
setHook(event, hooks[pkgname != "mlr3proba"], action = "replace")
173174

174-
library.dynam.unload("mlr3proba", libpath)
175+
library.dynam.unload("mlr3proba", libpath)
175176
}
176177

177178
leanify_package()

man/mlr_measures_surv.calib_alpha.Rd

Lines changed: 4 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/mlr_measures_surv.calib_beta.Rd

Lines changed: 3 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/mlr_measures_surv.chambless_auc.Rd

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)