Skip to content

Commit 23c02f6

Browse files
authored
Reversed peptide aa scores hotfix (#417)
* reverse aa scores hotfix * reverse aa scores hotfix
1 parent 9e3f3d1 commit 23c02f6

File tree

2 files changed

+20
-1
lines changed

2 files changed

+20
-1
lines changed

casanovo/denovo/model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -684,7 +684,7 @@ def _get_top_peptide(
684684
yield [
685685
(
686686
pep_score,
687-
aa_scores,
687+
aa_scores[::-1] if self.decoder.reverse else aa_scores,
688688
"".join(self.decoder.detokenize(pred_tokens)),
689689
)
690690
for pep_score, _, aa_scores, pred_tokens in heapq.nlargest(

tests/unit_tests/test_unit.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -240,6 +240,25 @@ def test_beam_search_decode():
240240
[pep[-1] for pep in list(model._get_top_peptide(test_cache))[0]]
241241
) == {"PEPK", "PEPP"}
242242

243+
# Test reverse aa scores when decoder is reversed
244+
pred_cache = {
245+
0: [(1.0, 0.42, np.array([1.0, 0.0]), torch.Tensor([4, 14]))]
246+
}
247+
248+
model.decoder.reverse = True
249+
top_peptides = list(model._get_top_peptide(pred_cache))
250+
assert len(top_peptides) == 1
251+
assert len(top_peptides[0]) == 1
252+
assert np.allclose(top_peptides[0][0][1], np.array([0.0, 1.0]))
253+
assert top_peptides[0][0][2] == "EP"
254+
255+
model.decoder.reverse = False
256+
top_peptides = list(model._get_top_peptide(pred_cache))
257+
assert len(top_peptides) == 1
258+
assert len(top_peptides[0]) == 1
259+
assert np.allclose(top_peptides[0][0][1], np.array([1.0, 0.0]))
260+
assert top_peptides[0][0][2] == "PE"
261+
243262
# Test _get_topk_beams().
244263
# Set scores to proceed generating the unfinished beam.
245264
step = 4

0 commit comments

Comments
 (0)