The Cell Observatory Platform is a comprehensive framework for training and evaluating machine learning models on biological image and video datasets. Built with PyTorch, accelerated and scaled with Ray, model sharding using DeepSpeed, and flexibly configured using Hydra, it provides a modular architecture for easy customization and extension.
Docker images
Our prebuilt image with Python, Torch, and all packages installed for you.
docker pull ghcr.io/cell-observatory/cell_observatory_platform:develop_torch_cuda_12_8
git clone --recurse-submodules https://github.com/cell-observatory/cell_observatory_platform.git
To later update to the latest, greatest
git pull --recurse-submodules
Note: If you want to run a local version of the image, see the Dockerfile
You will need to create a Supabase and W&B account to use the platform. Supabase can be found Cell Observatory Database, and W&B can be found Cell Observatory Dashboard.
Once you have created your Supabase and W&B accounts, you'll need to add your API keys in the environment variables as described below.
Rename .env.example
file to .env
which will be automatically loaded into the container and will be gitignored. The Supabase related environment variables enable database functionality. The W&B API key enables logging functionality. The REPO_NAME
, DATA_DIR
, and STORAGE_SERVER_DIR
environment variables are leverged in the configs/paths
configuration files to ensure that jobs run and save outputs as expected.
SUPABASE_USER=REPLACE_ME_WITH_YOUR_SUPABASE_USERNAME
SUPABASE_PASS=REPLACE_ME_WITH_YOUR_SUPABASE_PASSWORD
SUPABASE_STAGING_ID=REPLACE_ME_WITH_YOUR_SUPABASE_STAGING_ID
SUPABASE_PROD_ID=REPLACE_ME_WITH_YOUR_SUPABASE_PROD_ID
WANDB_API_KEY=REPLACE_ME_WITH_YOUR_WANDB_API_KEY
SUPABASE_STAGING_URI="postgresql://${SUPABASE_USER}.${SUPABASE_STAGING_ID}:${SUPABASE_PASS}@aws-0-us-east-1.pooler.supabase.com:5432/postgres"
SUPABASE_PROD_URI="postgresql://${SUPABASE_USER}.${SUPABASE_PROD_ID}:${SUPABASE_PASS}@aws-0-us-east-1.pooler.supabase.com:5432/postgres"
REPO_NAME=cell_observatory_platform
REPO_DIR=REPLACE_ME_WITH_YOUR_ROOT_REPO_DIR
DATA_DIR=REPLACE_ME_WITH_YOUR_ROOT_DATA_DIR_WHERE_DATA_WILL_BE_SAVED
STORAGE_SERVER_DIR=REPLACE_ME_WITH_YOUR_STORAGE_SERVER_DIR_WHERE_DATA_SERVER_IS_MOUNTED
PYTHONPATH=REPLACE_ME_WITH_YOUR_ROOT_REPO_DIR
To run docker image, cd to repo directory or replace $(pwd)
with your local path for the repository.
docker run --network host -u 1000 --privileged -v $(pwd):/workspace/cell_observatory_platform -w /workspace/cell_observatory_platform --env PYTHONUNBUFFERED=1 --pull missing -it --rm --ipc host --gpus all ghcr.io/cell-observatory/cell_observatory_platform:develop_torch_cuda_12_8 bash
Running an image on a cluster typically requires an Apptainer version of the image, which can be generated by:
apptainer pull --arch amd64 --force develop_torch_cuda_12_8.sif docker://ghcr.io/cell-observatory/cell_observatory_platform:develop_torch_cuda_12_8
apptainer pull --arch arm64 --force develop_torch_cuda_12_8.sif docker://ghcr.io/cell-observatory/cell_observatory_platform:develop_torch_cuda_12_8
All jobs are launched through our manager.py
script which
facilitates cluster resource allocation and Ray cluster setup.
You may decide whether to run jobs locally or on a cluster by
setting the launcher_type
variable in configs/clusters/*.yaml
.
We show how to run jobs locally and on SLURM or LSF clusters below.
Other cluster configurations can be added by extending the manager.py
script and adding the necessary files in the cluster
folder to support allocaing
cluster resources (see cluster/ray_slurm_cluster.sh
and cluster/ray_lsf_cluster.sh
for examples).
Example job configs are located in the configs/experiments
folder.
For local jobs, you can use our existing configs/paths/local.yaml
and configs/clusters/local.yaml
configurations. Then just edit the handful of lines below
to specify your local directory structure and resource configuration before launching the job.
In general, if you want to extend existing functionality (to support a new
database type, cluster type, dataloader type, etc.), you just need to implement
the necessary module (e.g. data/databases/my_database.py
, data/datasets/my_dataset.py
, clusters/my_cluster.sh
to implement
a new database, dataloader, or cluster configuration respectively), and add a new Hydra configuration file
to specify under the defaults
block in your run config.
experiment_name: test_cell_observatory_platform
wandb_project: test_cell_observatory_platform
paths:
# base output directory for logs, checkpoints, etc.
outdir: ${paths.data_path}/pretrained_models/${experiment_name}
resume_checkpointdir: null
pretrained_checkpointdir: null
clusters:
batch_size: 2 # total batch size
worker_nodes: 1 # number of worker nodes
gpus_per_worker: 1 # number of gpus per worker node
cpus_per_gpu: 4 # number of cpu cores per gpu
mem_per_cpu: 16000 # ram per cpu core
Run the local job using the manager.py
script, which will pick up the Hydra config and launch the Ray job:
# Set config and then run with `manager.py`:
python cluster/manager.py --config-name=configs/test_pretrain_4d_mae_local.yaml
To launch multiple training jobs with manager.py
, set the run_type
variable to multi_run
and define a runs
list of
training jobs you want to run (see configs/benchmarks/abc/benchmark_training_4d.yaml
for an example). Note that each run config needs to specify a base configuration from which each job can override any parameters necessary. We also provide functionality to run jobs using Ray Tune's hyperparameter tuning functionality, in which case you should set run_type
to tune
and specify the parameters you want to sweep in the tune
config module. For using Hydra's native sweep functionality or to run
single jobs, set run_type
to single_run
.
Running a job on a cluster is very similar to the local setup.
You need to override the defaults
in your my_run_config.yaml
file.
We recommend creating clusters/my_cluster_configuration.yaml
and
path/my_path_configurations.yaml
to match your directory structure and
resource/cluster configurations. We have some examples below for SLURM and
LSF in configs/paths/*.yaml
and configs/clusters/*.yaml
.
defaults:
- clusters: abc_a100 # configs/clusters/abc_a100.yaml
- paths: abc # configs/paths/abc.yaml
defaults:
- clusters: janelia_h100 # configs/clusters/janelia_h100.yaml
- paths: janelia # configs/paths/janelia.yaml
Here's what each configuration subdirectory handles:
configs/models/
- Defines complete model specifications (e.g.,
JEPA
andMAE
model architectures).
- Defines complete model specifications (e.g.,
configs/hooks/
- Defines training hooks for logging, checkpointing, and other custom training behaviors.
configs/datasets/
- Defines dataset classes and parameters.
configs/losses/
- Defines loss functions used in training.
configs/transforms/
- Defines data augmentation and preprocessing pipelines and parameters.
configs/optimizers/
- Defines optimizer configurations (e.g., AdamW, SGD).
configs/schedulers/
- Defines learning rate schedulers (e.g., StepLR, CosineAnnealing).
configs/checkpoint/
- Defines checkpointing configurations for saving and loading model states.
configs/logging/
- Defines logging configurations for training and evaluation metrics.
configs/deepspeed/
- Defines DeepSpeed configurations for distributed training.
configs/clusters/
- Defines cluster configurations for distributed training (e.g., Ray-on-SLURM).
configs/loggers/
- Defines logging configurations for experiment tracking (e.g., WandB, TensorBoard).
configs/trainer/
- Defines the training loop and configurations for training models.
configs/evaluation/
- Defines evaluation configurations for assessing model performance on validation/test datasets.
configs/optimizations/
- Defines configurations for enabling model performance optimization flags/functionality
(e.g.
torch.compile
, activation checkpointing, flags such asPYTORCH_CUDA_ALLOC_CONF
).
- Defines configurations for enabling model performance optimization flags/functionality
(e.g.
configs/benchmarks/
- Defines run configurations for benchmarking model throughput, data loading throughput, etc.
configs/experiments/
- Defines run configurations for previous experiments we have run.
configs/tune/
- Defines Ray Tune configurations for running parameter sweeps.
We use Hydra for managing experiment configurations. Hydra allows you to construct experiments by composing modular YAML files.
Use the defaults:
list to select the base YAML configurations for your experiment:
defaults:
- models: jepa
- datasets: pretrain
- transforms: transforms
- hooks: hooks
- _self_ # load this file’s overrides last
Hydra handles overrides, allowing precise experiment adjustments. Scalars and lists are replaced outright, whereas dictionaries are merged recursively (only specified keys change).
Example overrides in your main config file:
clusters:
batch_size: 128 # total batch size
datasets:
input_shape: [32, 128, 128, 128, 2]
patch_shape: [4, 16, 16, 16]
split: 0.1 # train/val split
models:
_target_: models.jepa.JEPA
Add a new YAML under the proper group, e.g. configs/models/backbones/my_new_backbone.yaml
. To support a brand‑new model or dataset, just drop in your new small YAML and reference it in your defaults:
block.
defaults:
- models/backbones: my_new_backbone
- _self_
The end-to-end flow of our data pipeline is as follows: Database (Supabase) -> index DataFrame -> dataset -> dataloader -> preprocessor -> model
.
We first build a hypercubes table (CSV + JSON config) from Supabase, filter it (based on server path, ROI/tile/HPF, occupancy), and then choose one of three input stacks:
-
PyTorch
(classic Dataset/DataLoader) -
DALI
(CPU external_source → GPU pipeline) -
Ray Data
(custom Datasource → streaming iterator)
All three paths normalize batches to {"data_tensor": <Tensor>, "metainfo": <dict>}
and then a Preprocessor
enforces dtype/device
, applies optional on-device transforms, and (optionally) generates masks for self-supervised training.
ImageList:
A tensor container that holds images of varying sizes as a single padded tensor, storing original image sizes, layout information (CZYX
vs ZYXC
etc.), and providing methods for standardization and batch operations.
BaseDataElement:
A base class that provides dict-like and tensor-like operations for data containers, separating metainfo (image metadata) from data fields (annotations/predictions).
DataSample:
Inherits from BaseDataElement
that contains an ImageList
object for data tensors and class methods from_dict
and to_dict
for serialization and deserialization.
SupabaseDatabase:
Database class that constructs or queries a view of TxZxYxXxC
hypercubes, fetches it via ConnectorX
(Arrow path), and saves both a CSV of records and a JSON of configs. It supports:
-
Server-aware existence filters (
/clusterfs
,/groups
,/aws
) and optional override ofserver_folder
. -
Row limiting / ROI or tile selection / HPF filters and an occupancy filter (based on a minimum occupancy ratio) that parses per-channel occupancy data for each sample and drops low-occupancy samples.
Users that want to implement their own database class only needs to implement (with corresponding .yaml
files) a corresponding class to generate a local CSV record with file path information for your training samples.
Our get_dataloader
method in data/dataloaders.py
is the entrypoint that instantiates the selected dataset stack, optional transforms, and builds a dataloader. We have support for the following dataloaders:
-
PyTorch
:PreTrainDataset
implements a standardDataLoader(..., num_workers, prefetch_factor, pin_memory, ...)
, and a user-defined collate function. -
DALI
:PretrainDatasetDali
builds aDALI
pipeline of the typeCPU external source -> GPU ops
, then aDALIGenericIterator
provides data loading functionality. -
Ray
:PretrainDatasourceRay
builds a Dataset from a customDatasource
, applies optional transforms, and creates an iterator with_iter_batches
.
Our Preprocessors
class provides a unified interface to standardize outputs from different data loader libraries and to perform any operations needed before the model step. We support the following Preprocessors:
TorchPreprocessor
Normalizes PyTorch
batches to a common interface:
-
Stacks lists, validates dtype, optionally does checks for NaN/Inf.
-
If
with_masking
, callsmask_generator(batch_size)
and addsmasks
,context_masks
,target_masks
,original_patch_indices
,channels_to_mask
to thedata_sample
metainfo.
DaliPreprocessor
Similar to TorchPreprocessor
, but input comes as a tuple from DALIGenericIterator
; the tensor is already on GPU.
RayPreprocessor
Similar to TorchPreprocessor
, also allows for applying transformations.
Generates per-batch, patch-level masks for self-supervised learning over 3D/4D
inputs with explicit time and space awareness. We currently support the following mask generation modes:
-
BLOCKED
/BLOCKED_TIME_ONLY
/BLOCKED_SPACE_ONLY
: samples block sizes from specified scales, creates spatial/temporal blocks. -
RANDOM
/RANDOM_SPACE_ONLY
: MAE-stylenoise -> sort -> split
mask creation pipeline that splits patch-level masks intocontext
andtarget
sets; inSPACE_ONLY
, the same spatial mask is repeated across time. -
BLOCKED_PATTERNED
: downsample time according to a provided pattern and build masks deterministically.
The MaskGenerator
object is owned by the Preprocessor
; when with_masking=True
, the preprocessor attaches all mask tensors into metainfo for the model step.
List of methods that return a transformed tensor. Where to apply: If the transform is cheap, you can attach it to the dataset by attaching the necessary config to configs/datasets/my_dataset/transforms.yaml
(will be included inside __getitem__
in Torch
or in get_dataset_ray
with Ray Data
). If it is heavy, consider applying it on device in the Preprocessor
to avoid extra CPU cost. For DALI
both CPU and on device operations should be specified as part of the DALI
pipeline.
DatasetEvaluator:
Abstract base class defining three key methods: reset()
for initialization, process()
for accumulating predictions during evaluation, and evaluate()
for computing final metrics and returning results as dictionaries.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at:
Copyright 2025 Cell Observatory.