@@ -19,7 +19,7 @@ LearnerSurvKaplan = R6Class("LearnerSurvKaplan",
19
19
id = " surv.kaplan" ,
20
20
predict_types = c(" crank" , " distr" ),
21
21
feature_types = c(" logical" , " integer" , " numeric" , " character" , " factor" , " ordered" ),
22
- properties = c(" missings" , " importance" , " selected_features" ),
22
+ properties = c(" missings" , " weights " , " importance" , " selected_features" ),
23
23
packages = c(" survival" , " distr6" ),
24
24
label = " Kaplan-Meier Estimator" ,
25
25
man = " mlr3proba::mlr_learners_surv.kaplan"
@@ -59,9 +59,15 @@ LearnerSurvKaplan = R6Class("LearnerSurvKaplan",
59
59
60
60
private = list (
61
61
.train = function (task ) {
62
- list (model = invoke(survival :: survfit , formula = task $ formula(1 ),
63
- data = task $ data(cols = task $ target_names )),
64
- features = task $ feature_names ) # keep for importance
62
+ fit = invoke(
63
+ survival :: survfit ,
64
+ formula = task $ formula(1 ),
65
+ data = task $ data(),
66
+ .args = list (weights = private $ .get_weights(task ))
67
+ )
68
+
69
+ # keep features for importance
70
+ list (model = fit , features = task $ feature_names )
65
71
},
66
72
67
73
.predict = function (task ) {
0 commit comments