1
1
import os
2
2
import torch
3
-
3
+ from itertools import product
4
4
from tqdm import tqdm
5
5
from openbioseq .utils import print_log
6
6
from ..registry import DATASOURCES
@@ -17,14 +17,12 @@ class DNASeqDataset(object):
17
17
validation training, e.g., file_list=['train_1.txt',].
18
18
word_splitor (str): Split the data string.
19
19
data_splitor (str): Split each seqence in the data.
20
- mapping_name (str): Predefined mapping for the bio string.
21
20
return_label (bool): Whether to return supervised labels.
22
21
data_type (str): Type of the data.
23
22
"""
24
23
25
24
CLASSES = None
26
-
27
- ACGT = dict (N = 0 , A = 1 , C = 2 , G = 3 , T = 4 )
25
+ toks = ['A' , 'C' , 'G' , 'T' ]
28
26
col_names = ['pos1' ,
29
27
'pos2' ,
30
28
'pos3' ,
@@ -39,26 +37,24 @@ class DNASeqDataset(object):
39
37
'seq' ,
40
38
'umi' ,
41
39
'total' ]
42
- AminoAcids = dict ()
43
40
44
41
def __init__ (self ,
45
42
root ,
46
43
file_list = None ,
47
44
word_splitor = "" ,
48
45
data_splitor = " " ,
49
- mapping_name = "ACGT" ,
50
46
has_labels = True ,
51
47
return_label = True ,
52
48
target_type = '' ,
53
- filter_condition = 0 ,
49
+ k = 6 ,
50
+ padding_idx = 0 ,
54
51
data_type = "classification" ,
55
52
max_seq_length = 1024 ,
56
53
max_data_length = None ):
57
54
assert file_list is None or isinstance (file_list , list )
58
55
assert word_splitor in ["" , " " , "," , ";" , "." ,]
59
56
assert data_splitor in [" " , "," , ";" , "." , "\t " ,]
60
57
assert word_splitor != data_splitor
61
- assert mapping_name in ["ACGT" , "AminoAcids" ,]
62
58
assert data_type in ["classification" , "regression" ,]
63
59
assert target_type in ['umi' , 'total' ]
64
60
@@ -75,46 +71,39 @@ def __init__(self,
75
71
self .return_label = return_label
76
72
self .data_type = data_type
77
73
self .max_seq_length = max_seq_length
78
- self .filter_condition = filter_condition
79
74
self .target_type = target_type
80
-
75
+ self .padding_idx = padding_idx
76
+ self .kmer2idx = {'' .join (x ) : i for i , x in enumerate (product (self .toks , repeat = k ), start = 1 )}
81
77
print_log ("Total file length: {}" .format (len (lines )), logger = 'root' )
82
78
83
79
# preprocesing
84
- mapping = getattr (self , mapping_name ) # mapping str to ints
85
80
self .data_list , self .labels = [], []
86
81
for l in tqdm (lines , desc = 'Data preprocessing:' ):
87
82
l = l .strip ().split (data_splitor )
83
+ kmer_seq = l [self .col_names .index ('seq' )].split (word_splitor )
84
+ kmer_idx_seq = list (map (self .kmer2idx .get , kmer_seq ))
85
+ padding = self .max_seq_length - len (kmer_idx_seq )
88
86
89
- # filtering
90
- con_g = int (l [self .col_names .index ('g_total_count' )]) > self .filter_condition
91
- con_r = int (l [self .col_names .index ('r_total_count' )]) > self .filter_condition
92
- con = con_g & con_r
93
-
94
- if con :
95
- if self .has_labels :
96
- # data = [mapping[tok] for tok in l[self.col_names.index('seq')]] + [0] * padding
97
- data_list = list (map (mapping .get , l [self .col_names .index ('seq' )]))
98
- padding = self .max_seq_length - len (data_list )
99
- if padding < 0 :
100
- data = data_list [:self .max_seq_length ]
101
- else :
102
- data = data_list + [0 ] * padding
87
+ if padding < 0 :
88
+ data = kmer_idx_seq [:self .max_seq_length ]
89
+ else :
90
+ data = kmer_idx_seq + [padding_idx ] * padding
103
91
104
- label = l [self .col_names .index (self .target_type )]
105
-
106
- if self .data_type == "classification" :
107
- label = torch .tensor (float (label )).type (torch .LongTensor )
108
- else :
109
- label = torch .tensor (float (label )).type (torch .float32 )
110
-
111
- self .labels .append (label )
92
+ if self .has_labels :
93
+ label = l [self .col_names .index (self .target_type )]
94
+
95
+ if self .data_type == "classification" :
96
+ label = torch .tensor (float (label )).type (torch .LongTensor )
112
97
else :
113
- # assert self.return_label is False
114
- label = None
115
- data = l .strip ()[self .col_names .index ['seq' ]]
98
+ label = torch .tensor (float (label )).type (torch .float32 )
99
+
100
+ self .labels .append (label )
101
+ else :
102
+ # assert self.return_label is False
103
+ label = None
104
+ data = l .strip ()[self .col_names .index ['seq' ]]
116
105
117
- self .data_list .append (data )
106
+ self .data_list .append (data )
118
107
119
108
if max_data_length is not None :
120
109
assert isinstance (max_data_length , (int , float ))
0 commit comments