Skip to content

Commit 733b2e1

Browse files
authored
Merge pull request #394 from m-muecke/linting
refactor: use inherits, toString, explicit integer minor formatting
2 parents 2cd4b28 + d221991 commit 733b2e1

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

49 files changed

+250
-242
lines changed

DESCRIPTION

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
Package: mlr3proba
22
Title: Probabilistic Supervised Learning for 'mlr3'
3-
Version: 0.6.3
3+
Version: 0.6.4
44
Authors@R:
55
c(person(given = "Raphael",
66
family = "Sonabend",
@@ -30,11 +30,16 @@ Authors@R:
3030
role = c("cre", "aut"),
3131
email = "bblodfon@gmail.com",
3232
comment = c(ORCID = "0000-0002-3609-8674")),
33-
person(given = "Lukas",
33+
person(given = "Lukas",
3434
family = "Burk",
35-
email = "github@quantenbrot.de",
35+
email = "github@quantenbrot.de",
3636
role = "ctb",
37-
comment = c(ORCID = "0000-0001-7528-3795")))
37+
comment = c(ORCID = "0000-0001-7528-3795")),
38+
person(given = "Maximilian",
39+
family = "Muecke",
40+
email = "muecke.maximilian@gmail.com",
41+
role = "ctb",
42+
comment = c(ORCID = "0009-0000-9432-9795")))
3843
Description: Provides extensions for probabilistic supervised learning for
3944
'mlr3'. This includes extending the regression task to probabilistic
4045
and interval regression, adding a survival task, and other specialized
@@ -86,7 +91,7 @@ Encoding: UTF-8
8691
LazyData: true
8792
NeedsCompilation: no
8893
Roxygen: list(markdown = TRUE, r6 = TRUE)
89-
RoxygenNote: 7.3.1
94+
RoxygenNote: 7.3.2
9095
Collate:
9196
'LearnerDens.R'
9297
'aaa.R'

NEWS.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,8 @@
1+
# mlr3proba 0.6.4
2+
3+
* Add useR! 2024 tutorial
4+
* Lots of refactoring, improve code quality (thanks to @m-muecke)
5+
16
# mlr3proba 0.6.3
27

38
* Add new tasks from `survival` package: `veteran`, `pbc`, `mgus`, `gbsg`

R/LearnerDensHistogram.R

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,12 +27,12 @@ LearnerDensHistogram = R6::R6Class("LearnerDensHistogram",
2727
private = list(
2828
.train = function(task) {
2929
pars = self$param_set$get_values(tags = "train")
30-
fit = invoke(.histogram, dat = task$data()[[1]], .args = pars)
30+
fit = invoke(.histogram, dat = task$data()[[1L]], .args = pars)
3131
set_class(list(distr = fit$distr, hist = fit$hist), "dens.hist")
3232
},
3333

3434
.predict = function(task) {
35-
newdata = task$data()[[1]]
35+
newdata = task$data()[[1L]]
3636
list(pdf = self$model$distr$pdf(newdata), cdf = self$model$distr$cdf(newdata),
3737
distr = self$model$distr)
3838
}

R/LearnerDensKDE.R

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ LearnerDensKDE = R6::R6Class("LearnerDensKDE",
1818
#' Creates a new instance of this [R6][R6::R6Class] class.
1919
initialize = function() {
2020
ps = ps(
21-
kernel = p_fct(levels = subset(distr6::listKernels(), select = "ShortName")[[1]],
21+
kernel = p_fct(levels = subset(distr6::listKernels(), select = "ShortName")[[1L]],
2222
default = "Epan", tags = "train"),
2323
bandwidth = p_dbl(lower = 0, tags = "train", special_vals = list("silver"))
2424
)
@@ -46,28 +46,30 @@ LearnerDensKDE = R6::R6Class("LearnerDensKDE",
4646
self$param_set$values$kernel == "Epan"
4747
}
4848

49-
data = task$data()[[1]]
49+
data = task$data()[[1L]]
5050

5151
kernel = get(as.character(subset(
5252
distr6::listKernels(),
5353
ShortName == self$param_set$values$kernel,
5454
ClassName)))$new()
5555

5656

57-
bw = ifelse(self$param_set$values$bandwidth == "silver",
57+
bw = if (isTRUE(self$param_set$values$bandwidth == "silver")) {
5858
0.9 * min(sd(data), stats::IQR(data, na.rm = TRUE) / 1.349, na.rm = TRUE) *
59-
length(data)^-0.2,
60-
self$param_set$values$bandwidth)
59+
length(data)^-0.2
60+
} else {
61+
self$param_set$values$bandwidth
62+
}
6163

6264
pdf = function(x) {} # nolint
6365

6466
body(pdf) = substitute({
65-
if (length(x) == 1) {
67+
if (length(x) == 1L) {
6668
return(1 / (rows * bw) * sum(kernel$pdf((x - train) / bw)))
6769
} else {
6870
x = matrix(x, nrow = length(x), ncol = rows)
6971
train_mat = matrix(train, nrow = nrow(x), ncol = rows, byrow = TRUE)
70-
return(1 / (rows * bw) * colSums(apply((x - train_mat) / bw, 1, kernel$pdf)))
72+
return(1 / (rows * bw) * colSums(apply((x - train_mat) / bw, 1L, kernel$pdf)))
7173
}
7274
}, list(
7375
rows = task$nrow,
@@ -83,7 +85,7 @@ LearnerDensKDE = R6::R6Class("LearnerDensKDE",
8385
},
8486

8587
.predict = function(task) {
86-
list(pdf = self$model$pdf(task$data()[[1]]),
88+
list(pdf = self$model$pdf(task$data()[[1L]]),
8789
distr = self$model)
8890
}
8991
)

R/LearnerSurvCoxPH.R

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,13 +55,13 @@ LearnerSurvCoxPH = R6Class("LearnerSurvCoxPH",
5555
stop(sprintf(
5656
"Learner %s on task %s failed to predict: Missing values in new data (line(s) %s)\n",
5757
self$id, task$id,
58-
paste0(which(!complete.cases(newdata)), collapse = ", ")))
58+
toString(which(!complete.cases(newdata)))))
5959
}
6060

6161
pv = self$param_set$get_values(tags = "predict")
6262

6363
# Get predicted values
64-
fit = mlr3misc::invoke(survival::survfit, formula = self$model, newdata = newdata,
64+
fit = invoke(survival::survfit, formula = self$model, newdata = newdata,
6565
se.fit = FALSE, .args = pv)
6666

6767
lp = predict(self$model, type = "lp", newdata = newdata)

R/MeasureRegrLogloss.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ MeasureRegrLogloss = R6::R6Class("MeasureRegrLogloss",
4545
if (inherits(distr, c("Matdist", "Arrdist"))) {
4646
pdf = diag(distr$pdf(truth))
4747
} else {
48-
pdf = as.numeric(distr$pdf(data = matrix(truth, nrow = 1)))
48+
pdf = as.numeric(distr$pdf(data = matrix(truth, nrow = 1L)))
4949
}
5050

5151
pdf[pdf == 0] = self$param_set$values$eps

R/MeasureSurvAUC.R

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ MeasureSurvAUC = R6Class("MeasureSurvAUC",
1616
#' Creates a new instance of this [R6][R6::R6Class] class.
1717
initialize = function(id, properties = character(), label = NA_character_,
1818
man = NA_character_, param_set = ps()) {
19-
if (class(self)[[1]] == "MeasureSurvAUC") {
19+
if (class(self)[[1L]] == "MeasureSurvAUC") {
2020
stop("This is an abstract class that should not be constructed directly.")
2121
}
2222

@@ -48,15 +48,15 @@ MeasureSurvAUC = R6Class("MeasureSurvAUC",
4848
}
4949

5050
args$times = ps$times
51-
if (length(args$times) == 0) {
52-
args$times = sort(unique(prediction$truth[, 1]))
51+
if (length(args$times) == 0L) {
52+
args$times = sort(unique(prediction$truth[, 1L]))
5353
}
5454

5555
if ("Surv.rsp.new" %in% names(formals(FUN))) {
5656
args$Surv.rsp.new = prediction$truth # nolint
5757
}
5858

59-
auc = mlr3misc::invoke(FUN, lpnew = prediction$lp, .args = args)
59+
auc = invoke(FUN, lpnew = prediction$lp, .args = args)
6060

6161
if (is.null(ps$integrated) || !ps$integrated || grepl("tnr|tpr", self$id)) {
6262
auc

R/MeasureSurvCalibrationAlpha.R

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ MeasureSurvCalibrationAlpha = R6Class("MeasureSurvCalibrationAlpha",
5656
)
5757
ps$values = list(eps = 1e-3, se = FALSE, method = method, truncate = Inf)
5858
range = if (method == "ratio") c(-Inf, Inf) else c(0, Inf)
59-
minimize = ifelse(method == "ratio", FALSE, TRUE)
59+
minimize = method != "ratio"
6060

6161
super$initialize(
6262
id = "surv.calib_alpha",
@@ -73,8 +73,8 @@ MeasureSurvCalibrationAlpha = R6Class("MeasureSurvCalibrationAlpha",
7373
private = list(
7474
.score = function(prediction, ...) {
7575
truth = prediction$truth
76-
all_times = truth[, 1] # both event times and censoring times
77-
status = truth[, 2]
76+
all_times = truth[, 1L] # both event times and censoring times
77+
status = truth[, 2L]
7878
deaths = sum(status)
7979

8080
ps = self$param_set$values
@@ -86,7 +86,7 @@ MeasureSurvCalibrationAlpha = R6Class("MeasureSurvCalibrationAlpha",
8686
# Bypass distr6 construction if underlying distr represented by array
8787
if (inherits(distr, "array")) {
8888
surv = distr
89-
if (length(dim(surv)) == 3) {
89+
if (length(dim(surv)) == 3L) {
9090
# survival 3d array, extract median
9191
surv = .ext_surv_mat(arr = surv, which.curve = 0.5)
9292
}
@@ -103,7 +103,7 @@ MeasureSurvCalibrationAlpha = R6Class("MeasureSurvCalibrationAlpha",
103103
} else {
104104
if (inherits(distr, "VectorDistribution")) {
105105
cumhaz = as.numeric(
106-
distr$cumHazard(data = matrix(all_times, nrow = 1))
106+
distr$cumHazard(data = matrix(all_times, nrow = 1L))
107107
)
108108
} else {
109109
cumhaz = diag(as.matrix(distr$cumHazard(all_times)))

R/MeasureSurvCalibrationBeta.R

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ MeasureSurvCalibrationBeta = R6Class("MeasureSurvCalibrationBeta",
4747
)
4848
ps$values = list(se = FALSE, method = method)
4949
range = if (method == "ratio") c(-Inf, Inf) else c(0, Inf)
50-
minimize = ifelse(method == "ratio", FALSE, TRUE)
50+
minimize = method != "ratio"
5151

5252
super$initialize(
5353
id = "surv.calib_beta",
@@ -66,15 +66,15 @@ MeasureSurvCalibrationBeta = R6Class("MeasureSurvCalibrationBeta",
6666
df = data.frame(truth = prediction$truth, lp = prediction$lp)
6767
fit = try(summary(survival::coxph(truth ~ lp, data = df)), silent = TRUE)
6868

69-
if (class(fit)[1] == "try-error") {
69+
if (inherits(fit, "try-error")) {
7070
return(NA)
7171
} else {
7272
ps = self$param_set$values
7373

7474
if (ps$se) {
75-
return(fit$coefficients[,"se(coef)"])
75+
return(fit$coefficients[, "se(coef)"])
7676
} else {
77-
out = fit$coefficients[,"coef"]
77+
out = fit$coefficients[, "coef"]
7878

7979
if (ps$method == "diff") {
8080
out = abs(1 - out)

R/MeasureSurvChamblessAUC.R

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,9 +46,9 @@ MeasureSurvChamblessAUC = R6Class("MeasureSurvChamblessAUC",
4646
ps = self$param_set$values
4747
if (!ps$integrated) {
4848
msg = "If `integrated=FALSE` then `times` should be a scalar numeric."
49-
assert_numeric(ps$times, len = 1, .var.name = msg)
49+
assert_numeric(ps$times, len = 1L, .var.name = msg)
5050
} else {
51-
if (!is.null(ps$times) && length(ps$times) == 1) {
51+
if (!is.null(ps$times) && length(ps$times) == 1L) {
5252
ps$integrated = FALSE
5353
}
5454
}

0 commit comments

Comments
 (0)