Skip to content

Commit ed6c351

Browse files
authored
Merge pull request #352 from mlr-org/calib_alpha_diff
measure optimizations and documentation updates
2 parents 4bc2513 + b552e6d commit ed6c351

File tree

102 files changed

+1644
-685
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

102 files changed

+1644
-685
lines changed

DESCRIPTION

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
Package: mlr3proba
22
Title: Probabilistic Supervised Learning for 'mlr3'
3-
Version: 0.5.9
3+
Version: 0.6.0
44
Authors@R:
55
c(person(given = "Raphael",
66
family = "Sonabend",
@@ -79,10 +79,10 @@ Encoding: UTF-8
7979
LazyData: true
8080
NeedsCompilation: no
8181
Roxygen: list(markdown = TRUE, r6 = TRUE)
82-
RoxygenNote: 7.3.1
82+
RoxygenNote: 7.3.1.9000
8383
Collate:
84-
'aaa.R'
8584
'LearnerDens.R'
85+
'aaa.R'
8686
'LearnerDensHistogram.R'
8787
'LearnerDensKDE.R'
8888
'LearnerSurv.R'

NAMESPACE

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,7 @@ importFrom(stats,reformulate)
116116
importFrom(stats,sd)
117117
importFrom(survival,Surv)
118118
importFrom(utils,data)
119+
importFrom(utils,getFromNamespace)
119120
importFrom(utils,head)
120121
importFrom(utils,tail)
121122
useDynLib(mlr3proba)

NEWS.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,9 @@
1+
# mlr3proba 0.6.0
2+
3+
* Optimized `surv.logloss` and `calib_alpha` measures (bypassing `distr6`)
4+
* Update/refine all measure docs (namign conventions from upcoming scoring rules paper) + doc templates
5+
* fix very rare bugs in `calib_alpha`, `surv.logloss` and `surv.graf` (version with proper = FALSE)
6+
17
# mlr3proba 0.5.9
28

39
* Fix several old issues (#348, #301, #281)

R/LearnerDensHistogram.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,5 +39,5 @@ LearnerDensHistogram = R6::R6Class("LearnerDensHistogram",
3939
)
4040
)
4141

42-
#' @include zzz.R
42+
#' @include aaa.R
4343
register_learner("dens.hist", LearnerDensHistogram)

R/MeasureDensLogloss.R

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,15 @@
11
#' @template dens_measure
2-
#' @templateVar title Log loss
2+
#' @templateVar title Log Loss
33
#' @templateVar inherit [MeasureDens]
44
#' @templateVar fullname MeasureDensLogloss
5-
#' @templateVar pars eps = 1e-15
6-
#' @templateVar eps_par TRUE
7-
#'
5+
#' @templateVar eps 1e-15
86
#' @template param_eps
97
#'
108
#' @description
119
#' Calculates the cross-entropy, or logarithmic (log), loss.
1210
#'
13-
#' The logloss, in the context of probabilistic predictions, is defined as the negative log
11+
#' @details
12+
#' The Log Loss, in the context of probabilistic predictions, is defined as the negative log
1413
#' probability density function, \eqn{f}, evaluated at the observed value, \eqn{y},
1514
#' \deqn{L(f, y) = -\log(f(y))}{L(f, y) = -log(f(y))}
1615
#'

R/MeasureRegrLogloss.R

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,15 @@
11
#' @template regr_measure
2-
#' @templateVar title Log loss
2+
#' @templateVar title Log Loss
33
#' @templateVar inherit [MeasureRegr]
44
#' @templateVar fullname MeasureRegrLogloss
5-
#' @templateVar pars eps = 1e-15
6-
#' @templateVar eps_par TRUE
7-
#'
5+
#' @templateVar eps 1e-15
86
#' @template param_eps
97
#'
108
#' @description
119
#' Calculates the cross-entropy, or logarithmic (log), loss.
1210
#'
13-
#' The logloss, in the context of probabilistic predictions, is defined as the negative log
11+
#' @details
12+
#' The Log Loss, in the context of probabilistic predictions, is defined as the negative log
1413
#' probability density function, \eqn{f}, evaluated at the observed value, \eqn{y},
1514
#' \deqn{L(f, y) = -\log(f(y))}{L(f, y) = -log(f(y))}
1615
#'

R/MeasureSurv.R

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121
#' @template param_packages
2222
#' @template param_label
2323
#' @template param_man
24-
#' @template param_se
2524
#'
2625
#' @family Measure
2726
#' @seealso
@@ -32,6 +31,8 @@ MeasureSurv = R6Class("MeasureSurv",
3231
public = list(
3332
#' @description
3433
#' Creates a new instance of this [R6][R6::R6Class] class.
34+
#' @param se If `TRUE` then returns standard error of the
35+
#' measure otherwise returns the mean (default).
3536
initialize = function(id, param_set = ps(), range, minimize = NA, aggregator = NULL,
3637
properties = character(), predict_type = "distr", task_properties = character(),
3738
packages = character(), label = NA_character_, man = NA_character_, se = FALSE) {

R/MeasureSurvCalibrationAlpha.R

Lines changed: 83 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,37 @@
11
#' @template surv_measure
2-
#' @templateVar title Van Houwelingen's Alpha
2+
#' @templateVar title Van Houwelingen's Calibration Alpha
33
#' @templateVar fullname MeasureSurvCalibrationAlpha
4-
#'
5-
#' @template param_se
4+
#' @templateVar eps 1e-3
5+
#' @template param_eps
66
#'
77
#' @description
88
#' This calibration method is defined by estimating
9-
#' \deqn{\alpha = \sum \delta_i / \sum H_i(t_i)}
10-
#' where \eqn{\delta} is the observed censoring indicator from the test data, \eqn{H_i} is the
11-
#' predicted cumulative hazard, and \eqn{t_i} is the observed survival time.
9+
#' \deqn{\hat{\alpha} = \sum \delta_i / \sum H_i(T_i)}
10+
#' where \eqn{\delta} is the observed censoring indicator from the test data,
11+
#' \eqn{H_i} is the predicted cumulative hazard, and \eqn{T_i} is the observed
12+
#' survival time (event or censoring).
1213
#'
1314
#' The standard error is given by
14-
#' \deqn{exp(1/\sqrt{\sum \delta_i})}
15+
#' \deqn{\hat{\alpha_{se}} = exp(1/\sqrt{\sum \delta_i})}
16+
#'
17+
#' The model is well calibrated if the estimated \eqn{\hat{\alpha}} coefficient
18+
#' (returned score) is equal to 1.
1519
#'
16-
#' The model is well calibrated if the estimated \eqn{\alpha} coefficient is equal to 1.
20+
#' @section Parameter details:
21+
#' - `se` (`logical(1)`)\cr
22+
#' If `TRUE` then return standard error of the measure, otherwise the score
23+
#' itself (default).
24+
#' - `method` (`character(1)`)\cr
25+
#' Returns \eqn{\hat{\alpha}} if equal to `ratio` (default) and
26+
#' \eqn{|1-\hat{\alpha}|} if equal to `diff`.
27+
#' With `diff`, the output score can be minimized and for example be used for
28+
#' tuning purposes. This parameter takes effect only if `se` is `FALSE`.
29+
#' - `truncate` (`double(1)`) \cr
30+
#' This parameter controls the upper bound of the output score.
31+
#' We use `truncate = Inf` by default (so no truncation) and it's up to the user
32+
#' **to set this up reasonably** given the chosen `method`.
33+
#' Note that truncation may severely limit automated tuning with this measure
34+
#' using `method = diff`.
1735
#'
1836
#' @references
1937
#' `r format_bib("vanhouwelingen_2000")`
@@ -25,16 +43,25 @@ MeasureSurvCalibrationAlpha = R6Class("MeasureSurvCalibrationAlpha",
2543
inherit = MeasureSurv,
2644
public = list(
2745
#' @description Creates a new instance of this [R6][R6::R6Class] class.
28-
initialize = function() {
46+
#' @param method defines which output score to return, see "Parameter
47+
#' details" section.
48+
initialize = function(method = "ratio") {
49+
assert_choice(method, choices = c("ratio", "diff"))
50+
2951
ps = ps(
30-
se = p_lgl(default = FALSE)
52+
eps = p_dbl(0, 1, default = 1e-3),
53+
se = p_lgl(default = FALSE),
54+
method = p_fct(c("ratio", "diff"), default = "ratio"),
55+
truncate = p_dbl(lower = -Inf, upper = Inf, default = Inf)
3156
)
32-
ps$values$se = FALSE
57+
ps$values = list(eps = 1e-3, se = FALSE, method = method, truncate = Inf)
58+
range = if (method == "ratio") c(-Inf, Inf) else c(0, Inf)
59+
minimize = ifelse(method == "ratio", FALSE, TRUE)
3360

3461
super$initialize(
3562
id = "surv.calib_alpha",
36-
range = c(-Inf, Inf),
37-
minimize = FALSE,
63+
range = range,
64+
minimize = minimize,
3865
predict_type = "distr",
3966
label = "Van Houwelingen's Alpha",
4067
man = "mlr3proba::mlr_measures_surv.calib_alpha",
@@ -45,21 +72,54 @@ MeasureSurvCalibrationAlpha = R6Class("MeasureSurvCalibrationAlpha",
4572

4673
private = list(
4774
.score = function(prediction, ...) {
48-
deaths = sum(prediction$truth[, 2])
75+
truth = prediction$truth
76+
all_times = truth[, 1] # both event times and censoring times
77+
status = truth[, 2]
78+
deaths = sum(status)
4979

50-
if (self$param_set$values$se) {
80+
ps = self$param_set$values
81+
if (ps$se) {
5182
return(exp(1 / sqrt(deaths)))
5283
} else {
53-
if (inherits(prediction$distr, "VectorDistribution")) {
54-
haz = as.numeric(prediction$distr$cumHazard(
55-
data = matrix(prediction$truth[, 1], nrow = 1)
56-
))
84+
distr = prediction$data$distr
85+
86+
# Bypass distr6 construction if underlying distr represented by array
87+
if (inherits(distr, "array")) {
88+
surv = distr
89+
if (length(dim(surv)) == 3) {
90+
# survival 3d array, extract median
91+
surv = .ext_surv_mat(arr = surv, which.curve = 0.5)
92+
}
93+
times = as.numeric(colnames(surv))
94+
95+
extend_times_cdf = getFromNamespace("C_Vec_WeightedDiscreteCdf", ns = "distr6")
96+
# get survival probability for each test obs at observed time
97+
surv_all = diag(
98+
extend_times_cdf(all_times, times, cdf = t(1 - surv), FALSE, FALSE)
99+
)
100+
101+
# H(t) = -log(S(t))
102+
cumhaz = -log(surv_all)
57103
} else {
58-
haz = diag(prediction$distr$cumHazard(prediction$truth[, 1]))
104+
if (inherits(distr, "VectorDistribution")) {
105+
cumhaz = as.numeric(
106+
distr$cumHazard(data = matrix(all_times, nrow = 1))
107+
)
108+
} else {
109+
cumhaz = diag(as.matrix(distr$cumHazard(all_times)))
110+
}
59111
}
60-
# cumulative hazard should only be infinite if only censoring occurs at the final time-point
61-
haz[haz == Inf] = 0
62-
return(deaths / sum(haz))
112+
113+
# Inf => case where censoring occurs at last time point
114+
# 0 => case where survival probabilities are all 1
115+
cumhaz[cumhaz == Inf | cumhaz == 0] = ps$eps
116+
out = deaths / sum(cumhaz)
117+
118+
if (ps$method == "diff") {
119+
out = abs(1 - out)
120+
}
121+
122+
return(min(ps$truncate, out))
63123
}
64124
}
65125
)

R/MeasureSurvCalibrationBeta.R

Lines changed: 45 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,30 @@
11
#' @template surv_measure
2-
#' @templateVar title Van Houwelingen's Beta
2+
#' @templateVar title Van Houwelingen's Calibration Beta
33
#' @templateVar fullname MeasureSurvCalibrationBeta
44
#'
5-
#' @template param_se
6-
#'
75
#' @description
8-
#' This calibration method fits the predicted linear predictor from a Cox PH model as the only
9-
#' predictor in a new Cox PH model with the test data as the response.
10-
#' \deqn{h(t|x) = h_0(t)exp(l\beta)}
11-
#' where \eqn{l} is the predicted linear predictor.
6+
#' This calibration method fits the predicted linear predictor from a Cox PH
7+
#' model as the only predictor in a new Cox PH model with the test data as
8+
#' the response.
9+
#' \deqn{h(t|x) = h_0(t)exp(\beta \times lp)}
10+
#' where \eqn{lp} is the predicted linear predictor on the test data.
11+
#'
12+
#' The model is well calibrated if the estimated \eqn{\hat{\beta}} coefficient
13+
#' (returned score) is equal to 1.
1214
#'
13-
#' The model is well calibrated if the estimated \eqn{\beta} coefficient is equal to 1.
15+
#' **Note**: Assumes fitted model is Cox PH (i.e. has an `lp` prediction type).
1416
#'
15-
#' Assumes fitted model is Cox PH.
17+
#' @section Parameter details:
18+
#' - `se` (`logical(1)`)\cr
19+
#' If `TRUE` then return standard error of the measure which is the standard
20+
#' error of the estimated coefficient \eqn{se_{\hat{\beta}}} from the Cox PH model.
21+
#' If `FALSE` (default) then returns the estimated coefficient \eqn{\hat{\beta}}.
22+
#' - `method` (`character(1)`)\cr
23+
#' Returns \eqn{\hat{\beta}} if equal to `ratio` (default) and \eqn{|1-\hat{\beta}|}
24+
#' if `diff`.
25+
#' With `diff`, the output score can be minimized and for example be used for
26+
#' tuning purposes.
27+
#' This parameter takes effect only if `se` is `FALSE`.
1628
#'
1729
#' @references
1830
#' `r format_bib("vanhouwelingen_2000")`
@@ -24,16 +36,23 @@ MeasureSurvCalibrationBeta = R6Class("MeasureSurvCalibrationBeta",
2436
inherit = MeasureSurv,
2537
public = list(
2638
#' @description Creates a new instance of this [R6][R6::R6Class] class.
27-
initialize = function() {
39+
#' @param method defines which output score to return, see "Parameter
40+
#' details" section.
41+
initialize = function(method = "ratio") {
42+
assert_choice(method, choices = c("ratio", "diff"))
43+
2844
ps = ps(
29-
se = p_lgl(default = FALSE)
45+
se = p_lgl(default = FALSE),
46+
method = p_fct(c("ratio", "diff"), default = "ratio")
3047
)
31-
ps$values$se = FALSE
48+
ps$values = list(se = FALSE, method = method)
49+
range = if (method == "ratio") c(-Inf, Inf) else c(0, Inf)
50+
minimize = ifelse(method == "ratio", FALSE, TRUE)
3251

3352
super$initialize(
3453
id = "surv.calib_beta",
35-
range = c(-Inf, Inf),
36-
minimize = FALSE,
54+
range = range,
55+
minimize = minimize,
3756
predict_type = "lp",
3857
label = "Van Houwelingen's Beta",
3958
man = "mlr3proba::mlr_measures_surv.calib_beta",
@@ -44,16 +63,24 @@ MeasureSurvCalibrationBeta = R6Class("MeasureSurvCalibrationBeta",
4463

4564
private = list(
4665
.score = function(prediction, ...) {
47-
4866
df = data.frame(truth = prediction$truth, lp = prediction$lp)
4967
fit = try(summary(survival::coxph(truth ~ lp, data = df)), silent = TRUE)
68+
5069
if (class(fit)[1] == "try-error") {
5170
return(NA)
5271
} else {
53-
if (self$param_set$values$se) {
54-
return(fit$coefficients[3])
72+
ps = self$param_set$values
73+
74+
if (ps$se) {
75+
return(fit$coefficients[,"se(coef)"])
5576
} else {
56-
return(fit$coefficients[1])
77+
out = fit$coefficients[,"coef"]
78+
79+
if (ps$method == "diff") {
80+
out = abs(1 - out)
81+
}
82+
83+
return(out)
5784
}
5885
}
5986
}

R/MeasureSurvChamblessAUC.R

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,15 @@
11
#' @template surv_measure
22
#' @templateVar title Chambless and Diao's AUC
33
#' @templateVar fullname MeasureSurvChamblessAUC
4+
#' @template measure_survAUC
5+
#' @template param_integrated
6+
#' @template param_times
47
#'
58
#' @description
69
#' Calls [survAUC::AUC.cd()].
710
#'
811
#' Assumes Cox PH model specification.
912
#'
10-
#' @template param_integrated
11-
#' @template param_times
12-
#' @template measure_survAUC
13-
#'
1413
#' @references
1514
#' `r format_bib("chambless_2006")`
1615
#'
@@ -24,8 +23,8 @@ MeasureSurvChamblessAUC = R6Class("MeasureSurvChamblessAUC",
2423
#' Creates a new instance of this [R6][R6::R6Class] class.
2524
initialize = function() {
2625
ps = ps(
27-
times = p_uty(),
28-
integrated = p_lgl(default = TRUE)
26+
integrated = p_lgl(default = TRUE),
27+
times = p_uty()
2928
)
3029
ps$values$integrated = TRUE
3130

0 commit comments

Comments
 (0)