Skip to content

Add cache_stan_model parameter to relevant functions and update docum… #222

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jul 21, 2025
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ export(sccomp_remove_unwanted_variation)
export(sccomp_replicate)
export(sccomp_test)
export(sccomp_theme)
export(sccomp_stan_models_cache_dir)
export(simulate_data)
import(dplyr)
import(ggplot2)
Expand Down
11 changes: 9 additions & 2 deletions R/model_fitting.R
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ fit_model = function(
seed , pars = c("beta", "alpha", "prec_coeff","prec_sd"), output_samples = NULL, chains=NULL, max_sampling_iterations = 20000,
output_directory = "sccomp_draws_files",
sig_figs = 9,
cache_stan_model = sccomp_stan_models_cache_dir,
...
)
{
Expand Down Expand Up @@ -82,7 +83,7 @@ fit_model = function(
dir.create(output_directory, showWarnings = FALSE)

# Fit
mod = load_model(model_name, threads = cores)
mod = load_model(model_name, threads = cores, cache_dir = cache_stan_model)

# Avoid 0 proportions
if(data_for_model$is_proportion && min(data_for_model$y_proportion)==0){
Expand Down Expand Up @@ -144,6 +145,7 @@ fit_model = function(
psis_resample = FALSE,
verbose = verbose,
sig_figs = sig_figs,
cache_stan_model = cache_stan_model,
show_exceptions = FALSE,
...
)
Expand Down Expand Up @@ -203,6 +205,10 @@ load_model <- function(name, cache_dir = sccomp_stan_models_cache_dir, force=FAL
# }, error = function(e) {
# Try to load the model from cache

# Handle cache directory - always add version to ensure version isolation
sccomp_version <- as.character(packageVersion("sccomp"))
cache_dir <- file.path(cache_dir, sccomp_version)

# RDS compiled model
cache_dir |> dir.create(showWarnings = FALSE, recursive = TRUE)
cache_file <- file.path(cache_dir, paste0(name, ".rds"))
Expand Down Expand Up @@ -335,6 +341,7 @@ vb_iterative = function(model,
verbose = TRUE,
psis_resample = FALSE,
sig_figs = 9,
cache_stan_model = sccomp_stan_models_cache_dir,
...) {
res = NULL
i = 0
Expand Down Expand Up @@ -384,7 +391,7 @@ vb_iterative = function(model,
error = function(e) {
if(e$message |> str_detect("The Stan file used to create the `CmdStanModel` object does not exist\\.")) {
clear_stan_model_cache()
model <<- load_model(model_name, force=TRUE, threads = cores)
model <<- load_model(model_name, force=TRUE, threads = cores, cache_dir = cache_stan_model)
}
else writeLines(sprintf("Further attempt with Variational Bayes: %s", e))

Expand Down
23 changes: 20 additions & 3 deletions R/sccomp_estimate.R
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,9 @@
#' @param max_sampling_iterations Integer to limit the maximum number of iterations for large datasets.
#' @param pass_fit Logical, whether to include the Stan fit as an attribute in the output.
#' @param sig_figs Number of significant figures to use for Stan model output. Default is 9.
#' @param cache_stan_model A character string specifying the cache directory for compiled Stan models.
#' The sccomp version will be automatically appended to ensure version isolation.
#' Default is `sccomp_stan_models_cache_dir` which points to `~/.sccomp_models`.
#' @param .count **DEPRECATED**. Use `abundance` instead.
#' @param approximate_posterior_inference **DEPRECATED**. Use `inference_method` instead.
#' @param variational_inference **DEPRECATED**. Use `inference_method` instead.
Expand Down Expand Up @@ -157,6 +160,7 @@ sccomp_estimate <- function(.data,
max_sampling_iterations = 20000,
pass_fit = TRUE,
sig_figs = 9,
cache_stan_model = sccomp_stan_models_cache_dir,
...,

# DEPRECATED
Expand Down Expand Up @@ -248,6 +252,7 @@ sccomp_estimate.Seurat <- function(.data,
max_sampling_iterations = 20000,
pass_fit = TRUE,
sig_figs = 9,
cache_stan_model = sccomp_stan_models_cache_dir,
...,

# DEPRECATED
Expand Down Expand Up @@ -300,6 +305,7 @@ sccomp_estimate.Seurat <- function(.data,
max_sampling_iterations = max_sampling_iterations,
pass_fit = pass_fit,
sig_figs = sig_figs,
cache_stan_model = cache_stan_model,
...,
.count = !!.count,
approximate_posterior_inference = approximate_posterior_inference,
Expand Down Expand Up @@ -341,6 +347,7 @@ sccomp_estimate.SingleCellExperiment <- function(.data,
max_sampling_iterations = 20000,
pass_fit = TRUE,
sig_figs = 9,
cache_stan_model = sccomp_stan_models_cache_dir,
...,

# DEPRECATED
Expand Down Expand Up @@ -394,6 +401,7 @@ sccomp_estimate.SingleCellExperiment <- function(.data,
max_sampling_iterations = max_sampling_iterations,
pass_fit = pass_fit,
sig_figs = sig_figs,
cache_stan_model = cache_stan_model,
...,
.count = !!.count,
approximate_posterior_inference = approximate_posterior_inference,
Expand Down Expand Up @@ -435,6 +443,7 @@ sccomp_estimate.DFrame <- function(.data,
max_sampling_iterations = 20000,
pass_fit = TRUE,
sig_figs = 9,
cache_stan_model = sccomp_stan_models_cache_dir,
...,

# DEPRECATED
Expand Down Expand Up @@ -477,7 +486,8 @@ sccomp_estimate.DFrame <- function(.data,
use_data = use_data,
mcmc_seed = mcmc_seed,
max_sampling_iterations = max_sampling_iterations,
pass_fit = pass_fit,
pass_fit = pass_fit,
cache_stan_model = cache_stan_model,
...,
.count = !!.count,
approximate_posterior_inference = approximate_posterior_inference,
Expand Down Expand Up @@ -522,6 +532,7 @@ sccomp_estimate.data.frame <- function(.data,
max_sampling_iterations = 20000,
pass_fit = TRUE,
sig_figs = 9,
cache_stan_model = sccomp_stan_models_cache_dir,
...,

# DEPRECATED
Expand Down Expand Up @@ -642,6 +653,7 @@ sccomp_estimate.data.frame <- function(.data,
max_sampling_iterations = max_sampling_iterations,
pass_fit = pass_fit,
sig_figs = sig_figs,
cache_stan_model = cache_stan_model,
...
)

Expand Down Expand Up @@ -672,6 +684,7 @@ sccomp_estimate.data.frame <- function(.data,
max_sampling_iterations = max_sampling_iterations,
pass_fit = pass_fit,
sig_figs = sig_figs,
cache_stan_model = cache_stan_model,
...
)

Expand Down Expand Up @@ -718,6 +731,7 @@ sccomp_glm_data_frame_raw = function(.data,
max_sampling_iterations = 20000,
pass_fit = TRUE,
sig_figs = 9,
cache_stan_model = sccomp_stan_models_cache_dir,
...) {

# See https://community.rstudio.com/t/how-to-make-complete-nesting-work-with-quosures-and-tidyeval/16473
Expand Down Expand Up @@ -794,6 +808,7 @@ sccomp_glm_data_frame_raw = function(.data,
max_sampling_iterations = max_sampling_iterations,
pass_fit = pass_fit,
sig_figs = sig_figs,
cache_stan_model = cache_stan_model,
...
)
}
Expand Down Expand Up @@ -827,8 +842,9 @@ sccomp_glm_data_frame_counts = function(.data,
cores = 4,
mcmc_seed = sample_seed(),
max_sampling_iterations = 20000,
pass_fit = TRUE,
sig_figs = 9,
pass_fit = TRUE,
sig_figs = 9,
cache_stan_model = sccomp_stan_models_cache_dir,
...) {

# Prepare column same enquo
Expand Down Expand Up @@ -1004,6 +1020,7 @@ sccomp_glm_data_frame_counts = function(.data,
"log_lik"
),
sig_figs = sig_figs,
cache_stan_model = cache_stan_model,
...
)

Expand Down
9 changes: 8 additions & 1 deletion R/sccomp_remove_outliers.R
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@
#' @param max_sampling_iterations Integer, limits the maximum number of iterations in case a large dataset is used, to limit computation time.
#' @param enable_loo Logical, whether to enable model comparison using the R package LOO. This is useful for comparing fits between models, similar to ANOVA.
#' @param sig_figs Number of significant figures to use for Stan model output. Default is 9.
#' @param cache_stan_model A character string specifying the cache directory for compiled Stan models.
#' The sccomp version will be automatically appended to ensure version isolation.
#' Default is `sccomp_stan_models_cache_dir` which points to `~/.sccomp_models`.
#' @param approximate_posterior_inference DEPRECATED, use the `variational_inference` argument.
#' @param variational_inference DEPRECATED Logical, whether to use variational Bayes for posterior inference. It is faster and convenient. Setting this argument to `FALSE` runs full Bayesian (Hamiltonian Monte Carlo) inference, which is slower but the gold standard.
#' @param ... Additional arguments passed to the `cmdstanr::sample` function.
Expand Down Expand Up @@ -94,6 +97,7 @@ sccomp_remove_outliers <- function(.estimate,
max_sampling_iterations = 20000,
enable_loo = FALSE,
sig_figs = 9,
cache_stan_model = sccomp_stan_models_cache_dir,

# DEPRECATED
approximate_posterior_inference = NULL,
Expand Down Expand Up @@ -125,6 +129,7 @@ sccomp_remove_outliers.sccomp_tbl = function(.estimate,
max_sampling_iterations = 20000,
enable_loo = FALSE,
sig_figs = 9,
cache_stan_model = sccomp_stan_models_cache_dir,

# DEPRECATED
approximate_posterior_inference = NULL,
Expand Down Expand Up @@ -180,7 +185,7 @@ sccomp_remove_outliers.sccomp_tbl = function(.estimate,
random_effect_elements = .estimate |> attr("formula_composition") |> parse_formula_random_effect()

# Load model
mod_rng = load_model("glm_multi_beta_binomial_generate_data", threads = cores)
mod_rng = load_model("glm_multi_beta_binomial_generate_data", threads = cores, cache_dir = cache_stan_model)

rng = mod_rng |> sample_safe(
generate_quantities_fx,
Expand Down Expand Up @@ -323,6 +328,7 @@ sccomp_remove_outliers.sccomp_tbl = function(.estimate,
max_sampling_iterations = max_sampling_iterations,
pars = c("beta", "alpha", "prec_coeff", "prec_sd", "alpha_normalised", "random_effect", "random_effect_2"),
sig_figs = sig_figs,
cache_stan_model = cache_stan_model,
...
)

Expand Down Expand Up @@ -453,6 +459,7 @@ sccomp_remove_outliers.sccomp_tbl = function(.estimate,
seed = mcmc_seed,
max_sampling_iterations = max_sampling_iterations,
pars = c("beta", "alpha", "prec_coeff","prec_sd", "alpha_normalised", "random_effect", "random_effect_2", "log_lik"),
cache_stan_model = cache_stan_model,
...
)

Expand Down
17 changes: 12 additions & 5 deletions R/sccomp_replicate.R
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@
#' @param formula_variability A formula. The formula describing the model for differential variability, for example ~treatment. In most cases, if differentially variability is of interest, the formula should only include the factor of interest as a large anount of data is needed to define variability depending to each factors. This formula can be a sub-formula of your estimated model; in this case all other factor will be factored out.
#' @param number_of_draws An integer. How may copies of the data you want to draw from the model joint posterior distribution.
#' @param mcmc_seed An integer. Used for Markov-chain Monte Carlo reproducibility. By default a random number is sampled from 1 to 999999. This itself can be controlled by set.seed()
#' @param cache_stan_model A character string specifying the cache directory for compiled Stan models.
#' The sccomp version will be automatically appended to ensure version isolation.
#' Default is `sccomp_stan_models_cache_dir` which points to `~/.sccomp_models`.
#'
#' @return A tibble `tbl` with cell_group-wise statistics
#'
Expand Down Expand Up @@ -48,7 +51,8 @@ sccomp_replicate <- function(fit,
formula_composition = NULL,
formula_variability = NULL,
number_of_draws = 1,
mcmc_seed = sample_seed()) {
mcmc_seed = sample_seed(),
cache_stan_model = sccomp_stan_models_cache_dir) {

# Run the function
check_and_install_cmdstanr()
Expand All @@ -62,7 +66,8 @@ sccomp_replicate.sccomp_tbl = function(fit,
formula_composition = NULL,
formula_variability = NULL,
number_of_draws = 1,
mcmc_seed = sample_seed()){
mcmc_seed = sample_seed(),
cache_stan_model = sccomp_stan_models_cache_dir){

.sample = attr(fit, ".sample")
.cell_group = attr(fit, ".cell_group")
Expand All @@ -79,7 +84,8 @@ sccomp_replicate.sccomp_tbl = function(fit,
formula_composition = formula_composition,
formula_variability = formula_variability,
number_of_draws = number_of_draws,
mcmc_seed = mcmc_seed
mcmc_seed = mcmc_seed,
cache_stan_model = cache_stan_model
)

model_input = attr(fit, "model_input")
Expand Down Expand Up @@ -453,7 +459,8 @@ replicate_data = function(.data,
new_data = NULL,
number_of_draws = 1,
mcmc_seed = sample(1e5, 1),
cores = detectCores()){
cores = detectCores(),
cache_stan_model = sccomp_stan_models_cache_dir){

# Extract required components from .data
.sample = attr(.data, ".sample")
Expand Down Expand Up @@ -529,7 +536,7 @@ replicate_data = function(.data,
number_of_draws = min(number_of_draws, number_of_draws_in_the_fit)

# Load model
mod_rng = load_model("glm_multi_beta_binomial_generate_data", threads = cores)
mod_rng = load_model("glm_multi_beta_binomial_generate_data", threads = cores, cache_dir = cache_stan_model)

draws_matrix <- attr(.data, "fit")$draws(format = "matrix")

Expand Down
11 changes: 8 additions & 3 deletions R/simulate_data.R
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@
#' @param mcmc_seed An integer. Used for Markov-chain Monte Carlo reproducibility. By default a random number is sampled from 1 to 999999. This itself can be controlled by set.seed()#' @param cores Integer, the number of cores to be used for parallel calculations.
#' @param cores Integer, the number of cores to be used for parallel calculations.
#' @param sig_figs Number of significant figures to use for Stan model output. Default is 9.
#' @param cache_stan_model A character string specifying the cache directory for compiled Stan models.
#' The sccomp version will be automatically appended to ensure version isolation.
#' Default is `sccomp_stan_models_cache_dir` which points to `~/.sccomp_models`.
#'
#' @return A tibble (`tbl`) with the following columns:
#' \itemize{
Expand Down Expand Up @@ -79,7 +82,8 @@ simulate_data <- function(.data,
number_of_draws = 1,
mcmc_seed = sample_seed(),
cores = detectCores(),
sig_figs = 9) {
sig_figs = 9,
cache_stan_model = sccomp_stan_models_cache_dir) {

# Run the function
check_and_install_cmdstanr()
Expand Down Expand Up @@ -107,7 +111,8 @@ simulate_data.tbl = function(.data,
number_of_draws = 1,
mcmc_seed = sample_seed(),
cores = detectCores(),
sig_figs = 9) {
sig_figs = 9,
cache_stan_model = sccomp_stan_models_cache_dir) {


.sample = enquo(.sample)
Expand Down Expand Up @@ -144,7 +149,7 @@ simulate_data.tbl = function(.data,
# [1] 5.6260004 -0.6940178
# prec_sd = 0.816423129

mod_rng = load_model("glm_multi_beta_binomial_simulate_data", threads = cores)
mod_rng = load_model("glm_multi_beta_binomial_simulate_data", threads = cores, cache_dir = cache_stan_model)

fit = mod_rng |> sample_safe(
generate_quantities_fx,
Expand Down
4 changes: 2 additions & 2 deletions R/utilities.R
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@ get_sccomp_cache_dir <- function() {
file.path(path.expand("~"), ".sccomp_models", packageVersion("sccomp"))
}

# Define global variable
sccomp_stan_models_cache_dir = get_sccomp_cache_dir()
# Define global variable (without version - version will be added in load_model when needed)
sccomp_stan_models_cache_dir = file.path(path.expand("~"), ".sccomp_models")

#' Add attribute to abject
#'
Expand Down
5 changes: 5 additions & 0 deletions man/sccomp_estimate.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

5 changes: 5 additions & 0 deletions man/sccomp_remove_outliers.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading
Loading