Skip to content

Commit 948c1aa

Browse files
committed
Merge branch 'main' into add_rcll
2 parents f654b2b + be96ee3 commit 948c1aa

File tree

5 files changed

+13
-2
lines changed

5 files changed

+13
-2
lines changed

NEWS.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# mlr3proba 0.4.7
22

33
* Add right-censored log loss
4+
* Fix bug in {rpart} where model was being discarded when set to be kept. Parameter `model` now called `keep_model`.
45

56
# mlr3proba 0.4.6
67

R/LearnerSurvRpart.R

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
#'
77
#' @description
88
#' Parameter `xval` is set to 0 in order to save some computation time.
9+
#' Parameter `model` has been renamed to `keep_model`.
910
#'
1011
#' @references
1112
#' `r format_bib("breiman_1984")`
@@ -28,7 +29,8 @@ LearnerSurvRpart = R6Class("LearnerSurvRpart",
2829
usesurrogate = p_int(default = 2L, lower = 0L, upper = 2L, tags = "train"),
2930
surrogatestyle = p_int(default = 0L, lower = 0L, upper = 1L, tags = "train"),
3031
xval = p_int(default = 10L, lower = 0L, tags = "train"),
31-
cost = p_uty(tags = "train")
32+
cost = p_uty(tags = "train"),
33+
keep_model = p_lgl(default = FALSE, tags = "train")
3234
)
3335
ps$values = list(xval = 0L)
3436

@@ -69,6 +71,7 @@ LearnerSurvRpart = R6Class("LearnerSurvRpart",
6971
private = list(
7072
.train = function(task) {
7173
pv = self$param_set$get_values(tags = "train")
74+
names(pv) = replace(names(pv), names(pv) == "keep_model", "model")
7275
if ("weights" %in% task$properties) {
7376
pv = insert_named(pv, list(weights = task$weights$weight))
7477
}

man/mlr_learners_surv.coxph.Rd

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/mlr_learners_surv.rpart.Rd

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

tests/testthat/test_mlr_learners_surv_rpart.R

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,3 +16,9 @@ test_that("importance/selected", {
1616
expect_silent(learner$selected_features())
1717
expect_silent(learner$importance())
1818
})
19+
20+
test_that("keep_model", {
21+
learner = lrn("surv.rpart", keep_model = TRUE)
22+
learner$train(tsk("rats"))
23+
expect_false(is.null(learner$model$model))
24+
})

0 commit comments

Comments
 (0)