Skip to content

Commit be96ee3

Browse files
authored
Merge pull request #265 from mlr-org/keep_model
Support hyperparameter 'model' in surv.rpart
2 parents e4145cd + 0bd1d3f commit be96ee3

File tree

6 files changed

+17
-3
lines changed

6 files changed

+17
-3
lines changed

DESCRIPTION

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
Package: mlr3proba
22
Title: Probabilistic Supervised Learning for 'mlr3'
3-
Version: 0.4.6
3+
Version: 0.4.7
44
Authors@R:
55
c(person(given = "Raphael",
66
family = "Sonabend",

NEWS.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,7 @@
1+
# mlr3proba 0.4.7
2+
3+
* Fix bug in {rpart} where model was being discarded when set to be kept. Parameter `model` now called `keep_model`.
4+
15
# mlr3proba 0.4.6
26

37
* Patch for upstream breakages

R/LearnerSurvRpart.R

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
#'
77
#' @description
88
#' Parameter `xval` is set to 0 in order to save some computation time.
9+
#' Parameter `model` has been renamed to `keep_model`.
910
#'
1011
#' @references
1112
#' `r format_bib("breiman_1984")`
@@ -28,7 +29,8 @@ LearnerSurvRpart = R6Class("LearnerSurvRpart",
2829
usesurrogate = p_int(default = 2L, lower = 0L, upper = 2L, tags = "train"),
2930
surrogatestyle = p_int(default = 0L, lower = 0L, upper = 1L, tags = "train"),
3031
xval = p_int(default = 10L, lower = 0L, tags = "train"),
31-
cost = p_uty(tags = "train")
32+
cost = p_uty(tags = "train"),
33+
keep_model = p_lgl(default = FALSE, tags = "train")
3234
)
3335
ps$values = list(xval = 0L)
3436

@@ -69,6 +71,7 @@ LearnerSurvRpart = R6Class("LearnerSurvRpart",
6971
private = list(
7072
.train = function(task) {
7173
pv = self$param_set$get_values(tags = "train")
74+
names(pv) = replace(names(pv), names(pv) == "keep_model", "model")
7275
if ("weights" %in% task$properties) {
7376
pv = insert_named(pv, list(weights = task$weights$weight))
7477
}

man/mlr_learners_surv.coxph.Rd

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/mlr_learners_surv.rpart.Rd

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

tests/testthat/test_mlr_learners_surv_rpart.R

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,3 +16,9 @@ test_that("importance/selected", {
1616
expect_silent(learner$selected_features())
1717
expect_silent(learner$importance())
1818
})
19+
20+
test_that("keep_model", {
21+
learner = lrn("surv.rpart", keep_model = TRUE)
22+
learner$train(tsk("rats"))
23+
expect_false(is.null(learner$model$model))
24+
})

0 commit comments

Comments
 (0)