Skip to content

Commit 694300d

Browse files
committed
add mirai in tests
1 parent 7e0ba83 commit 694300d

File tree

4 files changed

+70
-5
lines changed

4 files changed

+70
-5
lines changed

R/LearnerSurvKaplan.R

Lines changed: 32 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,22 +19,50 @@ LearnerSurvKaplan = R6Class("LearnerSurvKaplan",
1919
id = "surv.kaplan",
2020
predict_types = c("crank", "distr"),
2121
feature_types = c("logical", "integer", "numeric", "character", "factor", "ordered"),
22-
properties = "missings",
22+
properties = c("missings", "importance", "selected_features"),
2323
packages = c("survival", "distr6"),
2424
label = "Kaplan-Meier Estimator",
2525
man = "mlr3proba::mlr_learners_surv.kaplan"
2626
)
27+
},
28+
29+
#' @description
30+
#' All features have a score of `0` for this learner.
31+
#' #' This method exists solely for compatibility with the `mlr3` ecosystem,
32+
#' as this learner is used as a fallback for other survival learners that
33+
#' require an `importance()` method.
34+
#'
35+
#' @return Named `numeric()`.
36+
importance = function() {
37+
if (is.null(self$model)) {
38+
stopf("No model stored")
39+
}
40+
41+
fn = self$model$features
42+
named_vector(fn, 0)
43+
},
44+
45+
#' @description
46+
#' Selected features are always the empty set for this learner.
47+
#' This method is implemented only for compatibility with the `mlr3` API,
48+
#' as this learner does not perform feature selection.
49+
#'
50+
#' @return `character(0)`.
51+
selected_features = function() {
52+
character()
2753
}
2854
),
2955

3056
private = list(
3157
.train = function(task) {
32-
invoke(survival::survfit, formula = task$formula(1), data = task$data())
58+
list(model = invoke(survival::survfit, formula = task$formula(1),
59+
data = task$data(cols = task$target_names)),
60+
features = task$feature_names) # keep for importance
3361
},
3462

3563
.predict = function(task) {
36-
times = self$model$time
37-
surv = matrix(rep(self$model$surv, task$nrow), ncol = length(times),
64+
times = self$model$model$time
65+
surv = matrix(rep(self$model$model$surv, task$nrow), ncol = length(times),
3866
nrow = task$nrow, byrow = TRUE)
3967

4068
.surv_return(times = times, surv = surv)

man/mlr_learners_surv.kaplan.Rd

Lines changed: 33 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

tests/testthat/setup.R

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
11
lg = lgr::get_logger("mlr3")
22
old_threshold = lg$threshold
33
lg$set_threshold("warn")
4+
5+
mirai::daemons(0, .compute = "mlr3_encapsulation")
6+
mirai::daemons(1, .compute = "mlr3_encapsulation")

tests/testthat/teardown.R

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,3 @@
11
lg$set_threshold(old_threshold)
2+
3+
mirai::daemons(0, .compute = "mlr3_encapsulation")

0 commit comments

Comments
 (0)