|
| 1 | +#' @title Generate Predicted Antibody Response Curves (Median + 95% CI) |
| 2 | +#' @description |
| 3 | +#' Plots a median antibody response curve with a 95% credible interval |
| 4 | +#' ribbon, using MCMC samples from the posterior distribution. |
| 5 | +#' Optionally overlays observed data, |
| 6 | +#' applies logarithmic spacing on the y- and x-axes, |
| 7 | +#' and shows all individual |
| 8 | +#' sampled curves. |
| 9 | +#' |
| 10 | +#' @param sr_model An `sr_model` object (returned by [run_mod()]) containing |
| 11 | +#' samples from the posterior distribution of the model parameters. |
| 12 | +#' @param id The participant ID to plot; for example, "sees_npl_128". |
| 13 | +#' @param antigen_iso The antigen isotype to plot; for example, "HlyE_IgA" or |
| 14 | +#' "HlyE_IgG". |
| 15 | +#' @param dataset (Optional) A [dplyr::tbl_df] with observed antibody response |
| 16 | +#' data. |
| 17 | +#' Must contain: |
| 18 | +#' - `timeindays` |
| 19 | +#' - `value` |
| 20 | +#' - `id` |
| 21 | +#' - `antigen_iso` |
| 22 | +#' @param legend_obs Label for observed data in the legend. |
| 23 | +#' @param legend_median Label for the median prediction line. |
| 24 | +#' @param show_quantiles [logical]; if [TRUE] (default), plots the 2.5%, 50%, |
| 25 | +#' and 97.5% quantiles. |
| 26 | +#' @param log_y [logical]; if [TRUE], applies a [log10] transformation to |
| 27 | +#' the y-axis. |
| 28 | +#' @param log_x [logical]; if [TRUE], applies a [log10] transformation to the |
| 29 | +#' x-axis. |
| 30 | +#' @param show_all_curves [logical]; if [TRUE], overlays all |
| 31 | +#' individual sampled curves. |
| 32 | +#' @param alpha_samples Numeric; transparency level for individual |
| 33 | +#' curves (default = 0.3). |
| 34 | +#' @param xlim (Optional) A numeric vector of length 2 providing custom x-axis |
| 35 | +#' limits. |
| 36 | +#' @param ylab (Optional) A string for the y-axis label. If `NULL` (default), |
| 37 | +#' the label is automatically set to "ELISA units" or "ELISA units (log scale)" |
| 38 | +#' based on the `log_y` argument. |
| 39 | +#' |
| 40 | +#' @return A [ggplot2::ggplot] object displaying predicted antibody response |
| 41 | +#' curves with a median curve and a 95% credible interval band as default. |
| 42 | +#' @export |
| 43 | +#' |
| 44 | +#' @example inst/examples/examples-plot_predicted_curve.R |
| 45 | +plot_predicted_curve <- function(sr_model, |
| 46 | + id, |
| 47 | + antigen_iso, |
| 48 | + dataset = NULL, |
| 49 | + legend_obs = "Observed data", |
| 50 | + legend_median = "Median prediction", |
| 51 | + show_quantiles = TRUE, |
| 52 | + log_y = FALSE, |
| 53 | + log_x = FALSE, |
| 54 | + show_all_curves = FALSE, |
| 55 | + alpha_samples = 0.3, |
| 56 | + xlim = NULL, |
| 57 | + ylab = NULL) { |
| 58 | + |
| 59 | + # -------------------------------------------------------------------------- |
| 60 | + # 1) The 'sr_model' object is now the tibble itself |
| 61 | + df <- sr_model |
| 62 | + |
| 63 | + |
| 64 | + # -------------------------------------------------------------------------- |
| 65 | + # 2) Filter to the subject & antigen of interest: |
| 66 | + df_sub <- df |> |
| 67 | + dplyr::filter( |
| 68 | + .data$Subject == id, # e.g. "sees_npl_128" |
| 69 | + .data$Iso_type == antigen_iso # e.g. "HlyE_IgA" |
| 70 | + ) |
| 71 | + |
| 72 | + # -------------------------------------------------------------------------- |
| 73 | + # 3) Pivot to wide format: one row per iteration/chain |
| 74 | + param_medians_wide <- df_sub |> |
| 75 | + dplyr::select( |
| 76 | + all_of(c("Chain", |
| 77 | + "Iteration", |
| 78 | + "Iso_type", |
| 79 | + "Parameter", |
| 80 | + "value")) |
| 81 | + ) |> |
| 82 | + tidyr::pivot_wider( |
| 83 | + names_from = c("Parameter"), |
| 84 | + values_from = c("value") |
| 85 | + ) |> |
| 86 | + dplyr::arrange(.data$Chain, .data$Iteration) |> |
| 87 | + |
| 88 | + dplyr::mutate( |
| 89 | + antigen_iso = factor(.data$Iso_type), |
| 90 | + r = .data$shape |
| 91 | + ) |> |
| 92 | + dplyr::select(-c("Iso_type")) |
| 93 | + |
| 94 | + # Add sample_id if not present (to identify individual samples) |
| 95 | + if (!"sample_id" %in% names(param_medians_wide)) { |
| 96 | + param_medians_wide <- param_medians_wide |> |
| 97 | + dplyr::mutate(sample_id = dplyr::row_number()) |
| 98 | + } |
| 99 | + # Define time points for prediction |
| 100 | + tx2 <- seq(0, 1200, by = 5) |
| 101 | + |
| 102 | + |
| 103 | + ## --- Prepare data for Model 1 --- |
| 104 | + dt1 <- data.frame(t = tx2) |> |
| 105 | + dplyr::mutate(id = dplyr::row_number()) |> |
| 106 | + tidyr::pivot_wider(names_from = "id", |
| 107 | + values_from = "t", |
| 108 | + names_prefix = "time") |> |
| 109 | + dplyr::slice( |
| 110 | + rep(seq_len(dplyr::n()), each = nrow(param_medians_wide)) |
| 111 | + ) |
| 112 | + |
| 113 | + |
| 114 | + serocourse_all1 <- cbind(param_medians_wide, dt1) |> |
| 115 | + tidyr::pivot_longer(cols = dplyr::starts_with("time"), values_to = "t") |> |
| 116 | + dplyr::select(-c("name")) |> |
| 117 | + dplyr::rowwise() |> |
| 118 | + dplyr::mutate(res = ab(.data$t, |
| 119 | + .data$y0, |
| 120 | + .data$y1, |
| 121 | + .data$t1, |
| 122 | + .data$alpha, |
| 123 | + .data$shape)) |> |
| 124 | + dplyr::ungroup() |
| 125 | + |
| 126 | + # Determine Y-axis label |
| 127 | + if (is.null(ylab)) { |
| 128 | + if (log_y) { |
| 129 | + ylab <- "ELISA units (log scale)" |
| 130 | + } else { |
| 131 | + ylab <- "ELISA units" |
| 132 | + } |
| 133 | + } |
| 134 | + |
| 135 | + # Base ggplot object with legend at the bottom. |
| 136 | + p <- ggplot2::ggplot() + |
| 137 | + ggplot2::theme_minimal() + |
| 138 | + ggplot2::labs(x = "Days since fever onset", y = ylab) + |
| 139 | + ggplot2::theme(legend.position = "bottom") |
| 140 | + |
| 141 | + # If show_all_curves is TRUE, overlay all individual sampled curves. |
| 142 | + if (show_all_curves) { |
| 143 | + p <- p + |
| 144 | + ggplot2::geom_line(data = serocourse_all1, |
| 145 | + ggplot2::aes(x = .data$t, |
| 146 | + y = .data$res, |
| 147 | + group = .data$sample_id, |
| 148 | + color = "samples"), |
| 149 | + alpha = alpha_samples) |
| 150 | + } |
| 151 | + |
| 152 | + # --- Summarize & Plot Model 1 (Median + 95% Ribbon) --- |
| 153 | + if (show_quantiles) { |
| 154 | + sum1 <- serocourse_all1 |> |
| 155 | + dplyr::group_by(t) |> |
| 156 | + dplyr::summarise( |
| 157 | + res.med = stats::quantile(.data$res, probs = 0.50, na.rm = TRUE), |
| 158 | + res.low = stats::quantile(.data$res, probs = 0.025, na.rm = TRUE), |
| 159 | + res.high = stats::quantile(.data$res, probs = 0.975, na.rm = TRUE), |
| 160 | + .groups = "drop" |
| 161 | + ) |
| 162 | + |
| 163 | + p <- p + |
| 164 | + ggplot2::geom_ribbon(data = sum1, |
| 165 | + ggplot2::aes(x = .data$t, |
| 166 | + ymin = .data$res.low, |
| 167 | + ymax = .data$res.high, |
| 168 | + fill = "ci"), |
| 169 | + alpha = 0.2, inherit.aes = FALSE) + |
| 170 | + ggplot2::geom_line(data = sum1, |
| 171 | + ggplot2::aes(x = .data$t, |
| 172 | + y = .data$res.med, |
| 173 | + color = "median"), |
| 174 | + linewidth = 1, inherit.aes = FALSE) |
| 175 | + } |
| 176 | + |
| 177 | + # --- Overlay Observed Data (if provided) --- |
| 178 | + if (!is.null(dataset)) { |
| 179 | + observed_data <- dataset |> |
| 180 | + dplyr::rename(t = c("timeindays"), |
| 181 | + res = c("value")) |> |
| 182 | + dplyr::select(all_of(c("id", |
| 183 | + "t", |
| 184 | + "res", |
| 185 | + "antigen_iso"))) |> |
| 186 | + dplyr::mutate(id = as.factor(.data$id)) |
| 187 | + |
| 188 | + p <- p + |
| 189 | + ggplot2::geom_point(data = observed_data, |
| 190 | + ggplot2::aes(x = .data$t, |
| 191 | + y = .data$res, |
| 192 | + group = .data$id, |
| 193 | + color = "observed"), |
| 194 | + size = 2, show.legend = TRUE) + |
| 195 | + ggplot2::geom_line(data = observed_data, |
| 196 | + ggplot2::aes(x = .data$t, |
| 197 | + y = .data$res, |
| 198 | + group = .data$id, |
| 199 | + color = "observed"), |
| 200 | + linewidth = 1, show.legend = TRUE) |
| 201 | + } |
| 202 | + |
| 203 | + # --- Construct Unified Legend --- |
| 204 | + color_vals <- c("median" = "red") |
| 205 | + color_labels <- c("median" = legend_median) |
| 206 | + fill_vals <- c("ci" = "red") |
| 207 | + fill_labels <- c("ci" = "95% credible interval") |
| 208 | + |
| 209 | + if (show_all_curves) { |
| 210 | + color_vals["samples"] <- "gray" |
| 211 | + color_labels["samples"] <- "Posterior samples" |
| 212 | + } |
| 213 | + |
| 214 | + if (!is.null(dataset)) { |
| 215 | + color_vals["observed"] <- "blue" |
| 216 | + color_labels["observed"] <- legend_obs |
| 217 | + } |
| 218 | + |
| 219 | + p <- p + |
| 220 | + ggplot2::scale_color_manual( |
| 221 | + name = "Component", |
| 222 | + values = color_vals, |
| 223 | + labels = color_labels, |
| 224 | + guide = ggplot2::guide_legend(override.aes = list(shape = NA)) |
| 225 | + ) + |
| 226 | + ggplot2::scale_fill_manual( |
| 227 | + name = "Component", |
| 228 | + values = fill_vals, |
| 229 | + labels = fill_labels, |
| 230 | + guide = ggplot2::guide_legend(override.aes = list(color = NA)) |
| 231 | + ) |
| 232 | + |
| 233 | + # --- Optionally add log10 scales for y and/or x --- |
| 234 | + if (log_y) { |
| 235 | + p <- p + ggplot2::scale_y_log10() |
| 236 | + } |
| 237 | + if (log_x) { |
| 238 | + p <- p + |
| 239 | + ggplot2::scale_x_continuous( |
| 240 | + trans = scales::pseudo_log_trans(sigma = 1, base = 10) |
| 241 | + ) |
| 242 | + } |
| 243 | + |
| 244 | + # --- Set custom x-axis limits if provided --- |
| 245 | + if (!is.null(xlim)) { |
| 246 | + p <- p + ggplot2::coord_cartesian(xlim = xlim) |
| 247 | + } |
| 248 | + |
| 249 | + return(p) |
| 250 | +} |
0 commit comments