|
1 | 1 | """Tests for `pyrovelocity.analysis.cytotrace` module."""
|
2 | 2 |
|
3 | 3 |
|
| 4 | +import numpy as np |
| 5 | +import pandas as pd |
| 6 | +import pytest |
| 7 | +from anndata import AnnData |
| 8 | +from numpy.testing import assert_array_almost_equal |
| 9 | +from scipy.sparse import csr_matrix |
| 10 | + |
| 11 | +from pyrovelocity.analysis import cytotrace |
| 12 | + |
| 13 | + |
4 | 14 | def test_load_cytotrace():
|
5 |
| - pass |
| 15 | + print(cytotrace.__file__) |
| 16 | + |
| 17 | + |
| 18 | +@pytest.fixture |
| 19 | +def small_anndata(): |
| 20 | + X = csr_matrix([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) |
| 21 | + obs = pd.DataFrame(index=["cell1", "cell2", "cell3"]) |
| 22 | + var = pd.DataFrame(index=["gene1", "gene2", "gene3"]) |
| 23 | + adata = AnnData(X, obs=obs, var=var) |
| 24 | + adata.layers["raw"] = X |
| 25 | + return adata |
| 26 | + |
| 27 | + |
| 28 | +def test_compute_similarity2(): |
| 29 | + O = np.array([[1, 2, 3], [4, 5, 6]]) |
| 30 | + P = np.array([[1, 2], [3, 4]]) |
| 31 | + result = cytotrace.compute_similarity2(O, P) |
| 32 | + assert result.shape == (2, 3) |
| 33 | + assert np.allclose(result.T, np.corrcoef(O.T, P)[:3, 3:], atol=1e-5) |
| 34 | + |
| 35 | + |
| 36 | +def test_compute_similarity1(): |
| 37 | + A = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) |
| 38 | + result = cytotrace.compute_similarity1(A) |
| 39 | + assert result.shape == (3, 3) |
| 40 | + assert np.allclose(result, np.corrcoef(A.T)) |
| 41 | + |
| 42 | + |
| 43 | +def test_compute_gcs(): |
| 44 | + mat = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) |
| 45 | + count = np.array([2, 3, 3]) |
| 46 | + result = cytotrace.compute_gcs(mat, count, top_n_genes=2) |
| 47 | + assert result.shape == (3,) |
| 48 | + |
| 49 | + |
| 50 | +def test_threshold_and_normalize_similarity_matrix(): |
| 51 | + sim = np.array( |
| 52 | + [ |
| 53 | + [1.0, 0.8, 0.3, 0.1], |
| 54 | + [0.8, 1.0, 0.5, 0.2], |
| 55 | + [0.3, 0.5, 1.0, 0.7], |
| 56 | + [0.1, 0.2, 0.7, 1.0], |
| 57 | + ] |
| 58 | + ) |
| 59 | + |
| 60 | + result = cytotrace.threshold_and_normalize_similarity_matrix(sim) |
| 61 | + |
| 62 | + # check diagonal is zeroed out |
| 63 | + assert np.all(np.diag(result) == 0) |
| 64 | + |
| 65 | + # check values below or equal to mean are zeroed out |
| 66 | + mean_sim = np.mean(sim) |
| 67 | + assert np.all(result[sim <= mean_sim] == 0) |
| 68 | + |
| 69 | + # check non-zero rows are normalized to sum to 1 |
| 70 | + non_zero_rows = np.where(result.sum(axis=1) > 0)[0] |
| 71 | + for row in non_zero_rows: |
| 72 | + assert_array_almost_equal(result[row].sum(), 1.0, decimal=6) |
| 73 | + |
| 74 | + # check zero rows remain zero |
| 75 | + zero_rows = np.where(result.sum(axis=1) == 0)[0] |
| 76 | + assert np.all(result[zero_rows] == 0) |
| 77 | + |
| 78 | + # check the result is sparse (contains zeros) |
| 79 | + assert np.sum(result == 0) > 0 |
| 80 | + |
| 81 | + # check the result preserves symmetry |
| 82 | + if np.allclose(sim, sim.T): |
| 83 | + assert np.allclose(result, result.T) |
| 84 | + |
| 85 | + # check stronger similarities are preserved |
| 86 | + stronger_similarities = sim > np.mean(sim) |
| 87 | + assert np.all(result[stronger_similarities] >= 0) |
| 88 | + |
| 89 | + # check weaker similarities are removed |
| 90 | + weaker_similarities = sim <= np.mean(sim) |
| 91 | + assert np.all(result[weaker_similarities] == 0) |
| 92 | + |
| 93 | + # check behavior with all-zero input |
| 94 | + zero_sim = np.zeros_like(sim) |
| 95 | + zero_result = cytotrace.threshold_and_normalize_similarity_matrix(zero_sim) |
| 96 | + assert np.all(zero_result == 0) |
| 97 | + |
| 98 | + # check behavior with negative values |
| 99 | + neg_sim = np.array([[-1, 0.5], [0.5, -1]]) |
| 100 | + neg_result = cytotrace.threshold_and_normalize_similarity_matrix(neg_sim) |
| 101 | + assert np.all(neg_result >= 0) |
| 102 | + assert_array_almost_equal(neg_result, np.array([[0, 1], [1, 0]])) |
| 103 | + |
| 104 | + |
| 105 | +def test_diffused(): |
| 106 | + markov = np.array([[0.7, 0.2, 0.1], [0.3, 0.5, 0.2], [0.1, 0.3, 0.6]]) |
| 107 | + gcs = np.array([1, 2, 3]) |
| 108 | + result = cytotrace.diffused(markov, gcs) |
| 109 | + assert result.shape == gcs.shape |
| 110 | + |
| 111 | + |
| 112 | +def test_cytotrace_sparse(small_anndata, monkeypatch): |
| 113 | + result = cytotrace.cytotrace_sparse(small_anndata, layer="raw") |
| 114 | + |
| 115 | + assert isinstance(result, dict) |
| 116 | + assert "CytoTRACE" in result |
| 117 | + assert "GCS" in result |
| 118 | + assert "cytoGenes" in result |
| 119 | + |
| 120 | + assert "gcs" in small_anndata.obs.columns |
| 121 | + assert "cytotrace" in small_anndata.obs.columns |
| 122 | + assert "counts" in small_anndata.obs.columns |
| 123 | + assert "cytotrace" in small_anndata.var.columns |
| 124 | + assert "cytotrace_corrs" in small_anndata.var.columns |
| 125 | + |
| 126 | + |
| 127 | +def test_cytotrace_sparse_errors(): |
| 128 | + adata = AnnData(X=np.array([[1, 2], [3, 4]])) |
| 129 | + adata.layers["raw"] = adata.X |
| 130 | + |
| 131 | + with pytest.raises( |
| 132 | + NotImplementedError, |
| 133 | + ): |
| 134 | + cytotrace.cytotrace_sparse(adata) |
6 | 135 |
|
7 |
| - # print(cytotrace.__file__) |
| 136 | + with pytest.raises( |
| 137 | + KeyError, |
| 138 | + ): |
| 139 | + cytotrace.cytotrace_sparse(adata, layer="non_existent") |
0 commit comments