Skip to content

Commit e49ee6d

Browse files
chore(scripts): add script to generate clone trajectory data
Signed-off-by: Cameron Smith <cameron.ray.smith@gmail.com>
1 parent 82f67aa commit e49ee6d

File tree

1 file changed

+289
-0
lines changed

1 file changed

+289
-0
lines changed

scripts/clone/clone_gen.py

Lines changed: 289 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,289 @@
1+
#!/usr/bin/env python
2+
import logging
3+
import os
4+
from pathlib import Path
5+
6+
import numpy as np
7+
import rich_click as click
8+
from anndata import AnnData
9+
from beartype import beartype
10+
from beartype.typing import Dict, List, Optional
11+
from rich.console import Console
12+
from rich.logging import RichHandler
13+
from rich.theme import Theme
14+
15+
from pyrovelocity.io.datasets import larry_mono, larry_neu
16+
from pyrovelocity.plots._trajectory import get_clone_trajectory
17+
18+
click.rich_click.SHOW_ARGUMENTS = True
19+
click.rich_click.USE_MARKDOWN = True
20+
21+
22+
def configure_logging(logger_name: str = "clone_gen") -> logging.Logger:
23+
"""Configure rich logging with custom theme."""
24+
console_theme = Theme(
25+
{
26+
"logging.level.info": "dim cyan",
27+
"logging.level.warning": "magenta",
28+
"logging.level.error": "bold red",
29+
"logging.level.debug": "green",
30+
}
31+
)
32+
console = Console(theme=console_theme)
33+
rich_handler = RichHandler(
34+
console=console,
35+
rich_tracebacks=True,
36+
show_time=True,
37+
show_level=True,
38+
show_path=False,
39+
markup=True,
40+
log_time_format="[%X]",
41+
)
42+
valid_log_levels = ["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"]
43+
log_level = os.getenv("LOG_LEVEL", "INFO").upper()
44+
45+
if log_level not in valid_log_levels:
46+
log_level = "INFO"
47+
48+
logging.basicConfig(
49+
level=log_level,
50+
format="%(message)s",
51+
datefmt="[%X]",
52+
handlers=[rich_handler],
53+
)
54+
logger = logging.getLogger(logger_name)
55+
logger.setLevel(log_level)
56+
return logger
57+
58+
59+
logger = configure_logging()
60+
61+
62+
@beartype
63+
def generate_clone_trajectory(
64+
adata: AnnData,
65+
average_start_point: bool = True,
66+
times: List[int] = [2, 4, 6],
67+
clone_num: Optional[int] = None,
68+
fix_nans: bool = True,
69+
) -> AnnData:
70+
"""Generate clone trajectory data from AnnData object.
71+
72+
Args:
73+
adata: The input AnnData object
74+
average_start_point: Whether to average the start point
75+
times: List of time points to consider
76+
clone_num: Maximum number of clones to process
77+
fix_nans: Whether to replace NaN values with zeros
78+
79+
Returns:
80+
AnnData object with clone trajectory information
81+
"""
82+
logger.info(f"Generating clone trajectory for dataset with {adata.n_obs} cells")
83+
adata_clone = get_clone_trajectory(
84+
adata, average_start_point=average_start_point,
85+
times=times, clone_num=clone_num
86+
)
87+
88+
if fix_nans and "clone_vector_emb" in adata_clone.obsm:
89+
nan_count = np.isnan(adata_clone.obsm["clone_vector_emb"]).sum()
90+
if nan_count > 0:
91+
logger.info(f"Fixing {nan_count} NaN values in clone_vector_emb")
92+
adata_clone.obsm["clone_vector_emb"][
93+
np.isnan(adata_clone.obsm["clone_vector_emb"])
94+
] = 0
95+
96+
return adata_clone
97+
98+
99+
@beartype
100+
def generate_all_clone_trajectories(
101+
output_dir: Path,
102+
mono_path: Optional[str] = None,
103+
neu_path: Optional[str] = None,
104+
output_names: Dict[str, str] = None,
105+
) -> Dict[str, Path]:
106+
"""Pre-compute and cache clone trajectories for different lineage datasets.
107+
108+
Args:
109+
output_dir: Directory to save generated trajectory files
110+
mono_path: Optional custom path for mono dataset
111+
neu_path: Optional custom path for neu dataset
112+
output_names: Optional custom output filenames
113+
114+
Returns:
115+
Dictionary mapping dataset names to file paths
116+
"""
117+
output_dir.mkdir(parents=True, exist_ok=True)
118+
119+
if output_names is None:
120+
output_names = {
121+
"mono": "larry_mono_clone_trajectory.h5ad",
122+
"neu": "larry_neu_clone_trajectory.h5ad",
123+
"multilineage": "larry_multilineage_clone_trajectory.h5ad",
124+
}
125+
126+
logger.info(f"Loading monocyte lineage data from {'custom path' if mono_path else 'default path'}")
127+
mono_adata = larry_mono(mono_path) if mono_path else larry_mono()
128+
mono_clone = generate_clone_trajectory(mono_adata)
129+
mono_clone_path = output_dir / output_names["mono"]
130+
logger.info(f"Writing monocyte clone trajectory to {mono_clone_path}")
131+
mono_clone.write_h5ad(mono_clone_path)
132+
133+
logger.info(f"Loading neutrophil lineage data from {'custom path' if neu_path else 'default path'}")
134+
neu_adata = larry_neu(neu_path) if neu_path else larry_neu()
135+
neu_clone = generate_clone_trajectory(neu_adata)
136+
neu_clone_path = output_dir / output_names["neu"]
137+
logger.info(f"Writing neutrophil clone trajectory to {neu_clone_path}")
138+
neu_clone.write_h5ad(neu_clone_path)
139+
140+
logger.info("Creating concatenated multilineage clone trajectory")
141+
multi_clone = mono_clone.concatenate(neu_clone)
142+
multi_clone_path = output_dir / output_names["multilineage"]
143+
logger.info(f"Writing multilineage clone trajectory to {multi_clone_path}")
144+
multi_clone.write_h5ad(multi_clone_path)
145+
146+
logger.info("All clone trajectories generated successfully")
147+
148+
return {
149+
"mono": mono_clone_path,
150+
"neu": neu_clone_path,
151+
"multilineage": multi_clone_path
152+
}
153+
154+
155+
@click.group(
156+
invoke_without_command=True,
157+
context_settings={"help_option_names": ["-h", "--help"]},
158+
)
159+
@click.pass_context
160+
def cli(ctx):
161+
"""
162+
# clone_gen
163+
_**clone_gen**_ generates pre-computed clone trajectory files for PyroVelocity.
164+
165+
This tool downloads LARRY dataset samples and computes clone trajectories that
166+
can be later used directly in the plot_lineage_fate_correlation function.
167+
168+
Pass -h or --help to each command group listed below for detailed help.
169+
"""
170+
if ctx.invoked_subcommand is None:
171+
click.echo(ctx.get_help())
172+
173+
174+
@cli.command("generate")
175+
@click.option(
176+
"-o",
177+
"--output-dir",
178+
"output_dir",
179+
default="data/external",
180+
help="Output directory for the generated trajectories.",
181+
show_default=True,
182+
type=click.Path(),
183+
)
184+
@click.option(
185+
"--mono-path",
186+
"mono_path",
187+
default=None,
188+
help="Optional custom path for larry_mono dataset.",
189+
type=click.Path(exists=False),
190+
)
191+
@click.option(
192+
"--neu-path",
193+
"neu_path",
194+
default=None,
195+
help="Optional custom path for larry_neu dataset.",
196+
type=click.Path(exists=False),
197+
)
198+
def generate_trajectories(output_dir, mono_path, neu_path):
199+
"""
200+
# clone_gen generate
201+
202+
Generate pre-computed clone trajectories for the LARRY datasets.
203+
204+
This command:
205+
1. Downloads the larry_mono and larry_neu datasets if needed
206+
2. Computes clone trajectories using get_clone_trajectory
207+
3. Creates a concatenated multilineage trajectory
208+
4. Saves all trajectories to h5ad files
209+
210+
These pre-computed trajectories can then be used with plot_lineage_fate_correlation
211+
to generate consistent visualizations without redundant computation.
212+
"""
213+
output_dir_path = Path(output_dir)
214+
result_paths = generate_all_clone_trajectories(
215+
output_dir=output_dir_path,
216+
mono_path=mono_path,
217+
neu_path=neu_path,
218+
)
219+
220+
logger.info("Clone trajectories generated and saved to:")
221+
for name, path in result_paths.items():
222+
logger.info(f" - {name}: {path}")
223+
224+
logger.info("\nYou can now create functions in pyrovelocity.io.datasets to load these files:")
225+
logger.info("""
226+
@beartype
227+
def larry_mono_clone_trajectory(
228+
file_path: str | Path = "data/external/larry_mono_clone_trajectory.h5ad",
229+
) -> anndata._core.anndata.AnnData:
230+
\"\"\"
231+
Pre-computed clone trajectory data for the LARRY monocyte lineage.
232+
233+
This contains the output of get_clone_trajectory applied to the larry_mono dataset.
234+
235+
Returns:
236+
AnnData object with clone trajectory information
237+
\"\"\"
238+
url = "https://storage.googleapis.com/pyrovelocity/data/larry_mono_clone_trajectory.h5ad"
239+
adata = sc.read(file_path, backup_url=url, sparse=True, cache=True)
240+
return adata
241+
""")
242+
243+
244+
@cli.command("examine")
245+
@click.argument(
246+
"trajectory_path",
247+
type=click.Path(exists=True),
248+
)
249+
def examine_trajectory(trajectory_path):
250+
"""
251+
# clone_gen examine
252+
253+
Examine a generated clone trajectory file and print information about its contents.
254+
255+
## arguments
256+
- `TRAJECTORY_PATH`: Path to the clone trajectory file to examine
257+
"""
258+
import scanpy as sc
259+
260+
try:
261+
adata = sc.read(trajectory_path)
262+
logger.info(f"Successfully loaded file: {trajectory_path}")
263+
logger.info(f"AnnData object with n_obs × n_vars = {adata.n_obs} × {adata.n_vars}")
264+
265+
if "state_info" in adata.obs:
266+
centroid_count = sum(adata.obs["state_info"] == "Centroid")
267+
logger.info(f"Contains {centroid_count} centroid cells")
268+
269+
if "clone_vector_emb" in adata.obsm:
270+
logger.info("Contains clone_vector_emb in obsm")
271+
nan_count = np.isnan(adata.obsm["clone_vector_emb"]).sum()
272+
if nan_count > 0:
273+
logger.warning(f"Contains {nan_count} NaN values in clone_vector_emb")
274+
else:
275+
logger.info("No NaN values found in clone_vector_emb")
276+
else:
277+
logger.error("Missing clone_vector_emb in obsm")
278+
279+
logger.info("\nAvailable keys:")
280+
logger.info(f" obs keys: {list(adata.obs.keys())}")
281+
logger.info(f" var keys: {list(adata.var.keys())}")
282+
logger.info(f" obsm keys: {list(adata.obsm.keys())}")
283+
284+
except Exception as e:
285+
logger.error(f"Error examining trajectory file: {e}")
286+
287+
288+
if __name__ == "__main__":
289+
cli()

0 commit comments

Comments
 (0)