Skip to content

Commit f678533

Browse files
committed
add demo
1 parent 99413d0 commit f678533

File tree

6 files changed

+166
-13
lines changed

6 files changed

+166
-13
lines changed

demo/grna_demo.csv

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
1,TCTGTAAAGTGGCAAGCAGGAGTCTGCTACAATGGAGGAAAGGATTTTGCTGTATCTCTTGCC
2+
2,CAGGAGGGAAACATGGTTACTGCTCGCCAGGAACCTCGCCTGGTCCTGATTTCCCTGACCTGC

demo/grna_demo.py

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
"""
2+
An example of prediction of gRNA editing efficiency.
3+
4+
Example command (a sequence in 63 bits):
5+
python grna_demo.py TTGCTGTATCTCTTGCCAGGCCCAAGGCTGCAGAGGGAATTGGTAATATACTTCATTTAATAA
6+
7+
Output results:
8+
0.20432067
9+
"""
10+
11+
import argparse
12+
import torch
13+
from mmcv.runner import load_checkpoint
14+
15+
from openbioseq.datasets.data_sources.bio_seq_source import binarize
16+
from openbioseq.models import build_model
17+
from openbioseq.datasets.utils import read_file
18+
19+
20+
def parse_args():
21+
parser = argparse.ArgumentParser(
22+
description='Process an input gRNA sequence to predict')
23+
parser.add_argument('--input_seq', type=str, default=None, help='input sequence')
24+
parser.add_argument('--input_file', type=str, default=None,
25+
help='path to a input file containing several sequences')
26+
parser.add_argument('--debug', action='store_true', default=False,
27+
help='debug mode of demo')
28+
args = parser.parse_args()
29+
return args
30+
31+
32+
def get_model_config(seq_len=63, embed_dim=64, patch_size=2):
33+
""" Transformer """
34+
35+
checkpoint = "https://github.com/Westlake-AI/OpenBioSeq/releases/download/v0.1.1/k562_layer4_p2_h4_d64_init_bs256_ep100.pth"
36+
model = dict(
37+
type='Classification',
38+
pretrained=None,
39+
backbone=dict(
40+
type='SequenceTransformer',
41+
arch=dict(
42+
embed_dims=embed_dim,
43+
num_layers=4,
44+
num_heads=4,
45+
feedforward_channels=embed_dim * 4),
46+
in_channels=4,
47+
patch_size=patch_size,
48+
seq_len=int(seq_len / patch_size) + bool(seq_len % patch_size != 0),
49+
norm_cfg=dict(type='LN', eps=1e-6),
50+
drop_rate=0.1,
51+
drop_path_rate=0.1,
52+
init_values=0.1,
53+
final_norm=True,
54+
out_indices=-1, # last layer
55+
with_cls_token=False,
56+
output_cls_token=False),
57+
head=dict(
58+
type='RegHead',
59+
loss=dict(type='RegressionLoss', mode='huber_loss',
60+
loss_weight=1.0, reduction='mean',
61+
activate='sigmoid', alpha=0.2, gamma=1.0, beta=1.0, residual=False),
62+
with_avg_pool=True, in_channels=embed_dim, out_channels=1),
63+
)
64+
65+
return model, checkpoint
66+
67+
68+
def main():
69+
args = parse_args()
70+
if args.debug:
71+
input_seq = ["TTGCTGTATCTCTTGCCAGGCCCAAGGCTGCAGAGGGAATTGGTAATATACTTCATTTAATAA"]
72+
else:
73+
if args.input_seq is not None:
74+
input_seq = [args.input_seq]
75+
elif args.input_file is not None:
76+
input_seq = read_file(args.input_file)
77+
for i in range(len(input_seq)):
78+
input_seq[i] = input_seq[i].replace('\n', '')
79+
else:
80+
print(args)
81+
assert False and "Invalid input args"
82+
83+
# input
84+
seq_len, key_num = 63, 4
85+
key_mapping = dict(A=0, C=1, G=2, T=3)
86+
try:
87+
input_seq = binarize(
88+
input_seq, mapping=key_mapping, max_seq_length=seq_len, data_splitor=',')
89+
except ValueError:
90+
assert False and "Please check the input sequence"
91+
92+
# build the model and load checkpoint
93+
cfg_model, checkpoint = get_model_config(seq_len=seq_len)
94+
model = build_model(cfg_model)
95+
load_checkpoint(model, checkpoint, map_location='cpu')
96+
97+
# inference
98+
if len(input_seq) == 1:
99+
input_seq = input_seq[0].unsqueeze(0)
100+
else:
101+
input_seq = torch.concat(input_seq).view(-1, key_num, seq_len)
102+
103+
output = model(input_seq, mode='inference').detach().cpu().numpy()
104+
print("Prediction:", output)
105+
106+
107+
if __name__ == '__main__':
108+
main()

openbioseq/datasets/data_sources/bio_seq_source.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ def binarize(data_list, mapping, max_seq_length=None, data_splitor=None):
2626
token_list.append(onehot_seq)
2727
except:
2828
print(f"Error seq:", _seq)
29+
raise ValueError
2930
return token_list
3031

3132

openbioseq/models/classifiers/base_model.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,22 @@ def forward_test(self, data, **kwargs):
7474
"""
7575
pass
7676

77+
def forward_inference(self, data, **kwargs):
78+
"""
79+
Args:
80+
data (Tensor): List of tensors. Typically these should be
81+
mean centered and std scaled.
82+
kwargs (keyword arguments): Specific to concrete implementation.
83+
84+
Returns:
85+
tuple[Tensor]: final model outputs.
86+
"""
87+
x = self.backbone(data)
88+
if self.with_neck:
89+
x = self.neck(x)
90+
preds = self.head(x)
91+
return preds[0]
92+
7793
def forward_vis(self, data, **kwargs):
7894
"""Forward backbone features for visualization.
7995
@@ -122,6 +138,8 @@ def forward(self, data, mode='train', **kwargs):
122138
return self.forward_train(data, **kwargs)
123139
elif mode == 'test':
124140
return self.forward_test(data, **kwargs)
141+
elif mode == 'inference':
142+
return self.forward_inference(data, **kwargs)
125143
elif mode == 'calibration':
126144
return self.forward_calibration(data, **kwargs)
127145
elif mode == 'extract':

tools/model_converters/publish_model.py

Lines changed: 32 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,32 +1,51 @@
1+
"""
2+
Extract parameters and publish the model.
3+
4+
Example command:
5+
python tools/publish_model.py [PATH/to/checkpoint] [PATH/to/output]
6+
"""
17
import argparse
28
import subprocess
39

10+
import torch
11+
412

513
def parse_args():
614
parser = argparse.ArgumentParser(
715
description='Process a checkpoint to be published')
816
parser.add_argument('in_file', help='input checkpoint filename')
17+
parser.add_argument('out_file', help='output checkpoint filename')
18+
parser.add_argument('--decode', action='store_true', default=False,
19+
help='whether to add sha256sum in the output name')
920
args = parser.parse_args()
1021
return args
1122

1223

13-
def process_checkpoint(in_file):
14-
tmp_file = in_file + ".tmp"
15-
subprocess.Popen(['cp', in_file, tmp_file])
16-
sha = subprocess.check_output(['sha256sum', tmp_file]).decode()
17-
out_file = in_file
18-
if out_file.endswith('.pth'):
19-
out_file = out_file[:-4]
20-
final_file = out_file + f'-{sha[:8]}.pth'
21-
assert final_file != in_file, \
22-
"The output filename is the same as the input file."
23-
print("Output file: {}".format(final_file))
24-
subprocess.Popen(['mv', tmp_file, final_file])
24+
def process_checkpoint(in_file, out_file, decode=False):
25+
checkpoint = torch.load(in_file, map_location='cpu')
26+
# remove optimizer for smaller file size
27+
if 'optimizer' in checkpoint:
28+
del checkpoint['optimizer']
29+
# if it is necessary to remove some sensitive data in checkpoint['meta'],
30+
# add the code here.
31+
if torch.__version__ >= '1.6':
32+
torch.save(checkpoint, out_file, _use_new_zipfile_serialization=False)
33+
else:
34+
torch.save(checkpoint, out_file)
35+
36+
if decode:
37+
sha = subprocess.check_output(['sha256sum', out_file]).decode()
38+
if out_file.endswith('.pth'):
39+
out_file_name = out_file[:-4]
40+
else:
41+
out_file_name = out_file
42+
final_file = out_file_name + f'-{sha[:8]}.pth'
43+
subprocess.Popen(['mv', out_file, final_file])
2544

2645

2746
def main():
2847
args = parse_args()
29-
process_checkpoint(args.in_file)
48+
process_checkpoint(args.in_file, args.out_file, args.decode)
3049

3150

3251
if __name__ == '__main__':

tools/train.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,10 @@ def parse_args():
2828
help='the dir to save logs and models')
2929
parser.add_argument(
3030
'--resume_from', help='the checkpoint file to resume from')
31+
parser.add_argument(
32+
'--auto_resume',
33+
action='store_true',
34+
help='resume from the latest checkpoint automatically')
3135
parser.add_argument(
3236
'--pretrained', default=None, help='pretrained model file')
3337
parser.add_argument(
@@ -89,6 +93,7 @@ def main():
8993
osp.splitext(osp.basename(args.config))[0])
9094
if args.resume_from is not None:
9195
cfg.resume_from = args.resume_from
96+
cfg.auto_resume = args.auto_resume
9297
cfg.gpus = args.gpus
9398

9499
# check memcached package exists

0 commit comments

Comments
 (0)