Skip to content

Commit ed977c7

Browse files
committed
Add print method for sdmTMB_cv() #319
1 parent 796eb08 commit ed977c7

File tree

5 files changed

+32
-2
lines changed

5 files changed

+32
-2
lines changed

DESCRIPTION

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
Type: Package
22
Package: sdmTMB
33
Title: Spatial and Spatiotemporal SPDE-Based GLMMs with 'TMB'
4-
Version: 0.6.0.9009
4+
Version: 0.6.0.9010
55
Authors@R: c(
66
person(c("Sean", "C."), "Anderson", , "sean@seananderson.ca",
77
role = c("aut", "cre"),

NAMESPACE

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ S3method(nobs,sdmTMB)
1414
S3method(plot,sdmTMBmesh)
1515
S3method(predict,sdmTMB)
1616
S3method(print,sdmTMB)
17+
S3method(print,sdmTMB_cv)
1718
S3method(ranef,sdmTMB)
1819
S3method(residuals,sdmTMB)
1920
S3method(simulate,sdmTMB)

NEWS.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
# sdmTMB (development version)
22

3+
* Add print method for `sdmTMB_cv()` output. #319
4+
35
* Add progress bar to `simulate.sdmTMB()`. #346
46

57
* Add AUC and TSS examples to cross validation vignette. #268

R/cross-val.R

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -453,7 +453,7 @@ sdmTMB_cv <- function(
453453
pdHess <- vapply(out, `[[`, "pdHess", FUN.VALUE = logical(1L))
454454
max_grad <- vapply(out, `[[`, "max_gradient", FUN.VALUE = numeric(1L))
455455
converged <- all(pdHess)
456-
list(
456+
out <- list(
457457
data = data,
458458
models = models,
459459
fold_loglik = fold_cv_ll,
@@ -462,9 +462,35 @@ sdmTMB_cv <- function(
462462
pdHess = pdHess,
463463
max_gradients = max_grad
464464
)
465+
`class<-`(out, "sdmTMB_cv")
465466
}
466467

467468
log_sum_exp <- function(x) {
468469
max_x <- max(x)
469470
max_x + log(sum(exp(x - max_x)))
470471
}
472+
473+
#' @export
474+
#' @import methods
475+
print.sdmTMB_cv <- function(x, ...) {
476+
nmods <- length(x$models)
477+
nconverged <- sum(x$converged)
478+
cat(paste0("Cross validation of sdmTMB models with ", nmods, " folds.\n"))
479+
cat("\n")
480+
cat("Summary of the first fold model fit:\n")
481+
cat("\n")
482+
print(x$models[[1]])
483+
cat("\n")
484+
cat("Access the rest of the models in a list element named `models`.\n")
485+
cat("E.g. `object$models[[2]]` for the 2nd fold model fit.\n")
486+
cat("\n")
487+
cat(paste0(nconverged, " out of ", nmods, " models are consistent with convergence.\n"))
488+
cat("Figure out which folds these are in the `converged` list element.\n")
489+
cat("\n")
490+
cat(paste0("Out-of-sample log likelihood for each fold: ", paste(round(x$fold_loglik, 2), collapse = ", "), ".\n"))
491+
cat("Access these values in the `fold_loglik` list element.\n")
492+
cat("\n")
493+
cat("Sum of out-of-sample log likelihoods:", round(x$sum_loglik, 2), "\n")
494+
cat("More positive values imply better out-of-sample prediction.\n")
495+
cat("Access this value in the `sum_loglik` list element.\n")
496+
}

tests/testthat/test-cross-validation.R

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ test_that("Basic cross validation works", {
1212
data = d, mesh = spde,
1313
family = tweedie(link = "log"), time = "year", k_folds = 2
1414
)
15+
print(x)
1516
expect_equal(class(x$sum_loglik), "numeric")
1617
expect_equal(x$sum_loglik, sum(x$data$cv_loglik))
1718
expect_equal(x$sum_loglik, sum(x$fold_loglik))

0 commit comments

Comments
 (0)