1
- library(loo )
2
1
set.seed(123 )
3
2
4
3
LLarr <- example_loglik_array()
@@ -12,45 +11,63 @@ test_that("loo_compare throws appropriate errors", {
12
11
w4 <- suppressWarnings(waic(LLarr [,, - (1 : 2 )]))
13
12
14
13
expect_error(loo_compare(2 , 3 ), " must be a list if not a 'loo' object" )
15
- expect_error(loo_compare(w1 , w2 , x = list (w1 , w2 )),
16
- " If 'x' is a list then '...' should not be specified" )
17
- expect_error(loo_compare(w1 , list (1 ,2 ,3 )), " class 'loo'" )
14
+ expect_error(
15
+ loo_compare(w1 , w2 , x = list (w1 , w2 )),
16
+ " If 'x' is a list then '...' should not be specified"
17
+ )
18
+ expect_error(loo_compare(w1 , list (1 , 2 , 3 )), " class 'loo'" )
18
19
expect_error(loo_compare(w1 ), " requires at least two models" )
19
20
expect_error(loo_compare(x = list (w1 )), " requires at least two models" )
20
21
expect_error(loo_compare(w1 , w3 ), " same number of data points" )
21
22
expect_error(loo_compare(w1 , w2 , w3 ), " same number of data points" )
22
23
})
23
24
24
25
test_that(" loo_compare throws appropriate warnings" , {
25
- w3 <- w1 ; w4 <- w2
26
+ w3 <- w1
27
+ w4 <- w2
26
28
class(w3 ) <- class(w4 ) <- c(" kfold" , " loo" )
27
29
attr(w3 , " K" ) <- 2
28
30
attr(w4 , " K" ) <- 3
29
- expect_warning(loo_compare(w3 , w4 ), " Not all kfold objects have the same K value" )
31
+ expect_warning(
32
+ loo_compare(w3 , w4 ),
33
+ " Not all kfold objects have the same K value"
34
+ )
30
35
31
36
class(w4 ) <- c(" psis_loo" , " loo" )
32
37
attr(w4 , " K" ) <- NULL
33
38
expect_warning(loo_compare(w3 , w4 ), " Comparing LOO-CV to K-fold-CV" )
34
39
35
- w3 <- w1 ; w4 <- w2
40
+ w3 <- w1
41
+ w4 <- w2
36
42
attr(w3 , " yhash" ) <- " a"
37
43
attr(w4 , " yhash" ) <- " b"
38
44
expect_warning(loo_compare(w3 , w4 ), " Not all models have the same y variable" )
39
45
40
46
set.seed(123 )
41
- w_list <- lapply(1 : 25 , function (x ) suppressWarnings(waic(LLarr + rnorm(1 , 0 , 0.1 ))))
42
- expect_warning(loo_compare(w_list ),
43
- " Difference in performance potentially due to chance" )
44
-
45
- w_list_short <- lapply(1 : 4 , function (x ) suppressWarnings(waic(LLarr + rnorm(1 , 0 , 0.1 ))))
47
+ w_list <- lapply(1 : 25 , function (x ) {
48
+ suppressWarnings(waic(LLarr + rnorm(1 , 0 , 0.1 )))
49
+ })
50
+ expect_warning(
51
+ loo_compare(w_list ),
52
+ " Difference in performance potentially due to chance"
53
+ )
54
+
55
+ w_list_short <- lapply(1 : 4 , function (x ) {
56
+ suppressWarnings(waic(LLarr + rnorm(1 , 0 , 0.1 )))
57
+ })
46
58
expect_no_warning(loo_compare(w_list_short ))
47
59
})
48
60
49
61
50
-
51
62
comp_colnames <- c(
52
- " elpd_diff" , " se_diff" , " elpd_waic" , " se_elpd_waic" ,
53
- " p_waic" , " se_p_waic" , " waic" , " se_waic"
63
+ " elpd_diff" ,
64
+ " se_diff" ,
65
+ " elpd_waic" ,
66
+ " se_elpd_waic" ,
67
+ " p_waic" ,
68
+ " se_p_waic" ,
69
+ " waic" ,
70
+ " se_waic"
54
71
)
55
72
56
73
test_that(" loo_compare returns expected results (2 models)" , {
@@ -59,15 +76,15 @@ test_that("loo_compare returns expected results (2 models)", {
59
76
expect_equal(colnames(comp1 ), comp_colnames )
60
77
expect_equal(rownames(comp1 ), c(" model1" , " model2" ))
61
78
expect_output(print(comp1 ), " elpd_diff" )
62
- expect_equal(comp1 [1 : 2 ,1 ], c(0 , 0 ), ignore_attr = TRUE )
63
- expect_equal(comp1 [1 : 2 ,2 ], c(0 , 0 ), ignore_attr = TRUE )
79
+ expect_equal(comp1 [1 : 2 , 1 ], c(0 , 0 ), ignore_attr = TRUE )
80
+ expect_equal(comp1 [1 : 2 , 2 ], c(0 , 0 ), ignore_attr = TRUE )
64
81
65
82
comp2 <- loo_compare(w1 , w2 )
66
83
expect_s3_class(comp2 , " compare.loo" )
67
84
expect_equal(colnames(comp2 ), comp_colnames )
68
-
85
+
69
86
expect_snapshot_value(comp2 , style = " serialize" )
70
-
87
+
71
88
# specifying objects via ... and via arg x gives equal results
72
89
expect_equal(comp2 , loo_compare(x = list (w1 , w2 )))
73
90
})
@@ -79,7 +96,7 @@ test_that("loo_compare returns expected result (3 models)", {
79
96
80
97
expect_equal(colnames(comp1 ), comp_colnames )
81
98
expect_equal(rownames(comp1 ), c(" model1" , " model2" , " model3" ))
82
- expect_equal(comp1 [1 ,1 ], 0 )
99
+ expect_equal(comp1 [1 , 1 ], 0 )
83
100
expect_s3_class(comp1 , " compare.loo" )
84
101
expect_s3_class(comp1 , " matrix" )
85
102
@@ -119,34 +136,53 @@ test_that("compare returns expected result (3 models)", {
119
136
expect_equal(
120
137
colnames(comp1 ),
121
138
c(
122
- " elpd_diff" , " se_diff" , " elpd_waic" , " se_elpd_waic" ,
123
- " p_waic" , " se_p_waic" , " waic" , " se_waic"
124
- ))
139
+ " elpd_diff" ,
140
+ " se_diff" ,
141
+ " elpd_waic" ,
142
+ " se_elpd_waic" ,
143
+ " p_waic" ,
144
+ " se_p_waic" ,
145
+ " waic" ,
146
+ " se_waic"
147
+ )
148
+ )
125
149
expect_equal(rownames(comp1 ), c(" w1" , " w2" , " w3" ))
126
- expect_equal(comp1 [1 ,1 ], 0 )
150
+ expect_equal(comp1 [1 , 1 ], 0 )
127
151
expect_s3_class(comp1 , " compare.loo" )
128
152
expect_s3_class(comp1 , " matrix" )
129
153
expect_snapshot_value(comp1 , style = " serialize" )
130
154
131
155
# specifying objects via '...' gives equivalent results (equal
132
156
# except rownames) to using 'x' argument
133
- expect_warning(comp_via_list <- loo :: compare(x = list (w1 , w2 , w3 )), " Deprecated" )
157
+ expect_warning(
158
+ comp_via_list <- loo :: compare(x = list (w1 , w2 , w3 )),
159
+ " Deprecated"
160
+ )
134
161
expect_equal(comp1 , comp_via_list , ignore_attr = TRUE )
135
162
})
136
163
137
164
test_that(" compare throws appropriate errors" , {
138
- expect_error(suppressWarnings(loo :: compare(w1 , w2 , x = list (w1 , w2 ))),
139
- " should not be specified" )
140
- expect_error(suppressWarnings(loo :: compare(x = 2 )),
141
- " must be a list" )
142
- expect_error(suppressWarnings(loo :: compare(x = list (2 ))),
143
- " should have class 'loo'" )
144
- expect_error(suppressWarnings(loo :: compare(x = list (w1 ))),
145
- " requires at least two models" )
146
-
147
- w3 <- suppressWarnings(waic(LLarr2 [,,- 1 ]))
148
- expect_error(suppressWarnings(loo :: compare(x = list (w1 , w3 ))),
149
- " same number of data points" )
150
- expect_error(suppressWarnings(loo :: compare(x = list (w1 , w2 , w3 ))),
151
- " same number of data points" )
165
+ expect_error(
166
+ suppressWarnings(loo :: compare(w1 , w2 , x = list (w1 , w2 ))),
167
+ " should not be specified"
168
+ )
169
+ expect_error(suppressWarnings(loo :: compare(x = 2 )), " must be a list" )
170
+ expect_error(
171
+ suppressWarnings(loo :: compare(x = list (2 ))),
172
+ " should have class 'loo'"
173
+ )
174
+ expect_error(
175
+ suppressWarnings(loo :: compare(x = list (w1 ))),
176
+ " requires at least two models"
177
+ )
178
+
179
+ w3 <- suppressWarnings(waic(LLarr2 [,, - 1 ]))
180
+ expect_error(
181
+ suppressWarnings(loo :: compare(x = list (w1 , w3 ))),
182
+ " same number of data points"
183
+ )
184
+ expect_error(
185
+ suppressWarnings(loo :: compare(x = list (w1 , w2 , w3 ))),
186
+ " same number of data points"
187
+ )
152
188
})
0 commit comments