Skip to content

Commit 0c1f65b

Browse files
committed
update tests
1 parent d0551a6 commit 0c1f65b

File tree

1 file changed

+25
-14
lines changed

1 file changed

+25
-14
lines changed

tests/testthat/test_pipelines.R

Lines changed: 25 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -115,44 +115,55 @@ test_that("survtoclassif", {
115115

116116
pipe = mlr3pipelines::ppl("survtoclassif", mlr3learners::LearnerClassifLogReg$new())
117117
expect_class(pipe, "Graph")
118+
pipe = mlr3pipelines::ppl("survtoclassif", mlr3learners::LearnerClassifLogReg$new(), graph_learner = TRUE)
119+
expect_class(pipe, "GraphLearner")
118120
pipe$train(task)
119121
p = pipe$predict(task)
120-
expect_prediction_surv(p[[1]])
122+
expect_prediction_surv(p)
121123

122124
cox = lrn("surv.coxph")
123125
cox$train(task) |> suppressWarnings()
124126
p2 = cox$predict(task)
125127

126-
expect_equal(p[[1]]$truth, p2$truth)
127-
expect_equal(p[[1]]$score(), p2$score(), tolerance = 0.1)
128+
expect_equal(p$truth, p2$truth)
129+
expect_equal(p$score(), p2$score(), tolerance = 0.1)
128130

129131
# Test with cut
130-
pipe = mlr3pipelines::ppl("survtoclassif", mlr3learners::LearnerClassifLogReg$new(), cut = c(10, 30, 50))
131-
expect_class(pipe, "Graph")
132+
pipe = mlr3pipelines::ppl("survtoclassif", mlr3learners::LearnerClassifLogReg$new(), cut = c(10, 30, 50), graph_learner = TRUE)
133+
expect_class(pipe, "GraphLearner")
132134
pipe$train(task) |> suppressWarnings()
133135
p = pipe$predict(task)
134-
expect_prediction_surv(p[[1]])
136+
expect_prediction_surv(p)
135137

136138
# Test with max_time
137139
t = task$data()[status == 1, min(time)]
138-
pipe = mlr3pipelines::ppl("survtoclassif", mlr3learners::LearnerClassifLogReg$new(), max_time = t)
139-
expect_class(pipe, "Graph")
140+
pipe = mlr3pipelines::ppl("survtoclassif", mlr3learners::LearnerClassifLogReg$new(), max_time = t, graph_learner = TRUE)
141+
expect_class(pipe, "GraphLearner")
140142
expect_error(pipe$train(task))
141143

142-
pipe = mlr3pipelines::ppl("survtoclassif", mlr3learners::LearnerClassifLogReg$new(), max_time = t + 1)
144+
pipe = mlr3pipelines::ppl("survtoclassif", mlr3learners::LearnerClassifLogReg$new(), max_time = t + 1, graph_learner = TRUE)
143145
pipe$train(task) |> suppressWarnings()
144146
p = pipe$predict(task)
145-
expect_prediction_surv(p[[1]])
147+
expect_prediction_surv(p)
146148

147149
# Test with rhs
148-
pipe = ppl("survtoclassif", learner = lrn("classif.log_reg"), rhs = "1")
150+
pipe = ppl("survtoclassif", learner = lrn("classif.log_reg"), rhs = "1", graph_learner = TRUE)
149151
pipe$train(task)
150152
pred = pipe$predict(task)
151153

152-
pipe = ppl("survtoclassif", learner = lrn("classif.featureless"))
154+
pipe = ppl("survtoclassif", learner = lrn("classif.featureless"), graph_learner = TRUE)
153155
pipe$train(task)
154156
pred2 = pipe$predict(task)
155157

156-
expect_equal(pred$trafopred_classifsurv.output$data$distr,
157-
pred2$trafopred_classifsurv.output$data$distr)
158+
expect_equal(pred$data$distr, pred2$data$distr)
159+
160+
pipe = ppl("survtoclassif", learner = lrn("classif.log_reg"), rhs = "rx + litter", graph_learner = TRUE)
161+
pipe$train(task)
162+
pred = pipe$predict(task)
163+
164+
pipe = ppl("survtoclassif", learner = lrn("classif.log_reg"), rhs = ".", graph_learner = TRUE)
165+
pipe$train(task)
166+
pred2 = pipe$predict(task) |> suppressWarnings()
167+
168+
expect_true(pred$score() < pred2$score())
158169
})

0 commit comments

Comments
 (0)