Skip to content

Prompt-Guided Latent Diffusion with Predictive Class Conditioning for 3D Prostate MRI Generation

License

Apache-2.0, Unknown licenses found

Licenses found

Apache-2.0
LICENSE
Unknown
LICENSE.weights
Notifications You must be signed in to change notification settings

grabkeem/CCELLA

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

7 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Prompt-Guided Latent Diffusion with Predictive Class Conditioning for 3D Prostate MRI Generation

Authors: Emerson P. Grabke, Masoom A. Haider*, Babak Taati* arXiv: arXiv:2506.10230 PDF: Download PDF

*MAH and BT are co-senior authors on this work

Manuscript has been submitted to IEEE for possible publication

Abstract

Objective: Latent diffusion models (LDM) could alleviate data scarcity challenges affecting machine learning development for medical imaging. However, medical LDM strategies typically rely on short-prompt text encoders, non-medical LDMs, or large data volumes. These strategies can limit performance and scientific accessibility. We propose a novel LDM conditioning approach to address these limitations. Methods: We propose Class-Conditioned Efficient Large Language model Adapter (CCELLA), a novel dual-head conditioning approach that simultaneously conditions the LDM U-Net with free-text clinical reports and radiology classification. We also propose a data-efficient LDM framework centered around CCELLA and a proposed joint loss function. We first evaluate our method on 3D prostate MRI against state-of-the-art. We then augment a downstream classifier model training dataset with synthetic images from our method. Results: Our method achieves a 3D FID score of 0.025 on a size-limited 3D prostate MRI dataset, significantly outperforming a recent foundation model with FID 0.070. When training a classifier for prostate cancer prediction, adding synthetic images generated by our method during training improves classifier accuracy from 69% to 74%. Training a classifier solely on our method's synthetic images achieved comparable performance to training on real images alone. Conclusion: We show that our method improved both synthetic image quality and downstream classifier performance using limited data and minimal human annotation. Significance: The proposed CCELLA-centric framework enables radiology report and class-conditioned LDM training for high-quality medical image synthesis given limited data volume and human data annotation, improving LDM performance and scientific accessibility.

Code

This repository contains the scripts needed to train the CCELLA LDM and the other LDM variants in this work. All scripts are expected to be run from the ccella folder, with examples for multi-gpu use where available.

1. Starting Requirements

All data should already be resampled to the desired (fixed) image size. The JSON files at ./datasets/train.json and ./datasets/test.json should contain non-overlapping entries for each data sample, with the "text" field of each report containing the plaintext radiology report for that exam (or "" if empty in which case "text_isnull" should be set to 1). The PI-RADS vector should be onehot and set to [neg,pos] (i.e. [1,0] if PI-RADS in [1,2], else [0,1]).

There should additionally be an Excel spreadsheet (.xlsx) with the following columns:

  • "folder" : patient ID and subdirectory
  • "PosPIRADS" : Positive or negative (or empty) based on the case PI-RADS grading
  • "PosISUP" : Positive or negative (or empty) based on the case ISUP grading
  • "Label": Positive or negative based on the PIRADS and/or ISUP labels

Note that one of PosPIRADS or PosISUP needs to be nonempty for every row.

2. Data Preprocessing

Two scripts should be run in succession to generate the image and text embeddings for use in the LDM training pipeline. Example run:

export NUM_GPUS_PER_NODE=2
torchrun \
  --nproc_per_node=${NUM_GPUS_PER_NODE} \
  --nnodes=1 \
  --master_addr=localhost --master_port=1234 \
  -m scripts.data_processing.diff_model_create_training_data --model_def ./configs/CCELLA_def.json --model_config ./configs/CCELLA_config.json --env_config ./configs/CCELLA_env.json --num_gpus ${NUM_GPUS_PER_NODE}
python -m scripts.data_processing.gen_json_maisi_merged

3. Model Training and Evaluation

LDMs can be trained using the following command:

export NUM_GPUS_PER_NODE=2
CUDA_VISIBLE_DEVICES=0,1 torchrun \
  --nproc_per_node=${NUM_GPUS_PER_NODE} \
  --nnodes=1 \
  --master_addr=localhost --master_port=1234 \
  -m scripts.diff_model_train_all --model_def ./configs/CCELLA_def.json --model_config ./configs/CCELLA_config.json --env_config ./configs/CCELLA_env.json --num_gpus ${NUM_GPUS_PER_NODE}

For evaluation:

CUDA_VISIBLE_DEVICES=0 torchrun \
  --nproc_per_node=1 \
  --nnodes=1 \
  --master_addr=localhost --master_port=1234 \
  -m scripts.evaluate_diffusion --model_def ./configs/CCELLA_def.json --model_config ./configs/CCELLA_config.json --env_config ./configs/CCELLA_env.json

JSON files for each of the model definition, model config, and environment config have been provided in ./configs/

4. Synthetic Data Generation

To generate synthetic images (e.g. for use in a downstream classifier task):

CUDA_VISIBLE_DEVICES=0,1 torchrun \
  --nproc_per_node=4 \
  --nnodes=1 \
  --master_addr=localhost --master_port=1234 \
  -m scripts.synthetic_datagen --model_def ./configs/CCELLA_def.json --model_config ./configs/CCELLA_config.json --env_config ./configs/CCELLA_env.json --num_gpus 2

Alongside generating the images, the scripts will generate multiple Excel sheets per GPU and per each of train and test splits, in the form "{output_dir}/synthsheet_{local_rank}_{train/test}.xlsx". These need to be merged manually into a single spreadsheet at "{output_dir}/synthsheet_all.xlsx"

5. Downstream Classifier

Downstream EfficientNet classifiers can be trained and evaluated with the following commands:

export NUM_GPUS_PER_NODE=2
torchrun \
    --nproc_per_node=${NUM_GPUS_PER_NODE} \
    --nnodes=1 \
    --master_addr=localhost --master_port=1234 \
  -m scripts.downstream_classifier [-r,-s,-s -l]

The following flags change classifier behavior:

  • -r : Use real data if set
  • -s : Use synthetic data if set
  • -l : If using synthetic data, use CCELLA-generated synthetic labels as ground truth instead of clinical ground truth

Thus the following classifiers from the manuscript can be trained:

  • -r : Real data only
  • -s : Synthetic data with real labels
  • -s -l : Synthetic data with synthetic labels
  • -r -s : Real + Synthetic data (real labels)
  • -r -s -l : Real + Synthetic data (synthetic labels)

Classifiers are named according to the format: f"ds_enb0_r{int(args.real_data)}_s{int(args.synth_data)}_l{int(args.generated_label)}"

Acknowledgements

The code in this repository has built upon publicly released code from MONAI, MAISI, and ELLA

About

Prompt-Guided Latent Diffusion with Predictive Class Conditioning for 3D Prostate MRI Generation

Resources

License

Apache-2.0, Unknown licenses found

Licenses found

Apache-2.0
LICENSE
Unknown
LICENSE.weights

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages