Skip to content

Commit dd74674

Browse files
committed
add property weights to KM
1 parent a92d1a0 commit dd74674

File tree

2 files changed

+15
-6
lines changed

2 files changed

+15
-6
lines changed

R/LearnerSurvKaplan.R

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ LearnerSurvKaplan = R6Class("LearnerSurvKaplan",
1919
id = "surv.kaplan",
2020
predict_types = c("crank", "distr"),
2121
feature_types = c("logical", "integer", "numeric", "character", "factor", "ordered"),
22-
properties = c("missings", "importance", "selected_features"),
22+
properties = c("missings", "weights", "importance", "selected_features"),
2323
packages = c("survival", "distr6"),
2424
label = "Kaplan-Meier Estimator",
2525
man = "mlr3proba::mlr_learners_surv.kaplan"
@@ -59,9 +59,15 @@ LearnerSurvKaplan = R6Class("LearnerSurvKaplan",
5959

6060
private = list(
6161
.train = function(task) {
62-
list(model = invoke(survival::survfit, formula = task$formula(1),
63-
data = task$data(cols = task$target_names)),
64-
features = task$feature_names) # keep for importance
62+
fit = invoke(
63+
survival::survfit,
64+
formula = task$formula(1),
65+
data = task$data(),
66+
.args = list(weights = private$.get_weights(task))
67+
)
68+
69+
# keep features for importance
70+
list(model = fit, features = task$feature_names)
6571
},
6672

6773
.predict = function(task) {

tests/testthat/test_mlr_learners_surv_kaplan.R

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,10 @@ test_that("importance/selected", {
1919
learner = lrn("surv.kaplan")
2020
expect_error(learner$importance(), "No model stored")
2121
expect_error(learner$selected_features(), "No model stored")
22-
learner$train(tsk("rats"))
22+
23+
task = tsk("rats")
24+
learner$train(task)
2325
expect_character(learner$selected_features(), len = 0)
24-
expect_named(learner$importance())
26+
expect_named(learner$importance(), expected = task$feature_names)
27+
expect_true(all(learner$importance() == 0))
2528
})

0 commit comments

Comments
 (0)