Skip to content

Commit 3241eab

Browse files
committed
Add peptide and amino acid confidence scores to output file
1 parent 363f0f4 commit 3241eab

File tree

2 files changed

+2078
-7
lines changed

2 files changed

+2078
-7
lines changed

casanovo/denovo/model.py

Lines changed: 29 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
"""A de novo peptide sequencing model"""
2-
import logging, time, random, os
2+
import logging, time, random, os, csv
33

44
import torch
55
import numpy as np
@@ -402,7 +402,7 @@ def test_step(self, batch, *args):
402402
#De novo sequence the batch
403403
pred_seqs, scores = self.predict_step(batch)
404404
spectrum_order_id = batch[-1]
405-
self.denovo_seqs += [(spectrum_order_id, pred_seqs)]
405+
self.denovo_seqs += [(spectrum_order_id, pred_seqs,scores)]
406406

407407

408408
def on_train_epoch_end(self):
@@ -428,17 +428,39 @@ def on_validation_epoch_end(self):
428428
self._history.append(metrics)
429429

430430
def on_test_epoch_end(self):
431-
"""Write de novo sequences to csv file.
431+
"""Write de novo sequences and confidence scores to csv file.
432432
433433
This is a pytorch-lightning hook.
434434
"""
435-
with open(os.path.join(self.output_path,'casanovo_output.csv'), 'w') as f:
436-
f.write(f'spectrum_id,denovo_seq\n')
435+
with open(os.path.join(self.output_path,'casanovo_output.csv'), 'w') as f:
436+
writer = csv.writer(f)
437+
writer.writerow(['spectrum_id','denovo_seq','peptide_score','aa_scores'])
438+
437439
for batch in self.denovo_seqs:
440+
scores = batch[2].cpu() #transfer to cpu in case in gpu
441+
438442
for i in range(len(batch[0])):
439-
f.write(f'{batch[0][i]},{batch[1][i][1:]}\n')
443+
top_scores = torch.max(scores[i],axis=1)[0] #take the score of most probable AA
444+
empty_index = torch.where(top_scores==0.04)[0] #find the indices of positions after stop token
445+
446+
if len(empty_index)>0:#check if decoding was stopped
447+
last_index = empty_index[0]-1 #select index of the last AA
448+
449+
if last_index >= 1: #check if peptide is at least one AA long
450+
top_scores_list = top_scores[:last_index].tolist() #omit the stop token
451+
peptide_score = np.mean(top_scores_list)
452+
aa_scores = list(reversed(top_scores_list))
453+
454+
else:
455+
peptide_score = None
456+
aa_scores = None
457+
458+
else:
459+
peptide_score = None
460+
aa_scores = None
461+
462+
writer.writerow([batch[0][i],batch[1][i][1:],peptide_score,aa_scores])
440463

441-
442464
def on_epoch_end(self):
443465
"""Print log to console, if requested."""
444466

0 commit comments

Comments
 (0)