|
| 1 | +#' @title Wrap a learner into a PipeOp with survival predictions estimated by the Breslow estimator |
| 2 | +#' @name mlr_pipeops_compose_breslow_distr |
| 3 | +#' @template param_pipelines |
| 4 | +#' @description |
| 5 | +#' Composes a survival distribution (`distr`) using the linear predictor |
| 6 | +#' predictions (`lp`) from a given [LearnerSurv] during training and prediction, |
| 7 | +#' utilizing the [breslow estimator][breslow]. The specified `learner` must be |
| 8 | +#' capable of generating `lp`-type predictions (e.g., a Cox-type model). |
| 9 | +#' |
| 10 | +#' @section Dictionary: |
| 11 | +#' This [PipeOp][mlr3pipelines::PipeOp] can be instantiated via the |
| 12 | +#' [Dictionary][mlr3misc::Dictionary] [mlr_pipeops][mlr3pipelines::mlr_pipeops] |
| 13 | +#' or with the associated sugar function [po()][mlr3pipelines::po]: |
| 14 | +#' ``` |
| 15 | +#' PipeOpBreslow$new(learner) |
| 16 | +#' mlr_pipeops$get("breslowcompose", learner) |
| 17 | +#' po("breslowcompose", learner, breslow.overwrite = TRUE) |
| 18 | +#' ``` |
| 19 | +#' |
| 20 | +#' @section Input and Output Channels: |
| 21 | +#' [PipeOpBreslow] is like a [LearnerSurv]. |
| 22 | +#' It has one input channel, named `input` that takes a [TaskSurv] during training |
| 23 | +#' and another [TaskSurv] during prediction. |
| 24 | +#' [PipeOpBreslow] has one output channel named `output`, producing `NULL` during |
| 25 | +#' training and a [PredictionSurv] during prediction. |
| 26 | +#' |
| 27 | +#' @section State: |
| 28 | +#' The `$state` slot stores the `times` and `status` survival target variables of |
| 29 | +#' the train [TaskSurv] as well as the `lp` predictions on the train set. |
| 30 | +#' |
| 31 | +#' @section Parameters: |
| 32 | +#' The parameters are: |
| 33 | +#' * `breslow.overwrite` :: `logical(1)` \cr |
| 34 | +#' If `FALSE` (default) then the compositor does nothing and returns the |
| 35 | +#' input `learner`'s [PredictionSurv]. |
| 36 | +#' If `TRUE` or in the case that the input `learner` doesn't have `distr` |
| 37 | +#' predictions, then the `distr` is overwritten with the `distr` composed |
| 38 | +#' from `lp` and the train set information using [breslow]. |
| 39 | +#' This is useful for changing the prediction `distr` from one model form to |
| 40 | +#' another. |
| 41 | +#' @seealso [pipeline_distrcompositor] |
| 42 | +#' @export |
| 43 | +#' @family survival compositors |
| 44 | +#' @examples |
| 45 | +#' \dontrun{ |
| 46 | +#' if (requireNamespace("mlr3pipelines", quietly = TRUE)) { |
| 47 | +#' library(mlr3) |
| 48 | +#' library(mlr3pipelines) |
| 49 | +#' task = tsk("rats") |
| 50 | +#' part = partition(task, ratio = 0.8) |
| 51 | +#' train_task = task$clone()$filter(part$train) |
| 52 | +#' test_task = task$clone()$filter(part$test) |
| 53 | +#' |
| 54 | +#' learner = lrn("surv.coxph") # learner with lp predictions |
| 55 | +#' b = po("breslowcompose", learner = learner, breslow.overwrite = TRUE) |
| 56 | +#' |
| 57 | +#' b$train(list(train_task)) |
| 58 | +#' p = b$predict(list(test_task))[[1L]] |
| 59 | +#' } |
| 60 | +#' } |
| 61 | +PipeOpBreslow = R6Class("PipeOpBreslow", |
| 62 | + inherit = mlr3pipelines::PipeOp, |
| 63 | + public = list( |
| 64 | + #' @description |
| 65 | + #' Creates a new instance of this [R6][R6::R6Class] class. |
| 66 | + #' @param learner ([LearnerSurv])\cr |
| 67 | + #' Survival learner which must provide `lp`-type predictions |
| 68 | + #' @param id (character(1))\cr |
| 69 | + #' Identifier of the resulting object. If `NULL` (default), it will be set |
| 70 | + #' as the `id` of the input `learner`. |
| 71 | + initialize = function(learner, id = NULL, param_vals = list()) { |
| 72 | + assert_learner(learner, task_type = "surv") |
| 73 | + if ("lp" %nin% learner$predict_types) { |
| 74 | + stopf("Learner %s must provide lp predictions", learner$id) |
| 75 | + } |
| 76 | + |
| 77 | + # id of the PipeOp is the id of the learner |
| 78 | + private$.learner = as_learner(learner, clone = TRUE) |
| 79 | + private$.learner$param_set$set_id = "" |
| 80 | + id = id %??% private$.learner$id |
| 81 | + |
| 82 | + # define `breslow.overwrite` parameter |
| 83 | + private$.breslow_ps = ps( |
| 84 | + overwrite = p_lgl(default = FALSE, tags = c("predict", "required")) |
| 85 | + ) |
| 86 | + private$.breslow_ps$values = list(overwrite = FALSE) |
| 87 | + private$.breslow_ps$set_id = "breslow" |
| 88 | + |
| 89 | + super$initialize( |
| 90 | + id = id, |
| 91 | + param_set = alist(private$.breslow_ps, private$.learner$param_set), |
| 92 | + param_vals = param_vals, |
| 93 | + input = data.table(name = "input", train = "TaskSurv", predict = "TaskSurv"), |
| 94 | + output = data.table(name = "output", train = "NULL", predict = "PredictionSurv"), |
| 95 | + packages = private$.learner$packages |
| 96 | + ) |
| 97 | + } |
| 98 | + ), |
| 99 | + |
| 100 | + active = list( |
| 101 | + #' @field learner \cr |
| 102 | + #' The input survival learner. |
| 103 | + learner = function(rhs) { |
| 104 | + assert_ro_binding(rhs) |
| 105 | + private$.learner |
| 106 | + } |
| 107 | + ), |
| 108 | + |
| 109 | + private = list( |
| 110 | + .train = function(inputs) { |
| 111 | + task = inputs[[1]] |
| 112 | + learner = private$.learner |
| 113 | + |
| 114 | + # train learner |
| 115 | + learner$train(task) |
| 116 | + |
| 117 | + # predictions on the train set |
| 118 | + p = learner$predict(task) |
| 119 | + |
| 120 | + # Breslow works only with non-NA `lp` predictions |
| 121 | + if (anyMissing(p$lp)) { |
| 122 | + stopf("Missing lp predictions") |
| 123 | + } |
| 124 | + |
| 125 | + # keep the training data that Breslow estimator needs |
| 126 | + self$state = list( |
| 127 | + times = task$times(), |
| 128 | + status = task$status(), |
| 129 | + lp_train = p$lp |
| 130 | + ) |
| 131 | + |
| 132 | + list(NULL) |
| 133 | + }, |
| 134 | + |
| 135 | + .predict = function(inputs) { |
| 136 | + task = inputs[[1]] |
| 137 | + learner = private$.learner |
| 138 | + |
| 139 | + if (is.null(learner$model)) { |
| 140 | + stopf("Cannot predict, Learner '%s' has not been trained yet", learner$id) |
| 141 | + } |
| 142 | + |
| 143 | + # predictions on the test set |
| 144 | + p = learner$predict(task) |
| 145 | + |
| 146 | + pv = self$param_set$get_values(tags = "predict") |
| 147 | + overwrite = pv$breslow.overwrite |
| 148 | + |
| 149 | + # If learner predicts `distr` and overwrite is FALSE don't use breslow |
| 150 | + if ("distr" %in% learner$predict_types & !overwrite) { |
| 151 | + pred = list(p) |
| 152 | + } else { |
| 153 | + # Breslow works only with non-NA `lp` predictions |
| 154 | + if (anyMissing(p$lp)) { |
| 155 | + stopf("Missing lp predictions!") |
| 156 | + } |
| 157 | + |
| 158 | + distr = breslow( |
| 159 | + times = self$state$times, |
| 160 | + status = self$state$status, |
| 161 | + lp_train = self$state$lp_train, |
| 162 | + lp_test = p$lp |
| 163 | + ) |
| 164 | + |
| 165 | + pred = list(PredictionSurv$new( |
| 166 | + row_ids = p$row_ids, |
| 167 | + truth = p$truth, |
| 168 | + crank = p$crank, |
| 169 | + lp = p$lp, |
| 170 | + distr = distr |
| 171 | + )) |
| 172 | + } |
| 173 | + |
| 174 | + pred |
| 175 | + }, |
| 176 | + |
| 177 | + .breslow_ps = NULL, |
| 178 | + .learner = NULL |
| 179 | + ) |
| 180 | +) |
0 commit comments