1
1
"""Training and testing functionality for the de novo peptide sequencing model"""
2
- import logging
2
+ import logging , importlib
3
3
from pathlib import Path
4
+ import pytorch_lightning as pl
4
5
from depthcharge .data import AnnotatedSpectrumIndex
5
6
from casanovo .denovo import DeNovoDataModule , Spec2Pep
6
- from casanovo import config
7
- import pytorch_lightning as pl
8
7
9
- def train ():
10
- """Train a Casanovo model with options specified in config.py."""
8
+ def train (train_data_path , val_data_path , model_path , config_path ):
9
+ """Train a Casanovo model with options specified in config.py."""
11
10
11
+ #Use custom config file if specified
12
+ if config_path == None :
13
+ from casanovo import config
14
+ else :
15
+ importlib .machinery .SourceFileLoader ('config' , config_path ).load_module ()
16
+ import config
17
+
12
18
#Set random seed across PyTorch, numpy and python.random
13
19
pl .utilities .seed .seed_everything (seed = config .random_seed , workers = True )
14
20
15
21
#Index training and validation data
16
- train_mgf_file = Path (config . train_data_path )
17
- val_mgf_file = Path (config . val_data_path )
22
+ train_mgf_file = Path (train_data_path )
23
+ val_mgf_file = Path (val_data_path )
18
24
19
25
train_index = AnnotatedSpectrumIndex (config .train_annot_spec_idx_path , train_mgf_file , overwrite = config .train_spec_idx_overwrite )
20
26
val_index = AnnotatedSpectrumIndex (config .val_annot_spec_idx_path , val_mgf_file , overwrite = config .val_spec_idx_overwrite )
@@ -64,7 +70,7 @@ def train():
64
70
65
71
else :
66
72
model = Spec2Pep ().load_from_checkpoint (
67
- config . model_full_path ,
73
+ model_path ,
68
74
dim_model = config .dim_model ,
69
75
n_head = config .n_head ,
70
76
dim_feedforward = config .dim_feedforward ,
@@ -97,7 +103,7 @@ def train():
97
103
)
98
104
99
105
trainer = pl .Trainer (
100
- strategy = config .accelerator ,
106
+ accelerator = config .accelerator ,
101
107
logger = config .logger ,
102
108
gpus = config .gpus ,
103
109
max_epochs = config .max_epochs ,
@@ -108,7 +114,7 @@ def train():
108
114
else :
109
115
110
116
trainer = pl .Trainer (
111
- strategy = config .accelerator ,
117
+ accelerator = config .accelerator ,
112
118
logger = config .logger ,
113
119
gpus = config .gpus ,
114
120
max_epochs = config .max_epochs ,
@@ -118,12 +124,19 @@ def train():
118
124
#Train the model
119
125
trainer .fit (model , train_loader .train_dataloader (), val_loader .val_dataloader ())
120
126
121
- def test_denovo ():
127
+ def test_denovo (test_data_path , model_path , config_path ):
122
128
"""Test a pre-trained Casanovo model with options specified in config.py."""
123
129
130
+ #Use custom config file if specified
131
+ if config_path == None :
132
+ from casanovo import config
133
+ else :
134
+ importlib .machinery .SourceFileLoader ('config' , config_path ).load_module ()
135
+ import config
136
+
124
137
# Initialize the pre-trained model
125
138
model_trained = Spec2Pep ().load_from_checkpoint (
126
- config . model_full_path ,
139
+ model_path ,
127
140
dim_model = config .dim_model ,
128
141
n_head = config .n_head ,
129
142
dim_feedforward = config .dim_feedforward ,
@@ -137,7 +150,7 @@ def test_denovo():
137
150
n_log = config .n_log ,
138
151
)
139
152
#Index test data
140
- mgf_file = Path (config . test_data_path )
153
+ mgf_file = Path (test_data_path )
141
154
index = AnnotatedSpectrumIndex (config .test_annot_spec_idx_path , mgf_file , overwrite = config .test_spec_idx_overwrite )
142
155
143
156
#Initialize the data loader
@@ -154,7 +167,7 @@ def test_denovo():
154
167
155
168
#Create Trainer object
156
169
trainer = pl .Trainer (
157
- strategy = config .accelerator ,
170
+ accelerator = config .accelerator ,
158
171
logger = config .logger ,
159
172
gpus = config .gpus ,
160
173
max_epochs = config .max_epochs ,
@@ -163,5 +176,3 @@ def test_denovo():
163
176
164
177
#Run test
165
178
trainer .validate (model_trained , loaders .test_dataloader ())
166
-
167
-
0 commit comments