1
1
"""A de novo peptide sequencing model"""
2
- import logging , time , random , os
2
+ import logging , time , random , os , csv
3
3
4
4
import torch
5
5
import numpy as np
@@ -402,7 +402,7 @@ def test_step(self, batch, *args):
402
402
#De novo sequence the batch
403
403
pred_seqs , scores = self .predict_step (batch )
404
404
spectrum_order_id = batch [- 1 ]
405
- self .denovo_seqs += [(spectrum_order_id , pred_seqs )]
405
+ self .denovo_seqs += [(spectrum_order_id , pred_seqs , scores )]
406
406
407
407
408
408
def on_train_epoch_end (self ):
@@ -428,17 +428,39 @@ def on_validation_epoch_end(self):
428
428
self ._history .append (metrics )
429
429
430
430
def on_test_epoch_end (self ):
431
- """Write de novo sequences to csv file.
431
+ """Write de novo sequences and confidence scores to csv file.
432
432
433
433
This is a pytorch-lightning hook.
434
434
"""
435
- with open (os .path .join (self .output_path ,'casanovo_output.csv' ), 'w' ) as f :
436
- f .write (f'spectrum_id,denovo_seq\n ' )
435
+ with open (os .path .join (self .output_path ,'casanovo_output.csv' ), 'w' ) as f :
436
+ writer = csv .writer (f )
437
+ writer .writerow (['spectrum_id' ,'denovo_seq' ,'peptide_score' ,'aa_scores' ])
438
+
437
439
for batch in self .denovo_seqs :
440
+ scores = batch [2 ].cpu () #transfer to cpu in case in gpu
441
+
438
442
for i in range (len (batch [0 ])):
439
- f .write (f'{ batch [0 ][i ]} ,{ batch [1 ][i ][1 :]} \n ' )
443
+ top_scores = torch .max (scores [i ],axis = 1 )[0 ] #take the score of most probable AA
444
+ empty_index = torch .where (top_scores == 0.04 )[0 ] #find the indices of positions after stop token
445
+
446
+ if len (empty_index )> 0 :#check if decoding was stopped
447
+ last_index = empty_index [0 ]- 1 #select index of the last AA
448
+
449
+ if last_index >= 1 : #check if peptide is at least one AA long
450
+ top_scores_list = top_scores [:last_index ].tolist () #omit the stop token
451
+ peptide_score = np .mean (top_scores_list )
452
+ aa_scores = list (reversed (top_scores_list ))
453
+
454
+ else :
455
+ peptide_score = None
456
+ aa_scores = None
457
+
458
+ else :
459
+ peptide_score = None
460
+ aa_scores = None
461
+
462
+ writer .writerow ([batch [0 ][i ],batch [1 ][i ][1 :],peptide_score ,aa_scores ])
440
463
441
-
442
464
def on_epoch_end (self ):
443
465
"""Print log to console, if requested."""
444
466
0 commit comments