@@ -19,22 +19,50 @@ LearnerSurvKaplan = R6Class("LearnerSurvKaplan",
19
19
id = " surv.kaplan" ,
20
20
predict_types = c(" crank" , " distr" ),
21
21
feature_types = c(" logical" , " integer" , " numeric" , " character" , " factor" , " ordered" ),
22
- properties = " missings" ,
22
+ properties = c( " missings" , " importance " , " selected_features " ) ,
23
23
packages = c(" survival" , " distr6" ),
24
24
label = " Kaplan-Meier Estimator" ,
25
25
man = " mlr3proba::mlr_learners_surv.kaplan"
26
26
)
27
+ },
28
+
29
+ # ' @description
30
+ # ' All features have a score of `0` for this learner.
31
+ # ' #' This method exists solely for compatibility with the `mlr3` ecosystem,
32
+ # ' as this learner is used as a fallback for other survival learners that
33
+ # ' require an `importance()` method.
34
+ # '
35
+ # ' @return Named `numeric()`.
36
+ importance = function () {
37
+ if (is.null(self $ model )) {
38
+ stopf(" No model stored" )
39
+ }
40
+
41
+ fn = self $ model $ features
42
+ named_vector(fn , 0 )
43
+ },
44
+
45
+ # ' @description
46
+ # ' Selected features are always the empty set for this learner.
47
+ # ' This method is implemented only for compatibility with the `mlr3` API,
48
+ # ' as this learner does not perform feature selection.
49
+ # '
50
+ # ' @return `character(0)`.
51
+ selected_features = function () {
52
+ character ()
27
53
}
28
54
),
29
55
30
56
private = list (
31
57
.train = function (task ) {
32
- invoke(survival :: survfit , formula = task $ formula(1 ), data = task $ data())
58
+ list (model = invoke(survival :: survfit , formula = task $ formula(1 ),
59
+ data = task $ data(cols = task $ target_names )),
60
+ features = task $ feature_names ) # keep for importance
33
61
},
34
62
35
63
.predict = function (task ) {
36
- times = self $ model $ time
37
- surv = matrix (rep(self $ model $ surv , task $ nrow ), ncol = length(times ),
64
+ times = self $ model $ model $ time
65
+ surv = matrix (rep(self $ model $ model $ surv , task $ nrow ), ncol = length(times ),
38
66
nrow = task $ nrow , byrow = TRUE )
39
67
40
68
.surv_return(times = times , surv = surv )
0 commit comments