Skip to content

Commit 7c9d291

Browse files
authored
Merge pull request #2 from ArthurLeroy/new_version_multivariate_inputs
Major update of the package regarding multidimensional inputs
2 parents 42b9e0e + d528b73 commit 7c9d291

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

44 files changed

+3013
-1593
lines changed

DESCRIPTION

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,15 @@
11
Package: MagmaClustR
22
Title: Clustering and Prediction using Multi-Task Gaussian Processes with
33
Common Mean
4-
Version: 1.0.1
4+
Version: 1.1.0
55
Authors@R: c(
6-
person("Arthur", "Leroy", , "arthur.leroy.pro@gmail.com", role = c("aut", "cre"),
6+
person("Arthur", "Leroy", , "arthur.leroy.pro@gmail.com",
7+
role = c("aut", "cre"),
78
comment = c(ORCID = "0000-0003-0806-8934")),
9+
person("Pierre", "Latouche", , "pierre.latouche@gmail.com", role = "aut"),
810
person("Pierre", "Pathé", , "pathepierre@gmail.com", role = "ctb"),
9-
person("Pierre", "Latouche", , "pierre.latouche@gmail.com", role = "aut")
11+
person("Alexia", "Grenouilla", , "grenouil@insa-toulouse.fr", role = "ctb"),
12+
person("Hugo", "Lelievre", , "lelievre@insa-toulouse.fr", role = "ctb")
1013
)
1114
Description: An implementation for the multi-task Gaussian processes with common
1215
mean framework. Two main algorithms, called 'Magma' and 'MagmaClust',
@@ -36,6 +39,8 @@ Imports:
3639
magrittr,
3740
methods,
3841
mvtnorm,
42+
plyr,
43+
purrr,
3944
Rcpp,
4045
rlang,
4146
stats,
@@ -45,6 +50,7 @@ Imports:
4550
Suggests:
4651
gganimate,
4752
gifski,
53+
gridExtra,
4854
knitr,
4955
plotly,
5056
png,
@@ -57,3 +63,5 @@ Encoding: UTF-8
5763
LazyData: true
5864
Roxygen: list(markdown = TRUE)
5965
RoxygenNote: 7.2.0
66+
Depends:
67+
R (>= 2.10)

NAMESPACE

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
export("%>%")
44
export(data_allocate_cluster)
5+
export(expand_grid_inputs)
56
export(hp)
67
export(hyperposterior)
78
export(hyperposterior_clust)
@@ -17,6 +18,8 @@ export(pred_gp)
1718
export(pred_magma)
1819
export(pred_magmaclust)
1920
export(proba_max_cluster)
21+
export(regularise_data)
22+
export(regularize_data)
2023
export(sample_gp)
2124
export(select_nb_cluster)
2225
export(simu_db)

NEWS.md

Lines changed: 26 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,36 @@
11
# MagmaClustR (development version)
22

3+
# MagmaClustR 1.1.0
4+
5+
## Major
6+
* Provide 4 vignettes explaining in details how the different features of MagmaClustR work on practical examples.
7+
* Implement expand_grid_inputs() to help creating customised n-dimensional input
8+
grids on which to evaluate the GP.
9+
* Implement regularize_data() to project a dataset on a specific input grid,
10+
(possibly to control the size of the resulting covariance matrices and the associated running time).
11+
* Add an internal 'Reference' column to datasets, to provide an adequate identifier for multidimensional inputs.
12+
* Implement a new version of simu_db() to generate more realistic 2-D datasets.
13+
14+
## Minor
15+
* Round inputs to 6 significant digits to avoid numerical errors.
16+
* Generalise the creation of a grid in any dimension when 'grid_inputs' is not
17+
specified in the prediction functions.
318

419
# MagmaClustR 1.0.1
520

621
## Major
7-
* Remove the package 'optimr' dependency and switch to base 'optim()' function
8-
* Increase convergence tolerance in 'optim()', which was too slow
22+
*Remove the package 'optimr' dependency and switch to base 'optim()' function
23+
*Increase convergence tolerance in 'optim()', which was too slow
924

1025
## Minor
11-
* Fix the warnings about the absolute value function in the Cpp code
12-
* Remove error message in 'train_magmaclust()' when common_hp_k = FALSE
13-
* Change the default intervals for hyper-parameters in 'simu_db()'
14-
* Automatically remove rows with missing data
15-
* Change position of the 'grid_inputs' argument in prediction functions
16-
* Remove the internal functions from the index documentation
17-
* Fix 'ID' in hyperposterior() and hyperposterior_clust() when not character
26+
*Fix the warnings about the absolute value function in the Cpp code
27+
*Remove error message in 'train_magmaclust()' when common_hp_k = FALSE
28+
*Change the default intervals for hyper-parameters in 'simu_db()'
29+
*Automatically remove rows with missing data
30+
*Change position of the 'grid_inputs' argument in prediction functions
31+
*Remove the internal functions from the index documentation
32+
*Fix 'ID' in hyperposterior() and hyperposterior_clust() when not character
33+
1834

1935
# MagmaClustR 1.0.0
20-
* Initial release
36+
Initial release

R/data.R

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
#' French swimmers performances data on 100m freestyle events
2+
#'
3+
#' A subset of data from reported performances of French swimmers during
4+
#' 100m freestyle competitions between 2002 and 2016. See
5+
#' https://link.springer.com/article/10.1007/s10994-022-06172-1 and
6+
#' https://www.mdpi.com/2076-3417/8/10/1766 for dedicated description and
7+
#' analysis.
8+
#'
9+
#' @format ## `swimmers`
10+
#' A data frame with 76,832 rows and 4 columns:
11+
#' \describe{
12+
#' \item{ID}{Indentifying number associated to each swimmer}
13+
#' \item{Input}{Age in years}
14+
#' \item{Output}{Performance in seconds on a 100m freestyle event}
15+
#' \item{Gender}{Competition gender}
16+
#' }
17+
#' @source <https://ffn.extranat.fr/webffn/competitions.php?idact=nat>
18+
"swimmers"
19+
20+
#' Weight follow-up data of children in Singapore
21+
#'
22+
#' A subset of data from the GUSTO project (https://www.gusto.sg/) collecting
23+
#' the weight over time of several children in Singapore.
24+
#' See https://arxiv.org/abs/2011.07866 for dedicated description and
25+
#' analysis.
26+
#'
27+
#' @format ## `weight`
28+
#' A data frame with 3,629 rows and 4 columns:
29+
#' \describe{
30+
#' \item{ID}{Indentifying number associated to each child}
31+
#' \item{sex}{Biological gender}
32+
#' \item{Input}{Age in months}
33+
#' \item{Output}{Weight in kilograms}
34+
#' }
35+
#' @source <https://gustodatavault.sg/>
36+
"weight"
37+

R/elbos.R

Lines changed: 39 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,24 @@
1515
#'
1616
#' @examples
1717
#' TRUE
18-
elbo_clust_multi_GP <- function(hp, db, hyperpost, kern, pen_diag) {
18+
elbo_clust_multi_GP <- function(hp,
19+
db,
20+
hyperpost,
21+
kern,
22+
pen_diag) {
23+
1924
names_k <- hyperpost$mean %>% names()
20-
t_i <- db$Input
25+
t_i <- db$Reference
2126
y_i <- db$Output
2227
i <- unique(db$ID)
2328

24-
inv <- kern_to_inv(t_i, kern, hp, pen_diag)
29+
if("ID" %in% names(db)){
30+
inputs <- db %>% dplyr::select(-.data$Output, -.data$ID)
31+
} else{
32+
inputs <- db %>% dplyr::select(-.data$Output)
33+
}
34+
35+
inv <- kern_to_inv(inputs, kern, hp, pen_diag)
2536

2637
## classic Gaussian centred log likelihood
2738
LL_norm <- -dmnorm(y_i, rep(0, length(y_i)), inv, log = T)
@@ -35,12 +46,12 @@ elbo_clust_multi_GP <- function(hp, db, hyperpost, kern, pen_diag) {
3546
dplyr::filter(.data$ID == i) %>%
3647
dplyr::pull(k)
3748
mean_mu_k <- hyperpost$mean[[k]] %>%
38-
dplyr::filter(.data$Input %in% t_i) %>%
49+
dplyr::filter(.data$Reference %in% t_i) %>%
3950
dplyr::pull(.data$Output)
4051
corr1 <- corr1 + tau_i_k * mean_mu_k
4152
corr2 <- corr2 + tau_i_k *
4253
(mean_mu_k %*% t(mean_mu_k) +
43-
hyperpost$cov[[k]][as.character(t_i), as.character(t_i)])
54+
hyperpost$cov[[k]][as.character(t_i), as.character(t_i)])
4455
}
4556

4657
(LL_norm - y_i %*% inv %*% corr1 + 0.5 * sum(inv * corr2)) %>% return()
@@ -67,21 +78,22 @@ elbo_clust_multi_GP <- function(hp, db, hyperpost, kern, pen_diag) {
6778
#'
6879
#' @examples
6980
#' TRUE
70-
elbo_GP_mod_common_hp_k <- function(
71-
hp,
72-
db,
73-
mean,
74-
kern,
75-
post_cov,
76-
pen_diag
77-
) {
81+
elbo_GP_mod_common_hp_k <- function( hp,
82+
db,
83+
mean,
84+
kern,
85+
post_cov,
86+
pen_diag) {
7887

7988
list_ID_k <- names(db)
80-
# t_k = db[[1]] %>% dplyr::pull(.data$Input)
81-
t_k <- db[[1]] %>%
82-
dplyr::pull(.data$Input)
8389

84-
inv <- kern_to_inv(t_k, kern, hp, pen_diag)
90+
if("ID" %in% names(db)){
91+
inputs <- db[[1]] %>% dplyr::select(-.data$Output, -.data$ID)
92+
} else{
93+
inputs <- db[[1]] %>% dplyr::select(-.data$Output)
94+
}
95+
96+
inv <- kern_to_inv(inputs, kern, hp, pen_diag)
8597

8698
LL_norm <- 0
8799
cor_term <- 0
@@ -114,7 +126,12 @@ elbo_GP_mod_common_hp_k <- function(
114126
#'
115127
#' @examples
116128
#' TRUE
117-
elbo_clust_multi_GP_common_hp_i <- function(hp, db, hyperpost, kern, pen_diag) {
129+
elbo_clust_multi_GP_common_hp_i <- function(hp,
130+
db,
131+
hyperpost,
132+
kern,
133+
pen_diag) {
134+
118135
names_k <- hyperpost$mean %>% names()
119136

120137
sum_i <- 0
@@ -123,7 +140,7 @@ elbo_clust_multi_GP_common_hp_i <- function(hp, db, hyperpost, kern, pen_diag) {
123140
## Extract the i-th specific reference Input
124141
input_i <- db %>%
125142
dplyr::filter(.data$ID == i) %>%
126-
dplyr::pull(.data$Input)
143+
dplyr::pull(.data$Reference)
127144
## Extract the i-th specific inputs (reference + covariates)
128145
inputs_i <- db %>%
129146
dplyr::filter(.data$ID == i) %>%
@@ -148,7 +165,7 @@ elbo_clust_multi_GP_common_hp_i <- function(hp, db, hyperpost, kern, pen_diag) {
148165
dplyr::filter(.data$ID == i) %>%
149166
dplyr::pull(k)
150167
mean_mu_k <- hyperpost$mean[[k]] %>%
151-
dplyr::filter(.data$Input %in% input_i) %>%
168+
dplyr::filter(.data$Reference %in% input_i) %>%
152169
dplyr::pull(.data$Output)
153170
corr1 <- corr1 + tau_i_k * mean_mu_k
154171
corr2 <- corr2 + tau_i_k *
@@ -216,7 +233,7 @@ elbo_monitoring_VEM <- function(hp_k,
216233
floop2 <- function(i) {
217234
t_i <- db %>%
218235
dplyr::filter(.data$ID == i) %>%
219-
dplyr::pull(.data$Input)
236+
dplyr::pull(.data$Reference)
220237

221238
elbo_clust_multi_GP(
222239
hp_i[hp_i$ID == i, ],
@@ -257,6 +274,7 @@ elbo_monitoring_VEM <- function(hp_k,
257274

258275
return(sum_tau + det)
259276
}
277+
260278
sum_corr_k <- sapply(names(m_k), floop3) %>% sum()
261279

262280
return(-sum_ll_k - sum_ll_i + sum_corr_k)

R/em-magma.R

Lines changed: 31 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -29,10 +29,14 @@ e_step <- function(db,
2929
hp_0,
3030
hp_i,
3131
pen_diag) {
32-
all_input <- unique(db$Input) %>% sort()
32+
## Extract the union of all reference inputs provided in the training data
33+
all_inputs <- db %>%
34+
dplyr::select(-.data$ID, -.data$Output) %>%
35+
unique() %>%
36+
dplyr::arrange(.data$Reference)
3337

3438
## Compute all the inverse covariance matrices
35-
inv_0 <- kern_to_inv(all_input, kern_0, hp_0, pen_diag)
39+
inv_0 <- kern_to_inv(all_inputs, kern_0, hp_0, pen_diag)
3640
list_inv_i <- list_kern_to_inv(db, kern_i, hp_i, pen_diag)
3741
## Create a named list of Output values for all individuals
3842
list_output_i <- base::split(db$Output, list(db$ID))
@@ -50,8 +54,12 @@ e_step <- function(db,
5054

5155
post_cov <- post_inv %>%
5256
chol_inv_jitter(pen_diag = pen_diag) %>%
53-
`rownames<-`(all_input) %>%
54-
`colnames<-`(all_input)
57+
`rownames<-`(all_inputs %>%
58+
dplyr::pull(.data$Reference)
59+
) %>%
60+
`colnames<-`(all_inputs %>%
61+
dplyr::pull(.data$Reference)
62+
)
5563
##############################################
5664

5765
## Update the posterior mean ##
@@ -71,9 +79,8 @@ e_step <- function(db,
7179
##############################################
7280

7381
## Format the mean parameter of the hyper-posterior distribution
74-
tib_mean <- tibble::tibble(
75-
"Input" = all_input,
76-
"Output" = post_mean
82+
tib_mean <- tibble::tibble(all_inputs,
83+
"Output" = post_mean
7784
)
7885
list(
7986
"mean" = tib_mean,
@@ -117,8 +124,17 @@ e_step <- function(db,
117124
#'
118125
#' @examples
119126
#' TRUE
120-
m_step <- function(db, m_0, kern_0, kern_i, old_hp_0, old_hp_i,
121-
post_mean, post_cov, common_hp, pen_diag) {
127+
m_step <- function(db,
128+
m_0,
129+
kern_0,
130+
kern_i,
131+
old_hp_0,
132+
old_hp_i,
133+
post_mean,
134+
post_cov,
135+
common_hp,
136+
pen_diag) {
137+
122138
list_ID <- unique(db$ID)
123139
list_hp_0 <- old_hp_0 %>% names()
124140
list_hp_i <- old_hp_i %>%
@@ -182,10 +198,10 @@ m_step <- function(db, m_0, kern_0, kern_i, old_hp_0, old_hp_i,
182198
## Extract the i-th specific inputs
183199
input_i <- db %>%
184200
dplyr::filter(.data$ID == i) %>%
185-
dplyr::pull(.data$Input)
201+
dplyr::pull(.data$Reference)
186202
## Extract the mean values associated with the i-th specific inputs
187203
post_mean_i <- post_mean %>%
188-
dplyr::filter(.data$Input %in% input_i) %>%
204+
dplyr::filter(.data$Reference %in% input_i) %>%
189205
dplyr::pull(.data$Output)
190206
## Extract the covariance values associated with the i-th specific inputs
191207
post_cov_i <- post_cov[as.character(input_i), as.character(input_i)]
@@ -214,7 +230,10 @@ m_step <- function(db, m_0, kern_0, kern_i, old_hp_0, old_hp_i,
214230
tibble::as_tibble_row() %>%
215231
return()
216232
}
217-
new_hp_i <- sapply(list_ID, floop, simplify = FALSE, USE.NAMES = TRUE) %>%
233+
new_hp_i <- sapply(list_ID,
234+
floop,
235+
simplify = FALSE,
236+
USE.NAMES = TRUE) %>%
218237
tibble::enframe(name = "ID") %>%
219238
tidyr::unnest(cols = .data$value)
220239
}

0 commit comments

Comments
 (0)