Skip to content

Commit 8919b0b

Browse files
test(cytotrace): add test module
Signed-off-by: Cameron Smith <cameron.ray.smith@gmail.com>
1 parent f1628e5 commit 8919b0b

File tree

1 file changed

+134
-2
lines changed

1 file changed

+134
-2
lines changed
Lines changed: 134 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,139 @@
11
"""Tests for `pyrovelocity.analysis.cytotrace` module."""
22

33

4+
import numpy as np
5+
import pandas as pd
6+
import pytest
7+
from anndata import AnnData
8+
from numpy.testing import assert_array_almost_equal
9+
from scipy.sparse import csr_matrix
10+
11+
from pyrovelocity.analysis import cytotrace
12+
13+
414
def test_load_cytotrace():
5-
pass
15+
print(cytotrace.__file__)
16+
17+
18+
@pytest.fixture
19+
def small_anndata():
20+
X = csr_matrix([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
21+
obs = pd.DataFrame(index=["cell1", "cell2", "cell3"])
22+
var = pd.DataFrame(index=["gene1", "gene2", "gene3"])
23+
adata = AnnData(X, obs=obs, var=var)
24+
adata.layers["raw"] = X
25+
return adata
26+
27+
28+
def test_compute_similarity2():
29+
O = np.array([[1, 2, 3], [4, 5, 6]])
30+
P = np.array([[1, 2], [3, 4]])
31+
result = cytotrace.compute_similarity2(O, P)
32+
assert result.shape == (2, 3)
33+
assert np.allclose(result.T, np.corrcoef(O.T, P)[:3, 3:], atol=1e-5)
34+
35+
36+
def test_compute_similarity1():
37+
A = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
38+
result = cytotrace.compute_similarity1(A)
39+
assert result.shape == (3, 3)
40+
assert np.allclose(result, np.corrcoef(A.T))
41+
42+
43+
def test_compute_gcs():
44+
mat = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
45+
count = np.array([2, 3, 3])
46+
result = cytotrace.compute_gcs(mat, count, top_n_genes=2)
47+
assert result.shape == (3,)
48+
49+
50+
def test_threshold_and_normalize_similarity_matrix():
51+
sim = np.array(
52+
[
53+
[1.0, 0.8, 0.3, 0.1],
54+
[0.8, 1.0, 0.5, 0.2],
55+
[0.3, 0.5, 1.0, 0.7],
56+
[0.1, 0.2, 0.7, 1.0],
57+
]
58+
)
59+
60+
result = cytotrace.threshold_and_normalize_similarity_matrix(sim)
61+
62+
# check diagonal is zeroed out
63+
assert np.all(np.diag(result) == 0)
64+
65+
# check values below or equal to mean are zeroed out
66+
mean_sim = np.mean(sim)
67+
assert np.all(result[sim <= mean_sim] == 0)
68+
69+
# check non-zero rows are normalized to sum to 1
70+
non_zero_rows = np.where(result.sum(axis=1) > 0)[0]
71+
for row in non_zero_rows:
72+
assert_array_almost_equal(result[row].sum(), 1.0, decimal=6)
73+
74+
# check zero rows remain zero
75+
zero_rows = np.where(result.sum(axis=1) == 0)[0]
76+
assert np.all(result[zero_rows] == 0)
77+
78+
# check the result is sparse (contains zeros)
79+
assert np.sum(result == 0) > 0
80+
81+
# check the result preserves symmetry
82+
if np.allclose(sim, sim.T):
83+
assert np.allclose(result, result.T)
84+
85+
# check stronger similarities are preserved
86+
stronger_similarities = sim > np.mean(sim)
87+
assert np.all(result[stronger_similarities] >= 0)
88+
89+
# check weaker similarities are removed
90+
weaker_similarities = sim <= np.mean(sim)
91+
assert np.all(result[weaker_similarities] == 0)
92+
93+
# check behavior with all-zero input
94+
zero_sim = np.zeros_like(sim)
95+
zero_result = cytotrace.threshold_and_normalize_similarity_matrix(zero_sim)
96+
assert np.all(zero_result == 0)
97+
98+
# check behavior with negative values
99+
neg_sim = np.array([[-1, 0.5], [0.5, -1]])
100+
neg_result = cytotrace.threshold_and_normalize_similarity_matrix(neg_sim)
101+
assert np.all(neg_result >= 0)
102+
assert_array_almost_equal(neg_result, np.array([[0, 1], [1, 0]]))
103+
104+
105+
def test_diffused():
106+
markov = np.array([[0.7, 0.2, 0.1], [0.3, 0.5, 0.2], [0.1, 0.3, 0.6]])
107+
gcs = np.array([1, 2, 3])
108+
result = cytotrace.diffused(markov, gcs)
109+
assert result.shape == gcs.shape
110+
111+
112+
def test_cytotrace_sparse(small_anndata, monkeypatch):
113+
result = cytotrace.cytotrace_sparse(small_anndata, layer="raw")
114+
115+
assert isinstance(result, dict)
116+
assert "CytoTRACE" in result
117+
assert "GCS" in result
118+
assert "cytoGenes" in result
119+
120+
assert "gcs" in small_anndata.obs.columns
121+
assert "cytotrace" in small_anndata.obs.columns
122+
assert "counts" in small_anndata.obs.columns
123+
assert "cytotrace" in small_anndata.var.columns
124+
assert "cytotrace_corrs" in small_anndata.var.columns
125+
126+
127+
def test_cytotrace_sparse_errors():
128+
adata = AnnData(X=np.array([[1, 2], [3, 4]]))
129+
adata.layers["raw"] = adata.X
130+
131+
with pytest.raises(
132+
NotImplementedError,
133+
):
134+
cytotrace.cytotrace_sparse(adata)
6135

7-
# print(cytotrace.__file__)
136+
with pytest.raises(
137+
KeyError,
138+
):
139+
cytotrace.cytotrace_sparse(adata, layer="non_existent")

0 commit comments

Comments
 (0)