Skip to content

Commit d139243

Browse files
authored
Merge pull request #62 from FrancescaDr/one_hot_encode
Features for InterScale
2 parents 402a0a8 + bd46b54 commit d139243

File tree

6 files changed

+121
-14
lines changed

6 files changed

+121
-14
lines changed

src/geome/ann2data/base/abstract.py

Lines changed: 33 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -70,31 +70,61 @@ def create_data_obj(self, adata: AnnData) -> Data:
7070
obj[field] = self.merge_field(adata, field, locations)
7171
return Data(**obj)
7272

73-
def __call__(self, adata: AnnData | Iterable[AnnData]) -> Iterable[Data]:
73+
def __call__(self,
74+
adata: AnnData | Iterable[AnnData],
75+
save_subadata: bool = False,
76+
save_data: bool = False,
77+
save_preprocessed: bool = False) -> Iterable[Data]:
7478
"""Convert an AnnData object to a PyTorch compatible data object.
7579
7680
Args:
7781
----
7882
adata: The AnnData object to be converted.
83+
save_subadata: If True, save the subadata objects.
84+
save_data: If True, save the data objects.
85+
save_preprocessed: If True, save the preprocessed AnnData object.
7986
8087
Yields
8188
------
8289
PyTorch Geometric compatible data object.
8390
8491
"""
92+
if save_subadata or save_data or save_preprocessed:
93+
res = {}
94+
else:
95+
res = None
96+
8597
# do the given preprocessing steps.
8698
if self._preprocess is not None:
8799
adata = self._preprocess(adata)
100+
if save_preprocessed:
101+
res['preprocessed'] = adata
88102
# convert adata to iterable if it is specified
89103
adata_iter = adata
90104
if self._adata2iterable is not None:
91105
adata_iter = self._adata2iterable(adata)
106+
107+
if save_subadata:
108+
res['subadata'] = []
109+
if save_data:
110+
res['data'] = []
92111

93112
# iterate trough adata.
113+
data_objects = []
94114
for subadata in adata_iter:
95115
if self._transform is not None:
96116
subadata = self._transform(subadata)
97-
yield self.create_data_obj(subadata)
117+
118+
current_data_obj = self.create_data_obj(subadata)
119+
data_objects.append(current_data_obj)
120+
121+
if save_subadata:
122+
res['subadata'].append(subadata)
123+
if save_data:
124+
res['data'].append(current_data_obj)
125+
126+
return data_objects, res
127+
98128

99129
def to_list(self, adata: AnnData | Iterable[AnnData]) -> list[Data]:
100130
"""Convert an AnnData object to a list of PyTorch compatible data objects.
@@ -108,3 +138,4 @@ def to_list(self, adata: AnnData | Iterable[AnnData]) -> list[Data]:
108138
A list of PyTorch Geometric compatible data objects.
109139
"""
110140
return list(self(adata))
141+

src/geome/ann2data/basic.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from __future__ import annotations
22

3-
from typing import Callable
3+
from typing import Any, Callable
44

55
import numpy as np
66
import pandas as pd
@@ -22,6 +22,8 @@ def __init__(
2222
adata2iter: Callable[[AnnData], AnnData] | None = None,
2323
preprocess: list[Callable[[AnnData], AnnData]] | None = None,
2424
transform: list[Callable[[AnnData], AnnData]] | None = None,
25+
*args: Any,
26+
**kwargs: Any,
2527
) -> None:
2628
"""Convert anndata object into a dictionary of arrays.
2729
@@ -46,7 +48,7 @@ def __init__(
4648
transform: List of functions to transform the AnnData object after preprocessing.
4749
edge_index_key: Key for the edge index in the converted data. Defaults to 'edge_index'.
4850
"""
49-
super().__init__(fields, adata2iter, preprocess, transform)
51+
super().__init__(fields, adata2iter, preprocess, transform, *args, **kwargs)
5052

5153
self._preprocess = preprocess
5254
self._transform = transform

src/geome/iterables/to_category_iterator.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from collections.abc import Iterator
44
from dataclasses import dataclass
5-
from typing import Literal
5+
from typing import Any, Callable, Literal, Optional
66

77
from anndata import AnnData
88

@@ -21,12 +21,12 @@ class ToCategoryIterator(ToIterable):
2121
axis (int | str): The axis along which to iterate over the categories. Can be either 0, 1, "obs" or "var".
2222
0 or "obs" means the categories are in the observation axis.
2323
1 or "var" means the categories are in the variable axis.
24-
preserve_categories (bool): Preserves the categories in the resulting AnnData obs and var Series if `preserve_categories` is True.
24+
preserve_categories (list): If not None, preserves the indicated categories from the Anndata 'obs' and 'var'
2525
"""
2626

2727
category: str
2828
axis: Literal[0, 1, "obs", "var"] = "obs"
29-
preserve_categories: bool = True
29+
preserve_categories: Optional[list[str]] = []
3030

3131
def __post_init__(self):
3232
if self.axis not in (0, 1, "obs", "var"):
@@ -51,13 +51,15 @@ def __call__(self, adata: AnnData) -> Iterator[AnnData]:
5151
cats_df = get_from_loc(adata, f"{self.axis}/{self.category}")
5252
cats = cats_df.dtypes.categories
5353
preserved_categories = {"obs": {}, "var": {}}
54-
if self.preserve_categories:
55-
for axis in ("obs", "var"):
56-
adata_axis = getattr(adata, axis)
57-
if adata_axis is not None:
58-
for key in adata_axis.keys():
59-
if adata_axis[key].dtype.name == "category":
60-
preserved_categories[axis][key] = adata_axis[key].cat.categories
54+
55+
if self.preserve_categories is not None:
56+
for key in self.preserve_categories:
57+
for axis in ("obs", "var"):
58+
adata_axis = getattr(adata, axis)
59+
if adata_axis is not None:
60+
if key in adata_axis.keys():
61+
if adata_axis[key].dtype.name == "category":
62+
preserved_categories[axis][key] = adata_axis[key].cat.categories
6163

6264
for cat in cats:
6365
# TODO(syelman): is this wise? Maybe create copy only if preserve_categories is True?

src/geome/transforms/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from .base.transform import Transform
55
from .categorize import Categorize
66
from .compose import Compose
7+
from .one_hot_encode import SaveOneHotEncodeLabels
78
from .subset import Subset
89

910
__all__ = [
@@ -15,4 +16,5 @@
1516
"AddEdgeIndex",
1617
"AddEdgeIndexFromAdj",
1718
"Subset",
19+
"SaveOneHotEncodeLabels",
1820
]
Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
from __future__ import annotations
2+
3+
from dataclasses import dataclass
4+
from typing import Literal
5+
6+
import pandas as pd
7+
from anndata import AnnData
8+
9+
from .base.transform import Transform
10+
11+
12+
@dataclass
13+
class SaveOneHotEncodeLabels(Transform):
14+
"""One-hot encode specified columns in an AnnData object and store them in the specified matrix slot.
15+
16+
Args:
17+
-----
18+
keys (Union[str, List[str]]): Columns to be one-hot encoded.
19+
axis (Literal["obs", "var"]): Axis on which the columns are located. 'obs' for observation, 'var' for variables.
20+
key_added (str): Base key under which the one-hot encoded data and label mappings will be stored.
21+
22+
Methods
23+
-------
24+
__call__(adata: AnnData) -> None:
25+
Converts the specified columns to one-hot encoded format and updates `adata` accordingly.
26+
"""
27+
28+
keys: str | list
29+
axis: Literal["obs", "var"]
30+
key_added: str
31+
32+
def __post_init__(self):
33+
if isinstance(self.keys, str):
34+
self.keys = [self.keys]
35+
36+
def __call__(self, adata: AnnData) -> None:
37+
"""
38+
One-hot encode the specified columns and store the result in the AnnData object.
39+
40+
Parameters
41+
----------
42+
adata : AnnData
43+
The annotated data matrix to be updated with one-hot encoded data.
44+
45+
Returns
46+
-------
47+
None
48+
"""
49+
matrix_key = f"{self.axis}m" # e.g., 'obsm' or 'varm'
50+
# encoded_data = {}
51+
label_mappings = {}
52+
encoded_data_list = []
53+
54+
for key in self.keys:
55+
# Generate one-hot encoding
56+
categories = pd.get_dummies(getattr(adata, self.axis)[key])
57+
# encoded_data[key] = categories
58+
encoded_data_list.append(categories)
59+
# Store mapping of codes to labels
60+
label_mappings[key] = categories.columns.tolist()
61+
62+
encoded_data_combined = pd.concat(encoded_data_list, axis=1)
63+
64+
# Save the encoded data and mappings in the appropriate AnnData structure
65+
getattr(adata, matrix_key)[self.key_added] = pd.DataFrame(encoded_data_combined)
66+
adata.uns[f"{self.key_added}_mappings"] = label_mappings
67+
68+
print(label_mappings)
69+
70+
return adata

tests/transforms/test_one_hot_encode_labels.py

Whitespace-only changes.

0 commit comments

Comments
 (0)