Skip to content

Commit 1243565

Browse files
fix unk format when k>1
1 parent c15375a commit 1243565

File tree

3 files changed

+18
-2
lines changed

3 files changed

+18
-2
lines changed

fastlangid/langid.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,12 @@ def predict(self, text, full_clean=False, supplement_threshold=0.9, k=1, prob=Fa
9494
if isinstance(text, unicode):
9595
text = self.clean_up(text)
9696
if len(text) == 0 or only_punctuations(text[:50]):
97-
return (UNK_CLS, 1.0) if prob else UNK_CLS
97+
output = UNK_CLS
98+
if prob:
99+
output = (UNK_CLS, 1.0)
100+
if k > 1:
101+
output = [output]
102+
return output
98103
return self._predict_text(text, supplement_threshold=supplement_threshold, k=k, prob=prob, force_second=force_second)
99104
else:
100105
batch = [ self.clean_up(i, full_clean=full_clean) for i in text ]

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727

2828
setup(
2929
name='fastlangid',
30-
version='1.0.8',
30+
version='1.0.9',
3131
description='Language detection for news powered by fasttext',
3232
long_description=readme,
3333
long_description_content_type="text/markdown",

tests/testcases.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# coding=utf-8
22
from __future__ import division, unicode_literals
33
from fastlangid.langid import LID
4+
from fastlangid.utils import UNK_CLS
45
from os import path
56
import unittest
67

@@ -126,8 +127,18 @@ def test_batches(self):
126127
])
127128
self.assertEqual(lang_codes, ['zh-hant', 'ru', 'de', 'en','it', 'fr', 'ja', 'pt'])
128129

130+
def test_unknown(self):
131+
lang_code = self.langid.predict('', force_second=True)
132+
self.assertEqual(lang_code, UNK_CLS)
129133

134+
lang_code = self.langid.predict('???。???。???。???。???。???。')
135+
self.assertEqual(lang_code, UNK_CLS)
130136

137+
lang_code = self.langid.predict('???。???。???。???。???。???。', prob=True)
138+
self.assertEqual(lang_code, (UNK_CLS, 1.0))
139+
140+
lang_code = self.langid.predict('???。???。???。???。???。???。', prob=True, k=15)
141+
self.assertEqual(lang_code, [(UNK_CLS, 1.0)])
131142

132143
def suite():
133144
suite = unittest.TestSuite()

0 commit comments

Comments
 (0)