Skip to content

Commit 58ac558

Browse files
committed
fix dna
1 parent 2bc0a52 commit 58ac558

File tree

6 files changed

+153
-50
lines changed

6 files changed

+153
-50
lines changed
Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
_base_ = [
2+
'../../../_base_/datasets/DNA/dna.py',
3+
'../../../_base_/default_runtime.py',
4+
]
5+
6+
embed_dim = 64
7+
seq_len = 512
8+
9+
# model settings
10+
model = dict(
11+
type='Classification',
12+
pretrained=None,
13+
backbone=dict(
14+
type='DNATransformer',
15+
arch={'embed_dims': embed_dim,
16+
'num_layers': 12,
17+
'num_heads': embed_dim // 16,
18+
'feedforward_channels': embed_dim * 4},
19+
in_channels=4,
20+
seq_len=seq_len,
21+
norm_cfg=dict(type='LN', eps=1e-6),
22+
drop_rate=0.1,
23+
drop_path_rate=0.1,
24+
init_values=0.1,
25+
final_norm=True,
26+
out_indices=-1, # last layer
27+
with_cls_token=False,
28+
output_cls_token=False,
29+
),
30+
head=dict(
31+
type='RegHead',
32+
loss=dict(type='RegressionLoss', mode='huber_loss',
33+
loss_weight=1.0, reduction='mean',
34+
activate='sigmoid', alpha=0.2, gamma=1.0, beta=1.0, residual=False),
35+
with_avg_pool=True, in_channels=embed_dim, out_channels=1),
36+
)
37+
38+
# dataset settings
39+
data_root = 'data/dna/'
40+
data_source_cfg = dict(
41+
type='DNASeqDataset',
42+
file_list=None, # use all splits
43+
word_splitor="", data_splitor=",", mapping_name="ACGT", # gRNA tokenize
44+
data_type="regression", target_type='total',
45+
filter_condition=20, max_seq_length=512,
46+
)
47+
48+
data = dict(
49+
samples_per_gpu=64, # bs64 x 8gpu x 2 accu = bs1024
50+
workers_per_gpu=4,
51+
train=dict(
52+
data_source=dict(
53+
root=data_root+"train", **data_source_cfg)),
54+
val=dict(
55+
data_source=dict(
56+
root=data_root+"test", **data_source_cfg)),
57+
)
58+
update_interval = 2
59+
60+
# optimizer
61+
optimizer = dict(
62+
type='AdamW',
63+
lr=5e-3,
64+
weight_decay=1e-2, eps=1e-8, betas=(0.9, 0.999),
65+
paramwise_options={
66+
'(bn|ln|gn)(\d+)?.(weight|bias)': dict(weight_decay=0.),
67+
'norm': dict(weight_decay=0.),
68+
'bias': dict(weight_decay=0.),
69+
'pos_embed': dict(weight_decay=0.),
70+
'gamma': dict(weight_decay=0.),
71+
'noise_sigma': dict(weight_decay=0., lr_mult=1e-1),
72+
})
73+
74+
# apex
75+
use_fp16 = False
76+
fp16 = dict(type='mmcv', loss_scale='dynamic')
77+
optimizer_config = dict(
78+
grad_clip=dict(max_norm=5.0), update_interval=update_interval)
79+
80+
# learning policy
81+
lr_config = dict(
82+
policy='CosineAnnealing',
83+
by_epoch=False, min_lr=1e-5,
84+
warmup='linear',
85+
warmup_iters=1, warmup_by_epoch=True,
86+
warmup_ratio=1e-5,
87+
)
88+
89+
# checkpoint
90+
checkpoint_config = dict(interval=1, max_keep_ckpts=1)
91+
92+
# runtime settings
93+
runner = dict(type='EpochBasedRunner', max_epochs=50)

configs/regression/DNA/transformer/layer4/layer4_p2_h4_d64_init_bs256_ep100.py renamed to configs/regression/DNA/transformer/deit/deit_t_dim64_l512_f5_bs1024_ep100.py

Lines changed: 36 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2,17 +2,20 @@
22
'../../../_base_/datasets/DNA/dna.py',
33
'../../../_base_/default_runtime.py',
44
]
5+
56
embed_dim = 64
6-
patch_size = 2
7-
seq_len = 1024
7+
seq_len = 512
88

99
# model settings
1010
model = dict(
1111
type='Classification',
1212
pretrained=None,
1313
backbone=dict(
1414
type='DNATransformer',
15-
arch='deit-s',
15+
arch={'embed_dims': embed_dim,
16+
'num_layers': 12,
17+
'num_heads': embed_dim // 16,
18+
'feedforward_channels': embed_dim * 4},
1619
in_channels=4,
1720
seq_len=seq_len,
1821
norm_cfg=dict(type='LN', eps=1e-6),
@@ -32,37 +35,59 @@
3235
with_avg_pool=True, in_channels=embed_dim, out_channels=1),
3336
)
3437

38+
# dataset settings
39+
data_root = 'data/dna/'
40+
data_source_cfg = dict(
41+
type='DNASeqDataset',
42+
file_list=None, # use all splits
43+
word_splitor="", data_splitor=",", mapping_name="ACGT", # gRNA tokenize
44+
data_type="regression", target_type='total',
45+
filter_condition=5, max_seq_length=512,
46+
)
47+
48+
data = dict(
49+
samples_per_gpu=64, # bs64 x 8gpu x 2 accu = bs1024
50+
workers_per_gpu=4,
51+
train=dict(
52+
data_source=dict(
53+
root=data_root+"train", **data_source_cfg)),
54+
val=dict(
55+
data_source=dict(
56+
root=data_root+"test", **data_source_cfg)),
57+
)
58+
update_interval = 2
59+
3560
# optimizer
3661
optimizer = dict(
3762
type='AdamW',
38-
lr=3e-3,
39-
weight_decay=5e-2, eps=1e-8, betas=(0.9, 0.999),
63+
lr=5e-3,
64+
weight_decay=1e-2, eps=1e-8, betas=(0.9, 0.999),
4065
paramwise_options={
4166
'(bn|ln|gn)(\d+)?.(weight|bias)': dict(weight_decay=0.),
4267
'norm': dict(weight_decay=0.),
4368
'bias': dict(weight_decay=0.),
4469
'pos_embed': dict(weight_decay=0.),
4570
'gamma': dict(weight_decay=0.),
46-
# 'noise_sigma': dict(weight_decay=0., lr_mult=1e-1),
71+
'noise_sigma': dict(weight_decay=0., lr_mult=1e-1),
4772
})
4873

4974
# apex
5075
use_fp16 = False
51-
fp16 = dict(type='apex', loss_scale=dict(mode='dynamic'))
76+
fp16 = dict(type='mmcv', loss_scale='dynamic')
5277
optimizer_config = dict(
53-
grad_clip=dict(max_norm=5.0), update_interval=1)
78+
grad_clip=dict(max_norm=5.0), update_interval=update_interval)
5479

5580
# learning policy
5681
lr_config = dict(
5782
policy='CosineAnnealing',
5883
by_epoch=False, min_lr=1e-5,
5984
warmup='linear',
60-
warmup_iters=5, warmup_by_epoch=True,
85+
warmup_iters=1, warmup_by_epoch=True,
6186
warmup_ratio=1e-5,
6287
)
6388

6489
# checkpoint
65-
checkpoint_config = dict(interval=100, max_keep_ckpts=1)
90+
checkpoint_config = dict(interval=1, max_keep_ckpts=1)
6691

6792
# runtime settings
68-
runner = dict(type='EpochBasedRunner', max_epochs=100)
93+
runner = dict(type='EpochBasedRunner', max_epochs=50)

configs/regression/_base_/datasets/DNA/dna.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
file_list=None, # use all splits
66
word_splitor="", data_splitor=",", mapping_name="ACGT", # gRNA tokenize
77
data_type="regression", target_type='total',
8-
filter_condition=0
8+
filter_condition=5, max_seq_length=512
99
)
1010

1111
dataset_type = 'RegressionDataset'
@@ -41,8 +41,8 @@
4141

4242
# validation hook
4343
evaluation = dict(
44-
initial=False,
45-
interval=5,
44+
initial=True,
45+
interval=1,
4646
samples_per_gpu=100,
4747
workers_per_gpu=2,
4848
eval_param=dict(
@@ -52,4 +52,4 @@
5252
)
5353

5454
# checkpoint
55-
checkpoint_config = dict(interval=200, max_keep_ckpts=1)
55+
checkpoint_config = dict(interval=1, max_keep_ckpts=1)

openbioseq/datasets/data_sources/dna_seq_source.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
from ..utils import read_file
88

99

10-
1110
@DATASOURCES.register_module
1211
class DNASeqDataset(object):
1312
"""The implementation for loading any bio seqences.
@@ -95,7 +94,11 @@ def __init__(self,
9594
# data = [mapping[tok] for tok in l[self.col_names.index('seq')]] + [0] * padding
9695
data_list = list(map(mapping.get, l[self.col_names.index('seq')]))
9796
padding = self.max_seq_length - len(data_list)
98-
data = data_list + [0] * padding
97+
if padding < 0:
98+
data = data_list[:self.max_seq_length]
99+
else:
100+
data = data_list + [0] * padding
101+
99102
label = l[self.col_names.index(self.target_type)]
100103

101104
if self.data_type == "classification":

openbioseq/models/backbones/seq_embed_transformer.py

Lines changed: 5 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
11
import math
22
from typing import Sequence
3-
from functools import reduce
4-
from operator import mul
53

64
import numpy as np
75
import torch
@@ -13,8 +11,7 @@
1311
from mmcv.utils.parrots_wrapper import _BatchNorm
1412

1513
from openbioseq.utils import get_root_logger, print_log
16-
from ..utils import resize_pos_embed, PatchEmbed1d, ConvPatchEmbed1d, \
17-
build_1d_sincos_position_embedding
14+
from ..utils import resize_pos_embed, build_1d_sincos_position_embedding
1815
from ..builder import BACKBONES
1916
from .base_backbone import BaseBackbone
2017
from .vision_transformer import TransformerEncoderLayer
@@ -203,26 +200,6 @@ def __init__(self,
203200
padding_idx=padding_index
204201
)
205202
self.embedding_layer = nn.Embedding(**_seq_cfg)
206-
# _patch_cfg = dict(
207-
# in_channels=in_channels,
208-
# input_size=seq_len,
209-
# embed_dims=self.embed_dims,
210-
# conv_type='Conv1d',
211-
# kernel_size=patch_size,
212-
# stride=patch_size if patchfied else patch_size // 2,
213-
# )
214-
# if stem_layer <= 1:
215-
# _patch_cfg.update(patch_cfg)
216-
# self.patch_embed = PatchEmbed1d(**_patch_cfg)
217-
# else:
218-
# _patch_cfg.update(dict(
219-
# num_layers=stem_layer,
220-
# act_cfg=act_cfg,
221-
# ))
222-
# _patch_cfg.update(patch_cfg)
223-
# self.patch_embed = ConvPatchEmbed1d(**_patch_cfg)
224-
# self.patch_resolution = self.patch_embed.init_out_size
225-
# self.num_patches = self.patch_embed.init_out_size
226203

227204
# Set cls token
228205
if output_cls_token:
@@ -328,11 +305,7 @@ def init_weights(self, pretrained=None):
328305
cls_token=True)
329306
self.pos_embed.data.copy_(pos_emb)
330307
self.pos_embed.requires_grad = False
331-
# xavier_uniform initialization for PatchEmbed1d
332-
# if isinstance(self.patch_embed, PatchEmbed1d):
333-
# val = math.sqrt(
334-
# 6. / float(3 * reduce(mul, self.patch_size, 1) + self.embed_dims))
335-
# uniform_init(self.patch_embed.projection, -val, val, bias=0)
308+
336309
# initialization for linear layers
337310
for name, m in self.named_modules():
338311
if isinstance(m, nn.Linear):
@@ -373,8 +346,10 @@ def resize_pos_embed(*args, **kwargs):
373346

374347
def forward(self, x):
375348
B = x.shape[0]
349+
if x.dtype != torch.long: # must be indice
350+
x = x.type(torch.long).clamp(0, x.size(1)-1)
376351
x = self.embedding_layer(x)
377-
352+
378353
if self.cls_token is not None:
379354
cls_tokens = self.cls_token.expand(B, -1, -1)
380355
x = torch.cat((cls_tokens, x), dim=1)

tools/analysis_tools/get_flops.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,9 @@
1-
# Copyright (c) OpenMMLab. All rights reserved.
1+
"""
2+
An example to count Params and FLOPs.
3+
4+
Example command:
5+
python tools/get_flops.py [PATH_TO_config] --channel 4 --shape 512
6+
"""
27
import argparse
38

49
from mmcv import Config
@@ -13,13 +18,13 @@ def parse_args():
1318
parser.add_argument(
1419
'--channel',
1520
type=int,
16-
default=3,
21+
default=4,
1722
help='input data channel')
1823
parser.add_argument(
1924
'--shape',
2025
type=int,
2126
nargs='+',
22-
default=[224, 224],
27+
default=[512],
2328
help='input data size')
2429
args = parser.parse_args()
2530
return args
@@ -36,6 +41,8 @@ def main():
3641
input_shape = (in_channel, ) + tuple(args.shape)
3742
else:
3843
raise ValueError('invalid input shape')
44+
if args.channel == 0: # using nn.Embedding in the model
45+
input_shape = input_shape[1:]
3946

4047
cfg = Config.fromfile(args.config)
4148
model = build_model(cfg.model)

0 commit comments

Comments
 (0)