Skip to content

Commit 99413d0

Browse files
committed
kmer
1 parent 8bb3ceb commit 99413d0

File tree

5 files changed

+38
-140
lines changed

5 files changed

+38
-140
lines changed

.gitignore

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -137,4 +137,5 @@ configs/selfsup/processed
137137
# configs/selfsup
138138
*.json
139139
*.toml
140-
*.ipynb
140+
*.ipynb
141+
pretrained_model/

configs/regression/DNA/transformer/deit/deit_t_dim64_l512_f20_bs1024_ep100.py renamed to configs/regression/DNA/transformer/deit/deit_t_dim64_l512_bs256_ep100.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
'num_layers': 12,
1717
'num_heads': embed_dim // 16,
1818
'feedforward_channels': embed_dim * 4},
19-
in_channels=4,
19+
in_channels=4096,
2020
padding_index=0,
2121
seq_len=seq_len,
2222
norm_cfg=dict(type='LN', eps=1e-6),
@@ -27,6 +27,7 @@
2727
out_indices=-1, # last layer
2828
with_cls_token=False,
2929
output_cls_token=False,
30+
with_embedding=True,
3031
),
3132
head=dict(
3233
type='RegHead',
@@ -41,19 +42,19 @@
4142
data_source_cfg = dict(
4243
type='DNASeqDataset',
4344
file_list=None, # use all splits
44-
word_splitor="", data_splitor=",", mapping_name="ACGT", # gRNA tokenize
45+
word_splitor=" ", data_splitor=",", # gRNA tokenize
4546
data_type="regression", target_type='total',
46-
filter_condition=20, max_seq_length=512,
47+
max_seq_length=512,
4748
)
4849
data = dict(
49-
samples_per_gpu=64, # bs64 x 8gpu x 2 accu = bs1024
50+
samples_per_gpu=128, # 256
5051
workers_per_gpu=4,
5152
train=dict(
5253
data_source=dict(root=data_root+"train", **data_source_cfg)),
5354
val=dict(
5455
data_source=dict(root=data_root+"test", **data_source_cfg)),
5556
)
56-
update_interval = 2
57+
update_interval = 1
5758

5859
# optimizer
5960
optimizer = dict(

configs/regression/DNA/transformer/deit/deit_t_dim64_l512_f80_bs1024_ep50.py

Lines changed: 0 additions & 92 deletions
This file was deleted.

configs/regression/_base_/datasets/DNA/dna.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,9 @@
22
data_root = 'data/dna/'
33
data_source_cfg = dict(
44
type='DNASeqDataset',
5-
file_list=None, # use all splits
6-
word_splitor="", data_splitor=",", mapping_name="ACGT", # gRNA tokenize
7-
data_type="regression", target_type='total',
8-
filter_condition=5, max_seq_length=512
5+
file_list=None, k=6, padding_idx=0,
6+
word_splitor=" ", data_splitor=",",
7+
data_type="regression", target_type='total', max_seq_length=512
98
)
109

1110
dataset_type = 'RegressionDataset'
@@ -44,7 +43,7 @@
4443
initial=True,
4544
interval=1,
4645
samples_per_gpu=100,
47-
workers_per_gpu=2,
46+
workers_per_gpu=4,
4847
eval_param=dict(
4948
metric=['mse', 'spearman', 'pearson'],
5049
metric_options=dict(average_mode='mean')

openbioseq/datasets/data_sources/dna_seq_source.py

Lines changed: 26 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import os
22
import torch
3-
3+
from itertools import product
44
from tqdm import tqdm
55
from openbioseq.utils import print_log
66
from ..registry import DATASOURCES
@@ -17,14 +17,12 @@ class DNASeqDataset(object):
1717
validation training, e.g., file_list=['train_1.txt',].
1818
word_splitor (str): Split the data string.
1919
data_splitor (str): Split each seqence in the data.
20-
mapping_name (str): Predefined mapping for the bio string.
2120
return_label (bool): Whether to return supervised labels.
2221
data_type (str): Type of the data.
2322
"""
2423

2524
CLASSES = None
26-
27-
ACGT = dict(N=0, A=1, C=2, G=3, T=4)
25+
toks = ['A', 'C', 'G', 'T']
2826
col_names = ['pos1',
2927
'pos2',
3028
'pos3',
@@ -39,26 +37,24 @@ class DNASeqDataset(object):
3937
'seq',
4038
'umi',
4139
'total']
42-
AminoAcids = dict()
4340

4441
def __init__(self,
4542
root,
4643
file_list=None,
4744
word_splitor="",
4845
data_splitor=" ",
49-
mapping_name="ACGT",
5046
has_labels=True,
5147
return_label=True,
5248
target_type='',
53-
filter_condition=0,
49+
k=6,
50+
padding_idx=0,
5451
data_type="classification",
5552
max_seq_length=1024,
5653
max_data_length=None):
5754
assert file_list is None or isinstance(file_list, list)
5855
assert word_splitor in ["", " ", ",", ";", ".",]
5956
assert data_splitor in [" ", ",", ";", ".", "\t",]
6057
assert word_splitor != data_splitor
61-
assert mapping_name in ["ACGT", "AminoAcids",]
6258
assert data_type in ["classification", "regression",]
6359
assert target_type in ['umi', 'total']
6460

@@ -75,46 +71,39 @@ def __init__(self,
7571
self.return_label = return_label
7672
self.data_type = data_type
7773
self.max_seq_length = max_seq_length
78-
self.filter_condition = filter_condition
7974
self.target_type = target_type
80-
75+
self.padding_idx = padding_idx
76+
self.kmer2idx = {''.join(x) : i for i, x in enumerate(product(self.toks, repeat=k), start=1)}
8177
print_log("Total file length: {}".format(len(lines)), logger='root')
8278

8379
# preprocesing
84-
mapping = getattr(self, mapping_name) # mapping str to ints
8580
self.data_list, self.labels = [], []
8681
for l in tqdm(lines, desc='Data preprocessing:'):
8782
l = l.strip().split(data_splitor)
83+
kmer_seq = l[self.col_names.index('seq')].split(word_splitor)
84+
kmer_idx_seq = list(map(self.kmer2idx.get, kmer_seq))
85+
padding = self.max_seq_length - len(kmer_idx_seq)
8886

89-
# filtering
90-
con_g = int(l[self.col_names.index('g_total_count')]) > self.filter_condition
91-
con_r = int(l[self.col_names.index('r_total_count')]) > self.filter_condition
92-
con = con_g & con_r
93-
94-
if con:
95-
if self.has_labels:
96-
# data = [mapping[tok] for tok in l[self.col_names.index('seq')]] + [0] * padding
97-
data_list = list(map(mapping.get, l[self.col_names.index('seq')]))
98-
padding = self.max_seq_length - len(data_list)
99-
if padding < 0:
100-
data = data_list[:self.max_seq_length]
101-
else:
102-
data = data_list + [0] * padding
87+
if padding < 0:
88+
data = kmer_idx_seq[:self.max_seq_length]
89+
else:
90+
data = kmer_idx_seq + [padding_idx] * padding
10391

104-
label = l[self.col_names.index(self.target_type)]
105-
106-
if self.data_type == "classification":
107-
label = torch.tensor(float(label)).type(torch.LongTensor)
108-
else:
109-
label = torch.tensor(float(label)).type(torch.float32)
110-
111-
self.labels.append(label)
92+
if self.has_labels:
93+
label = l[self.col_names.index(self.target_type)]
94+
95+
if self.data_type == "classification":
96+
label = torch.tensor(float(label)).type(torch.LongTensor)
11297
else:
113-
# assert self.return_label is False
114-
label = None
115-
data = l.strip()[self.col_names.index['seq']]
98+
label = torch.tensor(float(label)).type(torch.float32)
99+
100+
self.labels.append(label)
101+
else:
102+
# assert self.return_label is False
103+
label = None
104+
data = l.strip()[self.col_names.index['seq']]
116105

117-
self.data_list.append(data)
106+
self.data_list.append(data)
118107

119108
if max_data_length is not None:
120109
assert isinstance(max_data_length, (int, float))

0 commit comments

Comments
 (0)