Skip to content

Commit 8bb3ceb

Browse files
committed
update seq transformer
1 parent 58ac558 commit 8bb3ceb

File tree

13 files changed

+250
-483
lines changed

13 files changed

+250
-483
lines changed

configs/regression/DNA/transformer/deit/deit_t_dim64_l512_f20_bs1024_ep100.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,13 @@
1111
type='Classification',
1212
pretrained=None,
1313
backbone=dict(
14-
type='DNATransformer',
14+
type='SequenceTransformer',
1515
arch={'embed_dims': embed_dim,
1616
'num_layers': 12,
1717
'num_heads': embed_dim // 16,
1818
'feedforward_channels': embed_dim * 4},
1919
in_channels=4,
20+
padding_index=0,
2021
seq_len=seq_len,
2122
norm_cfg=dict(type='LN', eps=1e-6),
2223
drop_rate=0.1,
@@ -44,16 +45,13 @@
4445
data_type="regression", target_type='total',
4546
filter_condition=20, max_seq_length=512,
4647
)
47-
4848
data = dict(
4949
samples_per_gpu=64, # bs64 x 8gpu x 2 accu = bs1024
5050
workers_per_gpu=4,
5151
train=dict(
52-
data_source=dict(
53-
root=data_root+"train", **data_source_cfg)),
52+
data_source=dict(root=data_root+"train", **data_source_cfg)),
5453
val=dict(
55-
data_source=dict(
56-
root=data_root+"test", **data_source_cfg)),
54+
data_source=dict(root=data_root+"test", **data_source_cfg)),
5755
)
5856
update_interval = 2
5957

configs/regression/DNA/transformer/deit/deit_t_dim64_l512_f5_bs1024_ep100.py renamed to configs/regression/DNA/transformer/deit/deit_t_dim64_l512_f80_bs1024_ep50.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -11,19 +11,21 @@
1111
type='Classification',
1212
pretrained=None,
1313
backbone=dict(
14-
type='DNATransformer',
14+
type='SequenceTransformer',
1515
arch={'embed_dims': embed_dim,
1616
'num_layers': 12,
1717
'num_heads': embed_dim // 16,
1818
'feedforward_channels': embed_dim * 4},
1919
in_channels=4,
20+
padding_index=0,
2021
seq_len=seq_len,
2122
norm_cfg=dict(type='LN', eps=1e-6),
2223
drop_rate=0.1,
2324
drop_path_rate=0.1,
2425
init_values=0.1,
2526
final_norm=True,
2627
out_indices=-1, # last layer
28+
with_embedding=True, # use `nn.Embedding`
2729
with_cls_token=False,
2830
output_cls_token=False,
2931
),
@@ -42,18 +44,15 @@
4244
file_list=None, # use all splits
4345
word_splitor="", data_splitor=",", mapping_name="ACGT", # gRNA tokenize
4446
data_type="regression", target_type='total',
45-
filter_condition=5, max_seq_length=512,
47+
filter_condition=80, max_seq_length=512,
4648
)
47-
4849
data = dict(
4950
samples_per_gpu=64, # bs64 x 8gpu x 2 accu = bs1024
50-
workers_per_gpu=4,
51+
workers_per_gpu=2,
5152
train=dict(
52-
data_source=dict(
53-
root=data_root+"train", **data_source_cfg)),
53+
data_source=dict(root=data_root+"train", **data_source_cfg)),
5454
val=dict(
55-
data_source=dict(
56-
root=data_root+"test", **data_source_cfg)),
55+
data_source=dict(root=data_root+"test", **data_source_cfg)),
5756
)
5857
update_interval = 2
5958

configs/regression/_base_/datasets/DNA/dna.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@
4646
samples_per_gpu=100,
4747
workers_per_gpu=2,
4848
eval_param=dict(
49-
metric=['mse', 'spearman'],
49+
metric=['mse', 'spearman', 'pearson'],
5050
metric_options=dict(average_mode='mean')
5151
),
5252
)
Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
_base_ = [
2+
'../../../_base_/datasets/DNA/dna_pretrain.py',
3+
'../../../_base_/default_runtime.py',
4+
]
5+
6+
embed_dim = 64
7+
seq_len = 512
8+
patch_size = 1
9+
10+
# model settings
11+
model = dict(
12+
type='BERT',
13+
pretrained=None,
14+
mask_ratio=0.15, # BERT 15%
15+
backbone=dict(
16+
type='SimMIMTransformer',
17+
arch={'embed_dims': embed_dim,
18+
'num_layers': 12,
19+
'num_heads': embed_dim // 16,
20+
'feedforward_channels': embed_dim * 4},
21+
in_channels=4,
22+
padding_index=0,
23+
seq_len=seq_len,
24+
mask_layer=10,
25+
mask_ratio=0.15, # BERT 15%
26+
mask_token='learnable',
27+
norm_cfg=dict(type='LN', eps=1e-6),
28+
drop_rate=0., # no dropout for pre-training
29+
drop_path_rate=0.1,
30+
final_norm=True,
31+
out_indices=-1, # last layer
32+
with_embedding=True, # use `nn.Embedding`
33+
with_cls_token=True,
34+
output_cls_token=True,
35+
),
36+
neck=dict(
37+
type='SimMIMNeck', feature_Nd="1d",
38+
in_channels=embed_dim, out_channels=5, encoder_stride=patch_size),
39+
head=dict(
40+
type='MIMHead',
41+
loss=dict(type='CrossEntropyLoss',
42+
use_soft=False, use_sigmoid=False, reduction='none', loss_weight=1.0),
43+
feature_Nd="1d", unmask_weight=0., encoder_in_channels=5,
44+
),
45+
init_cfg=[
46+
dict(type='TruncNormal', layer=['Conv1d', 'Linear'], std=0.02, bias=0.),
47+
dict(type='Constant', layer=['LayerNorm'], val=1., bias=0.)
48+
],
49+
)
50+
51+
# dataset settings
52+
data_root = 'data/dna/'
53+
data_source_cfg = dict(
54+
type='DNASeqDataset',
55+
file_list=None, # use all splits
56+
# file_list=["train_0.csv",], # use all splits
57+
word_splitor="", data_splitor=",", mapping_name="ACGT", # gRNA tokenize
58+
has_labels=True, return_label=False, # pre-training
59+
data_type="regression", target_type='total',
60+
filter_condition=5, max_seq_length=seq_len,
61+
)
62+
data = dict(
63+
samples_per_gpu=2, # bs64 x 8gpu x 2 accu = bs1024
64+
workers_per_gpu=2,
65+
train=dict(
66+
data_source=dict(root=data_root+"train", **data_source_cfg)),
67+
)
68+
update_interval = 2 # bs64 x 8gpu x 2 accu = bs1024
69+
70+
# optimizer
71+
optimizer = dict(
72+
type='AdamW',
73+
lr=1e-3,
74+
weight_decay=1e-2, eps=1e-8, betas=(0.9, 0.999),
75+
paramwise_options={
76+
'(bn|ln|gn)(\d+)?.(weight|bias)': dict(weight_decay=0.),
77+
'norm': dict(weight_decay=0.),
78+
'bias': dict(weight_decay=0.),
79+
'cls_token': dict(weight_decay=0.),
80+
'pos_embed': dict(weight_decay=0.),
81+
'mask_token': dict(weight_decay=0.),
82+
})
83+
84+
# apex
85+
use_fp16 = False
86+
fp16 = dict(type='mmcv', loss_scale=dict(mode='dynamic'))
87+
optimizer_config = dict(
88+
grad_clip=dict(max_norm=10.0), update_interval=1)
89+
90+
# learning policy
91+
lr_config = dict(
92+
policy='CosineAnnealing',
93+
by_epoch=False, min_lr=1e-5,
94+
warmup='linear',
95+
warmup_iters=5, warmup_by_epoch=True,
96+
warmup_ratio=1e-5,
97+
)
98+
99+
# checkpoint
100+
checkpoint_config = dict(interval=1, max_keep_ckpts=1)
101+
102+
# runtime settings
103+
runner = dict(type='EpochBasedRunner', max_epochs=100)
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
# dataset settings
2+
data_root = 'data/dna/'
3+
data_source_cfg = dict(
4+
type='DNASeqDataset',
5+
file_list=None, # use all splits
6+
word_splitor="", data_splitor=",", mapping_name="ACGT", # gRNA tokenize
7+
has_labels=True, return_label=False, # pre-training
8+
data_type="regression", target_type='total',
9+
filter_condition=5, max_seq_length=512
10+
)
11+
12+
dataset_type = 'ExtractDataset'
13+
sample_norm_cfg = dict(mean=[0,], std=[1,])
14+
train_pipeline = [
15+
dict(type='ToTensor'),
16+
]
17+
test_pipeline = [
18+
dict(type='ToTensor'),
19+
]
20+
# prefetch
21+
prefetch = False
22+
23+
data = dict(
24+
samples_per_gpu=256,
25+
workers_per_gpu=4,
26+
drop_last=True,
27+
train=dict(
28+
type=dataset_type,
29+
data_source=dict(
30+
root=data_root+"train",
31+
**data_source_cfg),
32+
pipeline=train_pipeline,
33+
prefetch=prefetch,
34+
),
35+
)
36+
37+
# checkpoint
38+
checkpoint_config = dict(interval=1, max_keep_ckpts=1)

openbioseq/datasets/data_sources/dna_seq_source.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ def __init__(self,
4848
data_splitor=" ",
4949
mapping_name="ACGT",
5050
has_labels=True,
51+
return_label=True,
5152
target_type='',
5253
filter_condition=0,
5354
data_type="classification",
@@ -71,6 +72,7 @@ def __init__(self,
7172

7273
# instance vars
7374
self.has_labels = len(lines[0].split(data_splitor)) >= 2 and has_labels
75+
self.return_label = return_label
7476
self.data_type = data_type
7577
self.max_seq_length = max_seq_length
7678
self.filter_condition = filter_condition
@@ -125,7 +127,7 @@ def get_length(self):
125127

126128
def get_sample(self, idx):
127129
seq = self.data_list[idx]
128-
if self.has_labels:
130+
if self.has_labels and self.return_label:
129131
target = self.labels[idx]
130132
return seq, target
131133
else:

openbioseq/models/backbones/__init__.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
from .resnet import ResNet, ResNet_CIFAR, ResNet_Mix, ResNet_Mix_CIFAR
55
from .seq_lstm import SequenceLSTM
66
from .seq_transformer import SequenceTransformer
7-
from .seq_embed_transformer import DNATransformer
87
from .timm_backbone import TIMMBackbone
98
from .uniformer import UniFormer
109
from .van import VAN
@@ -16,5 +15,5 @@
1615
'MAETransformer', 'MAEViT', 'MIMVisionTransformer', 'SimMIMTransformer', 'SimMIMViT',
1716
'ResNet', 'ResNet_CIFAR', 'ResNet_Mix', 'ResNet_Mix_CIFAR',
1817
'SequenceLSTM', 'SequenceTransformer', 'TIMMBackbone', 'TransformerEncoderLayer',
19-
'UniFormer', 'VAN', 'VisionTransformer', 'WideResNet', 'DNATransformer'
18+
'UniFormer', 'VAN', 'VisionTransformer', 'WideResNet',
2019
]

openbioseq/models/backbones/mim_vit.py

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,13 @@ def _init_weights(self, m):
9292
def forward(self, x):
9393
""" MAE backbone only used for MAE model """
9494
B = x.shape[0]
95-
x, _ = self.patch_embed(x)
95+
if not self.with_embedding:
96+
x, _ = self.patch_embed(x)
97+
else:
98+
if x.dtype != torch.long: # must be indice
99+
x = x.type(torch.long).clamp(0, x.size(1)-1)
100+
x = self.embedding_layer(x)
101+
96102
# add pos embed w/o cls token
97103
x = x + self.pos_embed[:, 1:, :]
98104
# masking: length -> length * mask_ratio
@@ -360,16 +366,23 @@ def forward(self, x, mask=None):
360366
Returns:
361367
tuple: A tuple containing features from multi-stages.
362368
"""
363-
x, seq_len = self.patch_embed(x)
369+
if not self.with_embedding:
370+
x, seq_len = self.patch_embed(x)
371+
else:
372+
if x.dtype != torch.long: # must be indice
373+
x = x.type(torch.long).clamp(0, x.size(1)-1)
374+
x = self.embedding_layer(x)
375+
seq_len = self.seq_len
364376

365377
if self.mask_layer == 0:
366378
if mask is None:
367379
mask = simmim_random_masking(x, self.mask_ratio)
368380
x = forward_simmim_masking(
369381
x, self.mask_token, mask, self.mask_mode)
370382

371-
cls_tokens = self.cls_token.expand(x.size(0), -1, -1)
372-
x = torch.cat((cls_tokens, x), dim=1)
383+
if self.with_cls_token or self.output_cls_token:
384+
cls_tokens = self.cls_token.expand(x.size(0), -1, -1)
385+
x = torch.cat((cls_tokens, x), dim=1)
373386
x = x + self.resize_pos_embed(
374387
self.pos_embed,
375388
src_shape=self.patch_resolution,
@@ -378,7 +391,7 @@ def forward(self, x, mask=None):
378391
num_extra_tokens=self.num_extra_tokens)
379392
x = self.drop_after_pos(x)
380393

381-
if not self.with_cls_token:
394+
if self.with_cls_token and not self.output_cls_token:
382395
# Remove class token for transformer encoder input
383396
x = x[:, 1:]
384397

@@ -393,7 +406,7 @@ def forward(self, x, mask=None):
393406
if i == len(self.layers) - 1 and self.final_norm:
394407
x = self.norm1(x)
395408

396-
if self.with_cls_token:
409+
if self.output_cls_token:
397410
x = x[:, 1:]
398411

399412
return (x, mask)

0 commit comments

Comments
 (0)