|
| 1 | +#!/usr/bin/env python |
| 2 | +import logging |
| 3 | +import os |
| 4 | +from pathlib import Path |
| 5 | + |
| 6 | +import numpy as np |
| 7 | +import rich_click as click |
| 8 | +from anndata import AnnData |
| 9 | +from beartype import beartype |
| 10 | +from beartype.typing import Dict, List, Optional |
| 11 | +from rich.console import Console |
| 12 | +from rich.logging import RichHandler |
| 13 | +from rich.theme import Theme |
| 14 | + |
| 15 | +from pyrovelocity.io.datasets import larry_mono, larry_neu |
| 16 | +from pyrovelocity.plots._trajectory import get_clone_trajectory |
| 17 | + |
| 18 | +click.rich_click.SHOW_ARGUMENTS = True |
| 19 | +click.rich_click.USE_MARKDOWN = True |
| 20 | + |
| 21 | + |
| 22 | +def configure_logging(logger_name: str = "clone_gen") -> logging.Logger: |
| 23 | + """Configure rich logging with custom theme.""" |
| 24 | + console_theme = Theme( |
| 25 | + { |
| 26 | + "logging.level.info": "dim cyan", |
| 27 | + "logging.level.warning": "magenta", |
| 28 | + "logging.level.error": "bold red", |
| 29 | + "logging.level.debug": "green", |
| 30 | + } |
| 31 | + ) |
| 32 | + console = Console(theme=console_theme) |
| 33 | + rich_handler = RichHandler( |
| 34 | + console=console, |
| 35 | + rich_tracebacks=True, |
| 36 | + show_time=True, |
| 37 | + show_level=True, |
| 38 | + show_path=False, |
| 39 | + markup=True, |
| 40 | + log_time_format="[%X]", |
| 41 | + ) |
| 42 | + valid_log_levels = ["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"] |
| 43 | + log_level = os.getenv("LOG_LEVEL", "INFO").upper() |
| 44 | + |
| 45 | + if log_level not in valid_log_levels: |
| 46 | + log_level = "INFO" |
| 47 | + |
| 48 | + logging.basicConfig( |
| 49 | + level=log_level, |
| 50 | + format="%(message)s", |
| 51 | + datefmt="[%X]", |
| 52 | + handlers=[rich_handler], |
| 53 | + ) |
| 54 | + logger = logging.getLogger(logger_name) |
| 55 | + logger.setLevel(log_level) |
| 56 | + return logger |
| 57 | + |
| 58 | + |
| 59 | +logger = configure_logging() |
| 60 | + |
| 61 | + |
| 62 | +@beartype |
| 63 | +def generate_clone_trajectory( |
| 64 | + adata: AnnData, |
| 65 | + average_start_point: bool = True, |
| 66 | + times: List[int] = [2, 4, 6], |
| 67 | + clone_num: Optional[int] = None, |
| 68 | + fix_nans: bool = True, |
| 69 | +) -> AnnData: |
| 70 | + """Generate clone trajectory data from AnnData object. |
| 71 | +
|
| 72 | + Args: |
| 73 | + adata: The input AnnData object |
| 74 | + average_start_point: Whether to average the start point |
| 75 | + times: List of time points to consider |
| 76 | + clone_num: Maximum number of clones to process |
| 77 | + fix_nans: Whether to replace NaN values with zeros |
| 78 | +
|
| 79 | + Returns: |
| 80 | + AnnData object with clone trajectory information |
| 81 | + """ |
| 82 | + logger.info(f"Generating clone trajectory for dataset with {adata.n_obs} cells") |
| 83 | + adata_clone = get_clone_trajectory( |
| 84 | + adata, average_start_point=average_start_point, |
| 85 | + times=times, clone_num=clone_num |
| 86 | + ) |
| 87 | + |
| 88 | + if fix_nans and "clone_vector_emb" in adata_clone.obsm: |
| 89 | + nan_count = np.isnan(adata_clone.obsm["clone_vector_emb"]).sum() |
| 90 | + if nan_count > 0: |
| 91 | + logger.info(f"Fixing {nan_count} NaN values in clone_vector_emb") |
| 92 | + adata_clone.obsm["clone_vector_emb"][ |
| 93 | + np.isnan(adata_clone.obsm["clone_vector_emb"]) |
| 94 | + ] = 0 |
| 95 | + |
| 96 | + return adata_clone |
| 97 | + |
| 98 | + |
| 99 | +@beartype |
| 100 | +def generate_all_clone_trajectories( |
| 101 | + output_dir: Path, |
| 102 | + mono_path: Optional[str] = None, |
| 103 | + neu_path: Optional[str] = None, |
| 104 | + output_names: Dict[str, str] = None, |
| 105 | +) -> Dict[str, Path]: |
| 106 | + """Pre-compute and cache clone trajectories for different lineage datasets. |
| 107 | +
|
| 108 | + Args: |
| 109 | + output_dir: Directory to save generated trajectory files |
| 110 | + mono_path: Optional custom path for mono dataset |
| 111 | + neu_path: Optional custom path for neu dataset |
| 112 | + output_names: Optional custom output filenames |
| 113 | +
|
| 114 | + Returns: |
| 115 | + Dictionary mapping dataset names to file paths |
| 116 | + """ |
| 117 | + output_dir.mkdir(parents=True, exist_ok=True) |
| 118 | + |
| 119 | + if output_names is None: |
| 120 | + output_names = { |
| 121 | + "mono": "larry_mono_clone_trajectory.h5ad", |
| 122 | + "neu": "larry_neu_clone_trajectory.h5ad", |
| 123 | + "multilineage": "larry_multilineage_clone_trajectory.h5ad", |
| 124 | + } |
| 125 | + |
| 126 | + logger.info(f"Loading monocyte lineage data from {'custom path' if mono_path else 'default path'}") |
| 127 | + mono_adata = larry_mono(mono_path) if mono_path else larry_mono() |
| 128 | + mono_clone = generate_clone_trajectory(mono_adata) |
| 129 | + mono_clone_path = output_dir / output_names["mono"] |
| 130 | + logger.info(f"Writing monocyte clone trajectory to {mono_clone_path}") |
| 131 | + mono_clone.write_h5ad(mono_clone_path) |
| 132 | + |
| 133 | + logger.info(f"Loading neutrophil lineage data from {'custom path' if neu_path else 'default path'}") |
| 134 | + neu_adata = larry_neu(neu_path) if neu_path else larry_neu() |
| 135 | + neu_clone = generate_clone_trajectory(neu_adata) |
| 136 | + neu_clone_path = output_dir / output_names["neu"] |
| 137 | + logger.info(f"Writing neutrophil clone trajectory to {neu_clone_path}") |
| 138 | + neu_clone.write_h5ad(neu_clone_path) |
| 139 | + |
| 140 | + logger.info("Creating concatenated multilineage clone trajectory") |
| 141 | + multi_clone = mono_clone.concatenate(neu_clone) |
| 142 | + multi_clone_path = output_dir / output_names["multilineage"] |
| 143 | + logger.info(f"Writing multilineage clone trajectory to {multi_clone_path}") |
| 144 | + multi_clone.write_h5ad(multi_clone_path) |
| 145 | + |
| 146 | + logger.info("All clone trajectories generated successfully") |
| 147 | + |
| 148 | + return { |
| 149 | + "mono": mono_clone_path, |
| 150 | + "neu": neu_clone_path, |
| 151 | + "multilineage": multi_clone_path |
| 152 | + } |
| 153 | + |
| 154 | + |
| 155 | +@click.group( |
| 156 | + invoke_without_command=True, |
| 157 | + context_settings={"help_option_names": ["-h", "--help"]}, |
| 158 | +) |
| 159 | +@click.pass_context |
| 160 | +def cli(ctx): |
| 161 | + """ |
| 162 | + # clone_gen |
| 163 | + _**clone_gen**_ generates pre-computed clone trajectory files for PyroVelocity. |
| 164 | + |
| 165 | + This tool downloads LARRY dataset samples and computes clone trajectories that |
| 166 | + can be later used directly in the plot_lineage_fate_correlation function. |
| 167 | + |
| 168 | + Pass -h or --help to each command group listed below for detailed help. |
| 169 | + """ |
| 170 | + if ctx.invoked_subcommand is None: |
| 171 | + click.echo(ctx.get_help()) |
| 172 | + |
| 173 | + |
| 174 | +@cli.command("generate") |
| 175 | +@click.option( |
| 176 | + "-o", |
| 177 | + "--output-dir", |
| 178 | + "output_dir", |
| 179 | + default="data/external", |
| 180 | + help="Output directory for the generated trajectories.", |
| 181 | + show_default=True, |
| 182 | + type=click.Path(), |
| 183 | +) |
| 184 | +@click.option( |
| 185 | + "--mono-path", |
| 186 | + "mono_path", |
| 187 | + default=None, |
| 188 | + help="Optional custom path for larry_mono dataset.", |
| 189 | + type=click.Path(exists=False), |
| 190 | +) |
| 191 | +@click.option( |
| 192 | + "--neu-path", |
| 193 | + "neu_path", |
| 194 | + default=None, |
| 195 | + help="Optional custom path for larry_neu dataset.", |
| 196 | + type=click.Path(exists=False), |
| 197 | +) |
| 198 | +def generate_trajectories(output_dir, mono_path, neu_path): |
| 199 | + """ |
| 200 | + # clone_gen generate |
| 201 | + |
| 202 | + Generate pre-computed clone trajectories for the LARRY datasets. |
| 203 | + |
| 204 | + This command: |
| 205 | + 1. Downloads the larry_mono and larry_neu datasets if needed |
| 206 | + 2. Computes clone trajectories using get_clone_trajectory |
| 207 | + 3. Creates a concatenated multilineage trajectory |
| 208 | + 4. Saves all trajectories to h5ad files |
| 209 | + |
| 210 | + These pre-computed trajectories can then be used with plot_lineage_fate_correlation |
| 211 | + to generate consistent visualizations without redundant computation. |
| 212 | + """ |
| 213 | + output_dir_path = Path(output_dir) |
| 214 | + result_paths = generate_all_clone_trajectories( |
| 215 | + output_dir=output_dir_path, |
| 216 | + mono_path=mono_path, |
| 217 | + neu_path=neu_path, |
| 218 | + ) |
| 219 | + |
| 220 | + logger.info("Clone trajectories generated and saved to:") |
| 221 | + for name, path in result_paths.items(): |
| 222 | + logger.info(f" - {name}: {path}") |
| 223 | + |
| 224 | + logger.info("\nYou can now create functions in pyrovelocity.io.datasets to load these files:") |
| 225 | + logger.info(""" |
| 226 | +@beartype |
| 227 | +def larry_mono_clone_trajectory( |
| 228 | + file_path: str | Path = "data/external/larry_mono_clone_trajectory.h5ad", |
| 229 | +) -> anndata._core.anndata.AnnData: |
| 230 | + \"\"\" |
| 231 | + Pre-computed clone trajectory data for the LARRY monocyte lineage. |
| 232 | + |
| 233 | + This contains the output of get_clone_trajectory applied to the larry_mono dataset. |
| 234 | + |
| 235 | + Returns: |
| 236 | + AnnData object with clone trajectory information |
| 237 | + \"\"\" |
| 238 | + url = "https://storage.googleapis.com/pyrovelocity/data/larry_mono_clone_trajectory.h5ad" |
| 239 | + adata = sc.read(file_path, backup_url=url, sparse=True, cache=True) |
| 240 | + return adata |
| 241 | + """) |
| 242 | + |
| 243 | + |
| 244 | +@cli.command("examine") |
| 245 | +@click.argument( |
| 246 | + "trajectory_path", |
| 247 | + type=click.Path(exists=True), |
| 248 | +) |
| 249 | +def examine_trajectory(trajectory_path): |
| 250 | + """ |
| 251 | + # clone_gen examine |
| 252 | + |
| 253 | + Examine a generated clone trajectory file and print information about its contents. |
| 254 | + |
| 255 | + ## arguments |
| 256 | + - `TRAJECTORY_PATH`: Path to the clone trajectory file to examine |
| 257 | + """ |
| 258 | + import scanpy as sc |
| 259 | + |
| 260 | + try: |
| 261 | + adata = sc.read(trajectory_path) |
| 262 | + logger.info(f"Successfully loaded file: {trajectory_path}") |
| 263 | + logger.info(f"AnnData object with n_obs × n_vars = {adata.n_obs} × {adata.n_vars}") |
| 264 | + |
| 265 | + if "state_info" in adata.obs: |
| 266 | + centroid_count = sum(adata.obs["state_info"] == "Centroid") |
| 267 | + logger.info(f"Contains {centroid_count} centroid cells") |
| 268 | + |
| 269 | + if "clone_vector_emb" in adata.obsm: |
| 270 | + logger.info("Contains clone_vector_emb in obsm") |
| 271 | + nan_count = np.isnan(adata.obsm["clone_vector_emb"]).sum() |
| 272 | + if nan_count > 0: |
| 273 | + logger.warning(f"Contains {nan_count} NaN values in clone_vector_emb") |
| 274 | + else: |
| 275 | + logger.info("No NaN values found in clone_vector_emb") |
| 276 | + else: |
| 277 | + logger.error("Missing clone_vector_emb in obsm") |
| 278 | + |
| 279 | + logger.info("\nAvailable keys:") |
| 280 | + logger.info(f" obs keys: {list(adata.obs.keys())}") |
| 281 | + logger.info(f" var keys: {list(adata.var.keys())}") |
| 282 | + logger.info(f" obsm keys: {list(adata.obsm.keys())}") |
| 283 | + |
| 284 | + except Exception as e: |
| 285 | + logger.error(f"Error examining trajectory file: {e}") |
| 286 | + |
| 287 | + |
| 288 | +if __name__ == "__main__": |
| 289 | + cli() |
0 commit comments