Skip to content

Commit 734c3be

Browse files
authored
Merge pull request #295 from stan-dev/remove-library-loo
Don't library loo in tests
2 parents 807e0ea + 787ef2c commit 734c3be

21 files changed

+3062
-954
lines changed

tests/testthat/test_0_helpers.R

Lines changed: 28 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
library(loo)
2-
31
LLarr <- example_loglik_array()
42
LLmat <- example_loglik_matrix()
53

@@ -24,18 +22,26 @@ test_that("reshaping functions result in correct dimensions", {
2422
})
2523

2624
test_that("reshaping functions throw correct errors", {
27-
expect_error(llmatrix_to_array(LLmat, chain_id = rep(1:2, times = c(400, 600))),
28-
regexp = "Not all chains have same number of iterations",
29-
fixed = TRUE)
30-
expect_error(llmatrix_to_array(LLmat, chain_id = rep(1:2, each = 400)),
31-
regexp = "Number of rows in matrix not equal to length(chain_id)",
32-
fixed = TRUE)
33-
expect_error(llmatrix_to_array(LLmat, chain_id = rep(2:3, each = 500)),
34-
regexp = "max(chain_id) not equal to the number of chains",
35-
fixed = TRUE)
36-
expect_error(llmatrix_to_array(LLmat, chain_id = rnorm(1000)),
37-
regexp = "all(chain_id == as.integer(chain_id)) is not TRUE",
38-
fixed = TRUE)
25+
expect_error(
26+
llmatrix_to_array(LLmat, chain_id = rep(1:2, times = c(400, 600))),
27+
regexp = "Not all chains have same number of iterations",
28+
fixed = TRUE
29+
)
30+
expect_error(
31+
llmatrix_to_array(LLmat, chain_id = rep(1:2, each = 400)),
32+
regexp = "Number of rows in matrix not equal to length(chain_id)",
33+
fixed = TRUE
34+
)
35+
expect_error(
36+
llmatrix_to_array(LLmat, chain_id = rep(2:3, each = 500)),
37+
regexp = "max(chain_id) not equal to the number of chains",
38+
fixed = TRUE
39+
)
40+
expect_error(
41+
llmatrix_to_array(LLmat, chain_id = rnorm(1000)),
42+
regexp = "all(chain_id == as.integer(chain_id)) is not TRUE",
43+
fixed = TRUE
44+
)
3945
})
4046

4147
test_that("colLogMeanExps(x) = log(colMeans(exp(x))) ", {
@@ -54,9 +60,14 @@ test_that("validating log-lik objects and functions works", {
5460
})
5561

5662
test_that("nlist works", {
57-
a <- 1; b <- 2; c <- 3;
63+
a <- 1
64+
b <- 2
65+
c <- 3
5866
nlist_val <- list(nlist(a, b, c), nlist(a, b, c = "tornado"))
59-
nlist_ans <- list(list(a = 1, b = 2, c = 3), list(a = 1, b = 2, c = "tornado"))
67+
nlist_ans <- list(
68+
list(a = 1, b = 2, c = 3),
69+
list(a = 1, b = 2, c = "tornado")
70+
)
6071
expect_equal(nlist_val, nlist_ans)
6172
expect_equal(nlist(a = 1, b = 2, c = 3), list(a = 1, b = 2, c = 3))
6273
})
@@ -69,6 +80,5 @@ test_that("loo_cores works", {
6980

7081
options(loo.cores = 2)
7182
expect_warning(expect_equal(loo_cores(10), 2), "deprecated")
72-
options(loo.cores=NULL)
83+
options(loo.cores = NULL)
7384
})
74-

tests/testthat/test_E_loo.R

Lines changed: 51 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
library(loo)
2-
31
LLarr <- example_loglik_array()
42
LLmat <- example_loglik_matrix()
53
LLvec <- LLmat[, 1]
@@ -17,15 +15,54 @@ log_rats <- -LLmat
1715
E_test_mean <- E_loo(x, psis_mat, type = "mean", log_ratios = log_rats)
1816
E_test_var <- E_loo(x, psis_mat, type = "var", log_ratios = log_rats)
1917
E_test_sd <- E_loo(x, psis_mat, type = "sd", log_ratios = log_rats)
20-
E_test_quant <- E_loo(x, psis_mat, type = "quantile", probs = 0.5, log_ratios = log_rats)
21-
E_test_quant2 <- E_loo(x, psis_mat, type = "quantile", probs = c(0.1, 0.9), log_ratios = log_rats)
18+
E_test_quant <- E_loo(
19+
x,
20+
psis_mat,
21+
type = "quantile",
22+
probs = 0.5,
23+
log_ratios = log_rats
24+
)
25+
E_test_quant2 <- E_loo(
26+
x,
27+
psis_mat,
28+
type = "quantile",
29+
probs = c(0.1, 0.9),
30+
log_ratios = log_rats
31+
)
2232

2333
# vector method
24-
E_test_mean_vec <- E_loo(x[, 1], psis_vec, type = "mean", log_ratios = log_rats[,1])
25-
E_test_var_vec <- E_loo(x[, 1], psis_vec, type = "var", log_ratios = log_rats[,1])
26-
E_test_sd_vec <- E_loo(x[, 1], psis_vec, type = "sd", log_ratios = log_rats[,1])
27-
E_test_quant_vec <- E_loo(x[, 1], psis_vec, type = "quant", probs = 0.5, log_ratios = log_rats[,1])
28-
E_test_quant_vec2 <- E_loo(x[, 1], psis_vec, type = "quant", probs = c(0.1, 0.5, 0.9), log_ratios = log_rats[,1])
34+
E_test_mean_vec <- E_loo(
35+
x[, 1],
36+
psis_vec,
37+
type = "mean",
38+
log_ratios = log_rats[, 1]
39+
)
40+
E_test_var_vec <- E_loo(
41+
x[, 1],
42+
psis_vec,
43+
type = "var",
44+
log_ratios = log_rats[, 1]
45+
)
46+
E_test_sd_vec <- E_loo(
47+
x[, 1],
48+
psis_vec,
49+
type = "sd",
50+
log_ratios = log_rats[, 1]
51+
)
52+
E_test_quant_vec <- E_loo(
53+
x[, 1],
54+
psis_vec,
55+
type = "quant",
56+
probs = 0.5,
57+
log_ratios = log_rats[, 1]
58+
)
59+
E_test_quant_vec2 <- E_loo(
60+
x[, 1],
61+
psis_vec,
62+
type = "quant",
63+
probs = c(0.1, 0.5, 0.9),
64+
log_ratios = log_rats[, 1]
65+
)
2966

3067
# E_loo_khat
3168
khat <- loo:::E_loo_khat.matrix(x, psis_mat, log_rats)
@@ -114,11 +151,11 @@ test_that("E_loo throws correct errors and warnings", {
114151
# warnings
115152
expect_no_warning(E_loo.matrix(x, psis_mat))
116153
# no warnings if x is constant, binary, NA, NaN, Inf
117-
expect_no_warning(E_loo.matrix(x*0, psis_mat))
118-
expect_no_warning(E_loo.matrix(0+(x>0), psis_mat))
119-
expect_no_warning(E_loo.matrix(x+NA, psis_mat))
120-
expect_no_warning(E_loo.matrix(x*NaN, psis_mat))
121-
expect_no_warning(E_loo.matrix(x*Inf, psis_mat))
154+
expect_no_warning(E_loo.matrix(x * 0, psis_mat))
155+
expect_no_warning(E_loo.matrix(0 + (x > 0), psis_mat))
156+
expect_no_warning(E_loo.matrix(x + NA, psis_mat))
157+
expect_no_warning(E_loo.matrix(x * NaN, psis_mat))
158+
expect_no_warning(E_loo.matrix(x * Inf, psis_mat))
122159
expect_no_warning(E_test <- E_loo.default(x[, 1], psis_vec))
123160
expect_length(E_test$pareto_k, 1)
124161

@@ -161,7 +198,6 @@ test_that("weighted quantiles work", {
161198
quantile(xx, probs, names = FALSE)
162199
}
163200

164-
165201
set.seed(123)
166202
pr <- seq(0.025, 0.975, 0.025)
167203

tests/testthat/test_compare.R

Lines changed: 75 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
library(loo)
21
set.seed(123)
32

43
LLarr <- example_loglik_array()
@@ -12,45 +11,63 @@ test_that("loo_compare throws appropriate errors", {
1211
w4 <- suppressWarnings(waic(LLarr[,, -(1:2)]))
1312

1413
expect_error(loo_compare(2, 3), "must be a list if not a 'loo' object")
15-
expect_error(loo_compare(w1, w2, x = list(w1, w2)),
16-
"If 'x' is a list then '...' should not be specified")
17-
expect_error(loo_compare(w1, list(1,2,3)), "class 'loo'")
14+
expect_error(
15+
loo_compare(w1, w2, x = list(w1, w2)),
16+
"If 'x' is a list then '...' should not be specified"
17+
)
18+
expect_error(loo_compare(w1, list(1, 2, 3)), "class 'loo'")
1819
expect_error(loo_compare(w1), "requires at least two models")
1920
expect_error(loo_compare(x = list(w1)), "requires at least two models")
2021
expect_error(loo_compare(w1, w3), "same number of data points")
2122
expect_error(loo_compare(w1, w2, w3), "same number of data points")
2223
})
2324

2425
test_that("loo_compare throws appropriate warnings", {
25-
w3 <- w1; w4 <- w2
26+
w3 <- w1
27+
w4 <- w2
2628
class(w3) <- class(w4) <- c("kfold", "loo")
2729
attr(w3, "K") <- 2
2830
attr(w4, "K") <- 3
29-
expect_warning(loo_compare(w3, w4), "Not all kfold objects have the same K value")
31+
expect_warning(
32+
loo_compare(w3, w4),
33+
"Not all kfold objects have the same K value"
34+
)
3035

3136
class(w4) <- c("psis_loo", "loo")
3237
attr(w4, "K") <- NULL
3338
expect_warning(loo_compare(w3, w4), "Comparing LOO-CV to K-fold-CV")
3439

35-
w3 <- w1; w4 <- w2
40+
w3 <- w1
41+
w4 <- w2
3642
attr(w3, "yhash") <- "a"
3743
attr(w4, "yhash") <- "b"
3844
expect_warning(loo_compare(w3, w4), "Not all models have the same y variable")
3945

4046
set.seed(123)
41-
w_list <- lapply(1:25, function(x) suppressWarnings(waic(LLarr + rnorm(1, 0, 0.1))))
42-
expect_warning(loo_compare(w_list),
43-
"Difference in performance potentially due to chance")
44-
45-
w_list_short <- lapply(1:4, function(x) suppressWarnings(waic(LLarr + rnorm(1, 0, 0.1))))
47+
w_list <- lapply(1:25, function(x) {
48+
suppressWarnings(waic(LLarr + rnorm(1, 0, 0.1)))
49+
})
50+
expect_warning(
51+
loo_compare(w_list),
52+
"Difference in performance potentially due to chance"
53+
)
54+
55+
w_list_short <- lapply(1:4, function(x) {
56+
suppressWarnings(waic(LLarr + rnorm(1, 0, 0.1)))
57+
})
4658
expect_no_warning(loo_compare(w_list_short))
4759
})
4860

4961

50-
5162
comp_colnames <- c(
52-
"elpd_diff", "se_diff", "elpd_waic", "se_elpd_waic",
53-
"p_waic", "se_p_waic", "waic", "se_waic"
63+
"elpd_diff",
64+
"se_diff",
65+
"elpd_waic",
66+
"se_elpd_waic",
67+
"p_waic",
68+
"se_p_waic",
69+
"waic",
70+
"se_waic"
5471
)
5572

5673
test_that("loo_compare returns expected results (2 models)", {
@@ -59,15 +76,15 @@ test_that("loo_compare returns expected results (2 models)", {
5976
expect_equal(colnames(comp1), comp_colnames)
6077
expect_equal(rownames(comp1), c("model1", "model2"))
6178
expect_output(print(comp1), "elpd_diff")
62-
expect_equal(comp1[1:2,1], c(0, 0), ignore_attr = TRUE)
63-
expect_equal(comp1[1:2,2], c(0, 0), ignore_attr = TRUE)
79+
expect_equal(comp1[1:2, 1], c(0, 0), ignore_attr = TRUE)
80+
expect_equal(comp1[1:2, 2], c(0, 0), ignore_attr = TRUE)
6481

6582
comp2 <- loo_compare(w1, w2)
6683
expect_s3_class(comp2, "compare.loo")
6784
expect_equal(colnames(comp2), comp_colnames)
68-
85+
6986
expect_snapshot_value(comp2, style = "serialize")
70-
87+
7188
# specifying objects via ... and via arg x gives equal results
7289
expect_equal(comp2, loo_compare(x = list(w1, w2)))
7390
})
@@ -79,7 +96,7 @@ test_that("loo_compare returns expected result (3 models)", {
7996

8097
expect_equal(colnames(comp1), comp_colnames)
8198
expect_equal(rownames(comp1), c("model1", "model2", "model3"))
82-
expect_equal(comp1[1,1], 0)
99+
expect_equal(comp1[1, 1], 0)
83100
expect_s3_class(comp1, "compare.loo")
84101
expect_s3_class(comp1, "matrix")
85102

@@ -119,34 +136,53 @@ test_that("compare returns expected result (3 models)", {
119136
expect_equal(
120137
colnames(comp1),
121138
c(
122-
"elpd_diff", "se_diff", "elpd_waic", "se_elpd_waic",
123-
"p_waic", "se_p_waic", "waic", "se_waic"
124-
))
139+
"elpd_diff",
140+
"se_diff",
141+
"elpd_waic",
142+
"se_elpd_waic",
143+
"p_waic",
144+
"se_p_waic",
145+
"waic",
146+
"se_waic"
147+
)
148+
)
125149
expect_equal(rownames(comp1), c("w1", "w2", "w3"))
126-
expect_equal(comp1[1,1], 0)
150+
expect_equal(comp1[1, 1], 0)
127151
expect_s3_class(comp1, "compare.loo")
128152
expect_s3_class(comp1, "matrix")
129153
expect_snapshot_value(comp1, style = "serialize")
130154

131155
# specifying objects via '...' gives equivalent results (equal
132156
# except rownames) to using 'x' argument
133-
expect_warning(comp_via_list <- loo::compare(x = list(w1, w2, w3)), "Deprecated")
157+
expect_warning(
158+
comp_via_list <- loo::compare(x = list(w1, w2, w3)),
159+
"Deprecated"
160+
)
134161
expect_equal(comp1, comp_via_list, ignore_attr = TRUE)
135162
})
136163

137164
test_that("compare throws appropriate errors", {
138-
expect_error(suppressWarnings(loo::compare(w1, w2, x = list(w1, w2))),
139-
"should not be specified")
140-
expect_error(suppressWarnings(loo::compare(x = 2)),
141-
"must be a list")
142-
expect_error(suppressWarnings(loo::compare(x = list(2))),
143-
"should have class 'loo'")
144-
expect_error(suppressWarnings(loo::compare(x = list(w1))),
145-
"requires at least two models")
146-
147-
w3 <- suppressWarnings(waic(LLarr2[,,-1]))
148-
expect_error(suppressWarnings(loo::compare(x = list(w1, w3))),
149-
"same number of data points")
150-
expect_error(suppressWarnings(loo::compare(x = list(w1, w2, w3))),
151-
"same number of data points")
165+
expect_error(
166+
suppressWarnings(loo::compare(w1, w2, x = list(w1, w2))),
167+
"should not be specified"
168+
)
169+
expect_error(suppressWarnings(loo::compare(x = 2)), "must be a list")
170+
expect_error(
171+
suppressWarnings(loo::compare(x = list(2))),
172+
"should have class 'loo'"
173+
)
174+
expect_error(
175+
suppressWarnings(loo::compare(x = list(w1))),
176+
"requires at least two models"
177+
)
178+
179+
w3 <- suppressWarnings(waic(LLarr2[,, -1]))
180+
expect_error(
181+
suppressWarnings(loo::compare(x = list(w1, w3))),
182+
"same number of data points"
183+
)
184+
expect_error(
185+
suppressWarnings(loo::compare(x = list(w1, w2, w3))),
186+
"same number of data points"
187+
)
152188
})

tests/testthat/test_deprecated_extractors.R

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
library(loo)
21
options(mc.cores = 1)
32
set.seed(123)
43

tests/testthat/test_extract_log_lik.R

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
library(loo)
2-
31
test_that("extract_log_lik throws appropriate errors", {
42
x1 <- rnorm(100)
53
expect_error(extract_log_lik(x1), regexp = "Not a stanfit object")

tests/testthat/test_gpdfit.R

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,20 @@
1-
library(loo)
2-
31
test_that("gpdfit returns correct result", {
42
set.seed(123)
53
x <- rexp(100)
6-
gpdfit_val_old <- unlist(gpdfit(x, wip=FALSE, min_grid_pts = 80))
4+
gpdfit_val_old <- unlist(gpdfit(x, wip = FALSE, min_grid_pts = 80))
75
expect_snapshot_value(gpdfit_val_old, style = "serialize")
86

9-
gpdfit_val_wip <- unlist(gpdfit(x, wip=TRUE, min_grid_pts = 80))
7+
gpdfit_val_wip <- unlist(gpdfit(x, wip = TRUE, min_grid_pts = 80))
108
expect_snapshot_value(gpdfit_val_wip, style = "serialize")
119

12-
gpdfit_val_wip_default_grid <- unlist(gpdfit(x, wip=TRUE))
10+
gpdfit_val_wip_default_grid <- unlist(gpdfit(x, wip = TRUE))
1311
expect_snapshot_value(gpdfit_val_wip_default_grid, style = "serialize")
1412
})
1513

1614
test_that("qgpd returns the correct result ", {
1715
probs <- seq(from = 0, to = 1, by = 0.25)
1816
q1 <- qgpd(probs, k = 1, sigma = 1)
19-
expect_equal(q1, c(0, 1/3, 1, 3, Inf))
17+
expect_equal(q1, c(0, 1 / 3, 1, 3, Inf))
2018

2119
q2 <- qgpd(probs, k = 1, sigma = 0)
2220
expect_true(all(is.nan(q2)))

0 commit comments

Comments
 (0)