Skip to content

refactor: use inherits, toString, explicit integer minor formatting #394

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 14 commits into from
Jul 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 10 additions & 5 deletions DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
Package: mlr3proba
Title: Probabilistic Supervised Learning for 'mlr3'
Version: 0.6.3
Version: 0.6.4
Authors@R:
c(person(given = "Raphael",
family = "Sonabend",
Expand Down Expand Up @@ -30,11 +30,16 @@ Authors@R:
role = c("cre", "aut"),
email = "bblodfon@gmail.com",
comment = c(ORCID = "0000-0002-3609-8674")),
person(given = "Lukas",
person(given = "Lukas",
family = "Burk",
email = "github@quantenbrot.de",
email = "github@quantenbrot.de",
role = "ctb",
comment = c(ORCID = "0000-0001-7528-3795")))
comment = c(ORCID = "0000-0001-7528-3795")),
person(given = "Maximilian",
family = "Muecke",
email = "muecke.maximilian@gmail.com",
role = "ctb",
comment = c(ORCID = "0009-0000-9432-9795")))
Description: Provides extensions for probabilistic supervised learning for
'mlr3'. This includes extending the regression task to probabilistic
and interval regression, adding a survival task, and other specialized
Expand Down Expand Up @@ -86,7 +91,7 @@ Encoding: UTF-8
LazyData: true
NeedsCompilation: no
Roxygen: list(markdown = TRUE, r6 = TRUE)
RoxygenNote: 7.3.1
RoxygenNote: 7.3.2
Collate:
'LearnerDens.R'
'aaa.R'
Expand Down
5 changes: 5 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
# mlr3proba 0.6.4

* Add useR! 2024 tutorial
* Lots of refactoring, improve code quality (thanks to @m-muecke)

# mlr3proba 0.6.3

* Add new tasks from `survival` package: `veteran`, `pbc`, `mgus`, `gbsg`
Expand Down
4 changes: 2 additions & 2 deletions R/LearnerDensHistogram.R
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,12 @@ LearnerDensHistogram = R6::R6Class("LearnerDensHistogram",
private = list(
.train = function(task) {
pars = self$param_set$get_values(tags = "train")
fit = invoke(.histogram, dat = task$data()[[1]], .args = pars)
fit = invoke(.histogram, dat = task$data()[[1L]], .args = pars)
set_class(list(distr = fit$distr, hist = fit$hist), "dens.hist")
},

.predict = function(task) {
newdata = task$data()[[1]]
newdata = task$data()[[1L]]
list(pdf = self$model$distr$pdf(newdata), cdf = self$model$distr$cdf(newdata),
distr = self$model$distr)
}
Expand Down
18 changes: 10 additions & 8 deletions R/LearnerDensKDE.R
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ LearnerDensKDE = R6::R6Class("LearnerDensKDE",
#' Creates a new instance of this [R6][R6::R6Class] class.
initialize = function() {
ps = ps(
kernel = p_fct(levels = subset(distr6::listKernels(), select = "ShortName")[[1]],
kernel = p_fct(levels = subset(distr6::listKernels(), select = "ShortName")[[1L]],
default = "Epan", tags = "train"),
bandwidth = p_dbl(lower = 0, tags = "train", special_vals = list("silver"))
)
Expand Down Expand Up @@ -46,28 +46,30 @@ LearnerDensKDE = R6::R6Class("LearnerDensKDE",
self$param_set$values$kernel == "Epan"
}

data = task$data()[[1]]
data = task$data()[[1L]]

kernel = get(as.character(subset(
distr6::listKernels(),
ShortName == self$param_set$values$kernel,
ClassName)))$new()


bw = ifelse(self$param_set$values$bandwidth == "silver",
bw = if (isTRUE(self$param_set$values$bandwidth == "silver")) {
0.9 * min(sd(data), stats::IQR(data, na.rm = TRUE) / 1.349, na.rm = TRUE) *
length(data)^-0.2,
self$param_set$values$bandwidth)
length(data)^-0.2
} else {
self$param_set$values$bandwidth
}

pdf = function(x) {} # nolint

body(pdf) = substitute({
if (length(x) == 1) {
if (length(x) == 1L) {
return(1 / (rows * bw) * sum(kernel$pdf((x - train) / bw)))
} else {
x = matrix(x, nrow = length(x), ncol = rows)
train_mat = matrix(train, nrow = nrow(x), ncol = rows, byrow = TRUE)
return(1 / (rows * bw) * colSums(apply((x - train_mat) / bw, 1, kernel$pdf)))
return(1 / (rows * bw) * colSums(apply((x - train_mat) / bw, 1L, kernel$pdf)))
}
}, list(
rows = task$nrow,
Expand All @@ -83,7 +85,7 @@ LearnerDensKDE = R6::R6Class("LearnerDensKDE",
},

.predict = function(task) {
list(pdf = self$model$pdf(task$data()[[1]]),
list(pdf = self$model$pdf(task$data()[[1L]]),
distr = self$model)
}
)
Expand Down
4 changes: 2 additions & 2 deletions R/LearnerSurvCoxPH.R
Original file line number Diff line number Diff line change
Expand Up @@ -55,13 +55,13 @@ LearnerSurvCoxPH = R6Class("LearnerSurvCoxPH",
stop(sprintf(
"Learner %s on task %s failed to predict: Missing values in new data (line(s) %s)\n",
self$id, task$id,
paste0(which(!complete.cases(newdata)), collapse = ", ")))
toString(which(!complete.cases(newdata)))))
}

pv = self$param_set$get_values(tags = "predict")

# Get predicted values
fit = mlr3misc::invoke(survival::survfit, formula = self$model, newdata = newdata,
fit = invoke(survival::survfit, formula = self$model, newdata = newdata,
se.fit = FALSE, .args = pv)

lp = predict(self$model, type = "lp", newdata = newdata)
Expand Down
2 changes: 1 addition & 1 deletion R/MeasureRegrLogloss.R
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ MeasureRegrLogloss = R6::R6Class("MeasureRegrLogloss",
if (inherits(distr, c("Matdist", "Arrdist"))) {
pdf = diag(distr$pdf(truth))
} else {
pdf = as.numeric(distr$pdf(data = matrix(truth, nrow = 1)))
pdf = as.numeric(distr$pdf(data = matrix(truth, nrow = 1L)))
}

pdf[pdf == 0] = self$param_set$values$eps
Expand Down
8 changes: 4 additions & 4 deletions R/MeasureSurvAUC.R
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ MeasureSurvAUC = R6Class("MeasureSurvAUC",
#' Creates a new instance of this [R6][R6::R6Class] class.
initialize = function(id, properties = character(), label = NA_character_,
man = NA_character_, param_set = ps()) {
if (class(self)[[1]] == "MeasureSurvAUC") {
if (class(self)[[1L]] == "MeasureSurvAUC") {
stop("This is an abstract class that should not be constructed directly.")
}

Expand Down Expand Up @@ -48,15 +48,15 @@ MeasureSurvAUC = R6Class("MeasureSurvAUC",
}

args$times = ps$times
if (length(args$times) == 0) {
args$times = sort(unique(prediction$truth[, 1]))
if (length(args$times) == 0L) {
args$times = sort(unique(prediction$truth[, 1L]))
}

if ("Surv.rsp.new" %in% names(formals(FUN))) {
args$Surv.rsp.new = prediction$truth # nolint
}

auc = mlr3misc::invoke(FUN, lpnew = prediction$lp, .args = args)
auc = invoke(FUN, lpnew = prediction$lp, .args = args)

if (is.null(ps$integrated) || !ps$integrated || grepl("tnr|tpr", self$id)) {
auc
Expand Down
10 changes: 5 additions & 5 deletions R/MeasureSurvCalibrationAlpha.R
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ MeasureSurvCalibrationAlpha = R6Class("MeasureSurvCalibrationAlpha",
)
ps$values = list(eps = 1e-3, se = FALSE, method = method, truncate = Inf)
range = if (method == "ratio") c(-Inf, Inf) else c(0, Inf)
minimize = ifelse(method == "ratio", FALSE, TRUE)
minimize = method != "ratio"

super$initialize(
id = "surv.calib_alpha",
Expand All @@ -73,8 +73,8 @@ MeasureSurvCalibrationAlpha = R6Class("MeasureSurvCalibrationAlpha",
private = list(
.score = function(prediction, ...) {
truth = prediction$truth
all_times = truth[, 1] # both event times and censoring times
status = truth[, 2]
all_times = truth[, 1L] # both event times and censoring times
status = truth[, 2L]
deaths = sum(status)

ps = self$param_set$values
Expand All @@ -86,7 +86,7 @@ MeasureSurvCalibrationAlpha = R6Class("MeasureSurvCalibrationAlpha",
# Bypass distr6 construction if underlying distr represented by array
if (inherits(distr, "array")) {
surv = distr
if (length(dim(surv)) == 3) {
if (length(dim(surv)) == 3L) {
# survival 3d array, extract median
surv = .ext_surv_mat(arr = surv, which.curve = 0.5)
}
Expand All @@ -103,7 +103,7 @@ MeasureSurvCalibrationAlpha = R6Class("MeasureSurvCalibrationAlpha",
} else {
if (inherits(distr, "VectorDistribution")) {
cumhaz = as.numeric(
distr$cumHazard(data = matrix(all_times, nrow = 1))
distr$cumHazard(data = matrix(all_times, nrow = 1L))
)
} else {
cumhaz = diag(as.matrix(distr$cumHazard(all_times)))
Expand Down
8 changes: 4 additions & 4 deletions R/MeasureSurvCalibrationBeta.R
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ MeasureSurvCalibrationBeta = R6Class("MeasureSurvCalibrationBeta",
)
ps$values = list(se = FALSE, method = method)
range = if (method == "ratio") c(-Inf, Inf) else c(0, Inf)
minimize = ifelse(method == "ratio", FALSE, TRUE)
minimize = method != "ratio"

super$initialize(
id = "surv.calib_beta",
Expand All @@ -66,15 +66,15 @@ MeasureSurvCalibrationBeta = R6Class("MeasureSurvCalibrationBeta",
df = data.frame(truth = prediction$truth, lp = prediction$lp)
fit = try(summary(survival::coxph(truth ~ lp, data = df)), silent = TRUE)

if (class(fit)[1] == "try-error") {
if (inherits(fit, "try-error")) {
return(NA)
} else {
ps = self$param_set$values

if (ps$se) {
return(fit$coefficients[,"se(coef)"])
return(fit$coefficients[, "se(coef)"])
} else {
out = fit$coefficients[,"coef"]
out = fit$coefficients[, "coef"]

if (ps$method == "diff") {
out = abs(1 - out)
Expand Down
4 changes: 2 additions & 2 deletions R/MeasureSurvChamblessAUC.R
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,9 @@ MeasureSurvChamblessAUC = R6Class("MeasureSurvChamblessAUC",
ps = self$param_set$values
if (!ps$integrated) {
msg = "If `integrated=FALSE` then `times` should be a scalar numeric."
assert_numeric(ps$times, len = 1, .var.name = msg)
assert_numeric(ps$times, len = 1L, .var.name = msg)
} else {
if (!is.null(ps$times) && length(ps$times) == 1) {
if (!is.null(ps$times) && length(ps$times) == 1L) {
ps$integrated = FALSE
}
}
Expand Down
6 changes: 3 additions & 3 deletions R/MeasureSurvCindex.R
Original file line number Diff line number Diff line change
Expand Up @@ -110,15 +110,15 @@ MeasureSurvCindex = R6Class("MeasureSurvCindex",
# calculate t_max (cutoff time horizon)
if (is.null(ps$t_max) && !is.null(ps$p_max)) {
truth = prediction$truth
unique_times = unique(sort(truth[,"time"]))
unique_times = unique(sort(truth[, "time"]))
surv = survival::survfit(truth ~ 1)
indx = which(1 - (surv$n.risk / surv$n) > ps$p_max)
if (length(indx) == 0) {
if (length(indx) == 0L) {
t_max = NULL # t_max calculated in `cindex()`
} else {
# first time point that surpasses the specified
# `p_max` proportion of censoring
t_max = surv$time[indx[1]]
t_max = surv$time[indx[1L]]
}
} else {
t_max = ps$t_max
Expand Down
2 changes: 1 addition & 1 deletion R/MeasureSurvDCalibration.R
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ MeasureSurvDCalibration = R6Class("MeasureSurvDCalibration",
#' @description Creates a new instance of this [R6][R6::R6Class] class.
initialize = function() {
ps = ps(
B = p_int(1, default = 10),
B = p_int(1L, default = 10L),
chisq = p_lgl(default = FALSE),
truncate = p_dbl(lower = 0, upper = Inf, default = Inf)
)
Expand Down
4 changes: 2 additions & 2 deletions R/MeasureSurvGraf.R
Original file line number Diff line number Diff line change
Expand Up @@ -102,9 +102,9 @@ MeasureSurvGraf = R6::R6Class("MeasureSurvGraf",
}
if (!ps$integrated) {
msg = "If `integrated=FALSE` then `times` should be a scalar numeric."
assert_numeric(ps$times, len = 1, .var.name = msg)
assert_numeric(ps$times, len = 1L, .var.name = msg)
} else {
if (!is.null(ps$times) && length(ps$times) == 1) {
if (!is.null(ps$times) && length(ps$times) == 1L) {
ps$integrated = FALSE
}
}
Expand Down
4 changes: 2 additions & 2 deletions R/MeasureSurvIntLogloss.R
Original file line number Diff line number Diff line change
Expand Up @@ -95,9 +95,9 @@ MeasureSurvIntLogloss = R6::R6Class("MeasureSurvIntLogloss",
}
if (!ps$integrated) {
msg = "If `integrated=FALSE` then `times` should be a scalar numeric."
assert_numeric(ps$times, len = 1, .var.name = msg)
assert_numeric(ps$times, len = 1L, .var.name = msg)
} else {
if (!is.null(ps$times) && length(ps$times) == 1) {
if (!is.null(ps$times) && length(ps$times) == 1L) {
ps$integrated = FALSE
}
}
Expand Down
2 changes: 1 addition & 1 deletion R/MeasureSurvLogloss.R
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ MeasureSurvLogloss = R6::R6Class("MeasureSurvLogloss",
ps = self$param_set$values

if (ps$se) {
ll = surv_logloss(prediction$truth, prediction$data$distr, ps$eps, ps$IPCW, train) #nolint
ll = surv_logloss(prediction$truth, prediction$data$distr, ps$eps, ps$IPCW, train) # nolint
sd(ll) / sqrt(length(ll))
} else {
mean(surv_logloss(prediction$truth, prediction$data$distr, ps$eps, ps$IPCW, train)) # nolint
Expand Down
8 changes: 4 additions & 4 deletions R/MeasureSurvRCLL.R
Original file line number Diff line number Diff line change
Expand Up @@ -69,14 +69,14 @@ MeasureSurvRCLL = R6::R6Class("MeasureSurvRCLL",
}
out = rep(-99L, length(prediction$row_ids))
truth = prediction$truth
event = truth[, 2] == 1
event_times = truth[event, 1]
cens_times = truth[!event, 1]
event = truth[, 2L] == 1
event_times = truth[event, 1L]
cens_times = truth[!event, 1L]

# Bypass distr6 construction if underlying distr represented by array
if (inherits(prediction$data$distr, "array")) {
surv = prediction$data$distr
if (length(dim(surv)) == 3) {
if (length(dim(surv)) == 3L) {
# survival 3d array, extract median
surv = .ext_surv_mat(arr = surv, which.curve = 0.5)
}
Expand Down
4 changes: 2 additions & 2 deletions R/MeasureSurvSchmid.R
Original file line number Diff line number Diff line change
Expand Up @@ -91,9 +91,9 @@ MeasureSurvSchmid = R6::R6Class("MeasureSurvSchmid",
}
if (!ps$integrated) {
msg = "If `integrated=FALSE` then `times` should be a scalar numeric."
assert_numeric(ps$times, len = 1, .var.name = msg)
assert_numeric(ps$times, len = 1L, .var.name = msg)
} else {
if (!is.null(ps$times) && length(ps$times) == 1) {
if (!is.null(ps$times) && length(ps$times) == 1L) {
ps$integrated = FALSE
}
}
Expand Down
2 changes: 1 addition & 1 deletion R/MeasureSurvUnoAUC.R
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ MeasureSurvUnoAUC = R6Class("MeasureSurvUnoAUC",
msg = "If `integrated=FALSE` then `times` should be a scalar numeric."
assert_numeric(ps$times, len = 1, .var.name = msg)
} else {
if (!is.null(ps$times) && length(ps$times) == 1) {
if (!is.null(ps$times) && length(ps$times) == 1L) {
ps$integrated = FALSE
}
}
Expand Down
2 changes: 1 addition & 1 deletion R/PipeOpBreslow.R
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@
#' task = tsk("rats")
#' part = partition(task, ratio = 0.8)
#' train_task = task$clone()$filter(part$train)
#' test_task = task$clone()$filter(part$test)
#' test_task = task$clone()$filter(part$test)
#'
#' learner = lrn("surv.coxph") # learner with lp predictions
#' b = po("breslowcompose", learner = learner, breslow.overwrite = TRUE)
Expand Down
10 changes: 5 additions & 5 deletions R/PipeOpCrankCompositor.R
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ PipeOpCrankCompositor = R6Class("PipeOpCrankCompositor",
param_set = ps(
method = p_fct(default = "sum_haz", levels = c("sum_haz", "mean", "median", "mode"),
tags = "predict"),
which = p_int(default = 1, lower = 1, tags = "predict"),
which = p_int(default = 1L, lower = 1L, tags = "predict"),
response = p_lgl(default = FALSE, tags = "predict"),
overwrite = p_lgl(default = FALSE, tags = "predict")
)
Expand All @@ -105,7 +105,7 @@ PipeOpCrankCompositor = R6Class("PipeOpCrankCompositor",

.predict = function(inputs) {

inpred = inputs[[1]]
inpred = inputs[[1L]]

response = self$param_set$values$response
b_response = !anyMissing(inpred$response)
Expand All @@ -120,7 +120,7 @@ PipeOpCrankCompositor = R6Class("PipeOpCrankCompositor",
} else {
assert("distr" %in% inpred$predict_types)
method = self$param_set$values$method
if (length(method) == 0) method = "sum_haz"
if (length(method) == 0L) method = "sum_haz"
if (method == "sum_haz") {
if (inherits(inpred$data$distr, "matrix") ||
!requireNamespace("survivalmodels", quietly = TRUE)) {
Expand All @@ -132,11 +132,11 @@ PipeOpCrankCompositor = R6Class("PipeOpCrankCompositor",
}
} else if (method == "mean") {
comp = try(inpred$distr$mean(), silent = TRUE)
if (class(comp)[1] == "try-error") {
if (inherits(comp, "try-error")) {
requireNamespace("cubature")
comp = try(inpred$distr$mean(cubature = TRUE), silent = TRUE)
}
if (class(comp)[1] == "try-error") {
if (inherits(comp, "try-error")) {
comp = numeric(length(inpred$crank))
}
} else {
Expand Down
Loading