Skip to content

Commit cc8afa3

Browse files
committed
Add more cli options and specify custom config file
1 parent 3fa0556 commit cc8afa3

File tree

4 files changed

+54
-33
lines changed

4 files changed

+54
-33
lines changed

casanovo/casanovo.py

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,42 @@
11
"""The command line entry point for casanovo"""
2-
import sys, logging
2+
import click, logging
33
from casanovo.denovo import train, test_denovo
44

55

6-
def main():
6+
@click.command()
7+
@click.option("--mode", default='test', help="Choose to train a model or test denovo predictions")
8+
@click.option("--model_path", help="Specify path to pre-trained model weights for testing or to continue to train")
9+
@click.option("--train_data_path", help="Specify path to mgf files to be used as training data")
10+
@click.option("--val_data_path", help="Specify path to mgf files to be used as validation data")
11+
@click.option("--test_data_path", help="Specify path to mgf files to be used as test data")
12+
@click.option("--config_path", help="Specify path to config file which includes data and model related options")
13+
14+
def main(
15+
mode,
16+
model_path,
17+
train_data_path,
18+
val_data_path,
19+
test_data_path,
20+
config_path
21+
):
722
"""The command line function"""
823
logging.basicConfig(
924
level=logging.INFO,
1025
format="%(levelname)s: %(message)s",
1126
)
12-
if sys.argv[1:][0] == 'train':
27+
if mode == 'train':
1328

1429
logging.info('Training Casanovo...')
15-
train()
30+
train(train_data_path, val_data_path, model_path, config_path)
1631

17-
elif sys.argv[1:][0] == 'test':
32+
elif mode == 'test':
1833

1934
logging.info('Testing Casanovo...')
20-
test_denovo()
35+
test_denovo(test_data_path, model_path, config_path)
2136

2237
pass
2338

2439

2540
if __name__ == "__main__":
2641
main()
42+

casanovo/config.py

Lines changed: 4 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -7,18 +7,15 @@
77
random_seed = 454
88

99
#Train data options
10-
train_data_path = ''
11-
train_annot_spec_idx_path = ''
10+
train_annot_spec_idx_path = os.path.join(os.getcwd(),'casanovo_train.hdf5') #path to write the training data index file
1211
train_spec_idx_overwrite = True
1312

1413
#Validation data options
15-
val_data_path = ''
16-
val_annot_spec_idx_path = ''
14+
val_annot_spec_idx_path = os.path.join(os.getcwd(),'casanovo_val.hdf5') #path to write the validation data index file
1715
val_spec_idx_overwrite = True
1816

1917
#Test data options
20-
test_data_path = ''
21-
test_annot_spec_idx_path = ''
18+
test_annot_spec_idx_path = os.path.join(os.getcwd(),'casanovo_test.hdf5') #path to write the test data index file
2219
test_spec_idx_overwrite = True
2320

2421
#Preprocessing parameters
@@ -28,7 +25,7 @@
2825

2926
#Hardware options
3027
num_workers = 0
31-
gpus = [0] #None for CPU, int list to specify GPUs
28+
gpus = None #None for CPU, int list to specify GPUs
3229

3330
#Model options
3431
max_charge = 10
@@ -86,12 +83,8 @@
8683
num_sanity_val_steps = 0
8784

8885
train_from_scratch = True
89-
model_full_path = ''
9086

9187
save_model = False
9288
model_save_folder_path = ''
9389
save_weights_only = True
9490
every_n_epochs = 1
95-
96-
97-

casanovo/denovo/train_test.py

Lines changed: 27 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,26 @@
11
"""Training and testing functionality for the de novo peptide sequencing model"""
2-
import logging
2+
import logging, importlib
33
from pathlib import Path
4+
import pytorch_lightning as pl
45
from depthcharge.data import AnnotatedSpectrumIndex
56
from casanovo.denovo import DeNovoDataModule, Spec2Pep
6-
from casanovo import config
7-
import pytorch_lightning as pl
87

9-
def train():
10-
"""Train a Casanovo model with options specified in config.py."""
8+
def train(train_data_path, val_data_path, model_path, config_path):
9+
"""Train a Casanovo model with options specified in config.py."""
1110

11+
#Use custom config file if specified
12+
if config_path == None:
13+
from casanovo import config
14+
else:
15+
importlib.machinery.SourceFileLoader('config', config_path).load_module()
16+
import config
17+
1218
#Set random seed across PyTorch, numpy and python.random
1319
pl.utilities.seed.seed_everything(seed=config.random_seed, workers=True)
1420

1521
#Index training and validation data
16-
train_mgf_file = Path(config.train_data_path)
17-
val_mgf_file = Path(config.val_data_path)
22+
train_mgf_file = Path(train_data_path)
23+
val_mgf_file = Path(val_data_path)
1824

1925
train_index = AnnotatedSpectrumIndex(config.train_annot_spec_idx_path, train_mgf_file, overwrite=config.train_spec_idx_overwrite)
2026
val_index = AnnotatedSpectrumIndex(config.val_annot_spec_idx_path, val_mgf_file, overwrite=config.val_spec_idx_overwrite)
@@ -64,7 +70,7 @@ def train():
6470

6571
else:
6672
model = Spec2Pep().load_from_checkpoint(
67-
config.model_full_path,
73+
model_path,
6874
dim_model=config.dim_model,
6975
n_head=config.n_head,
7076
dim_feedforward=config.dim_feedforward,
@@ -97,7 +103,7 @@ def train():
97103
)
98104

99105
trainer = pl.Trainer(
100-
strategy=config.accelerator,
106+
accelerator=config.accelerator,
101107
logger=config.logger,
102108
gpus=config.gpus,
103109
max_epochs=config.max_epochs,
@@ -108,7 +114,7 @@ def train():
108114
else:
109115

110116
trainer = pl.Trainer(
111-
strategy=config.accelerator,
117+
accelerator=config.accelerator,
112118
logger=config.logger,
113119
gpus=config.gpus,
114120
max_epochs=config.max_epochs,
@@ -118,12 +124,19 @@ def train():
118124
#Train the model
119125
trainer.fit(model, train_loader.train_dataloader(), val_loader.val_dataloader())
120126

121-
def test_denovo():
127+
def test_denovo(test_data_path, model_path, config_path):
122128
"""Test a pre-trained Casanovo model with options specified in config.py."""
123129

130+
#Use custom config file if specified
131+
if config_path == None:
132+
from casanovo import config
133+
else:
134+
importlib.machinery.SourceFileLoader('config', config_path).load_module()
135+
import config
136+
124137
# Initialize the pre-trained model
125138
model_trained = Spec2Pep().load_from_checkpoint(
126-
config.model_full_path,
139+
model_path,
127140
dim_model=config.dim_model,
128141
n_head=config.n_head,
129142
dim_feedforward=config.dim_feedforward,
@@ -137,7 +150,7 @@ def test_denovo():
137150
n_log=config.n_log,
138151
)
139152
#Index test data
140-
mgf_file = Path(config.test_data_path)
153+
mgf_file = Path(test_data_path)
141154
index = AnnotatedSpectrumIndex(config.test_annot_spec_idx_path, mgf_file, overwrite=config.test_spec_idx_overwrite)
142155

143156
#Initialize the data loader
@@ -154,7 +167,7 @@ def test_denovo():
154167

155168
#Create Trainer object
156169
trainer = pl.Trainer(
157-
strategy=config.accelerator,
170+
accelerator=config.accelerator,
158171
logger=config.logger,
159172
gpus=config.gpus,
160173
max_epochs=config.max_epochs,
@@ -163,5 +176,3 @@ def test_denovo():
163176

164177
#Run test
165178
trainer.validate(model_trained, loaders.test_dataloader())
166-
167-

setup.cfg

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ install_requires =
2828
torch
2929
pytorch-lightning
3030
spectrum_utils
31+
click
3132

3233
[options.extras_require]
3334
docs =

0 commit comments

Comments
 (0)