23
23
from __future__ import division
24
24
from __future__ import print_function
25
25
26
+ from collections import defaultdict
27
+
26
28
# Dependency imports
27
29
28
30
import six
35
37
PAD = '<pad>'
36
38
EOS = '<EOS>'
37
39
RESERVED_TOKENS = [PAD , EOS ]
40
+ if six .PY2 :
41
+ RESERVED_TOKENS_BYTES = RESERVED_TOKENS
42
+ else :
43
+ RESERVED_TOKENS_BYTES = [bytes (PAD , 'ascii' ), bytes (EOS , 'ascii' )]
38
44
39
45
40
46
class TextEncoder (object ):
@@ -87,17 +93,25 @@ class ByteTextEncoder(TextEncoder):
87
93
"""Encodes each byte to an id. For 8-bit strings only."""
88
94
89
95
def encode (self , s ):
90
- return [ord (c ) + self ._num_reserved_ids for c in s ]
96
+ numres = self ._num_reserved_ids
97
+ if six .PY2 :
98
+ return [ord (c ) + numres for c in s ]
99
+ # Python3: explicitly convert to UTF-8
100
+ return [c + numres for c in s .encode ('utf-8' )]
91
101
92
102
def decode (self , ids ):
103
+ numres = self ._num_reserved_ids
93
104
decoded_ids = []
105
+ int2byte = six .int2byte
94
106
for id_ in ids :
95
- if 0 <= id_ < self . _num_reserved_ids :
96
- decoded_ids .append (RESERVED_TOKENS [int (id_ )])
107
+ if 0 <= id_ < numres :
108
+ decoded_ids .append (RESERVED_TOKENS_BYTES [int (id_ )])
97
109
else :
98
- decoded_ids .append (chr (id_ ))
99
-
100
- return '' .join (decoded_ids )
110
+ decoded_ids .append (int2byte (id_ - numres ))
111
+ if six .PY2 :
112
+ return '' .join (decoded_ids )
113
+ # Python3: join byte arrays and then decode string
114
+ return b'' .join (decoded_ids ).decode ('utf-8' )
101
115
102
116
@property
103
117
def vocab_size (self ):
@@ -111,20 +125,16 @@ def __init__(self, vocab_filename, reverse=False, num_reserved_ids=2):
111
125
"""Initialize from a file, one token per line."""
112
126
super (TokenTextEncoder , self ).__init__ (num_reserved_ids = num_reserved_ids )
113
127
self ._reverse = reverse
114
- if vocab_filename is not None :
115
- self ._load_vocab_from_file (vocab_filename )
128
+ self ._load_vocab_from_file (vocab_filename )
116
129
117
130
def encode (self , sentence ):
118
131
"""Converts a space-separated string of tokens to a list of ids."""
119
132
ret = [self ._token_to_id [tok ] for tok in sentence .strip ().split ()]
120
- if self ._reverse :
121
- ret = ret [::- 1 ]
122
- return ret
133
+ return ret [::- 1 ] if self ._reverse else ret
123
134
124
135
def decode (self , ids ):
125
- if self ._reverse :
126
- ids = ids [::- 1 ]
127
- return ' ' .join ([self ._safe_id_to_token (i ) for i in ids ])
136
+ seq = reversed (ids ) if self ._reverse else ids
137
+ return ' ' .join ([self ._safe_id_to_token (i ) for i in seq ])
128
138
129
139
@property
130
140
def vocab_size (self ):
@@ -243,15 +253,22 @@ def _escaped_token_to_subtokens(self, escaped_token):
243
253
"""
244
254
ret = []
245
255
pos = 0
246
- while pos < len (escaped_token ):
247
- end = len (escaped_token )
248
- while True :
256
+ lesc = len (escaped_token )
257
+ while pos < lesc :
258
+ end = lesc
259
+ while end > pos :
249
260
subtoken = self ._subtoken_string_to_id .get (escaped_token [pos :end ], - 1 )
250
261
if subtoken != - 1 :
251
262
break
252
263
end -= 1
253
264
ret .append (subtoken )
254
- pos = end
265
+ if end > pos :
266
+ pos = end
267
+ else :
268
+ # This kinda should not happen, but it does. Cop out by skipping the
269
+ # nonexistent subtoken from the returned list.
270
+ # print("Unable to find subtoken in string '{0}'".format(escaped_token))
271
+ pos += 1
255
272
return ret
256
273
257
274
@classmethod
@@ -322,13 +339,13 @@ def build_from_token_counts(self,
322
339
# then count the resulting potential subtokens, keeping the ones
323
340
# with high enough counts for our new vocabulary.
324
341
for i in xrange (num_iterations ):
325
- counts = {}
342
+ counts = defaultdict ( int )
326
343
for token , count in six .iteritems (token_counts ):
327
344
escaped_token = self ._escape_token (token )
328
345
# we will count all tails of the escaped_token, starting from boundaries
329
346
# determined by our current segmentation.
330
347
if i == 0 :
331
- starts = list ( range ( len (escaped_token ) ))
348
+ starts = xrange ( len (escaped_token ))
332
349
else :
333
350
subtokens = self ._escaped_token_to_subtokens (escaped_token )
334
351
pos = 0
@@ -337,31 +354,33 @@ def build_from_token_counts(self,
337
354
starts .append (pos )
338
355
pos += len (self .subtoken_to_subtoken_string (subtoken ))
339
356
for start in starts :
340
- for end in xrange (start + 1 , len (escaped_token ) + 1 ):
357
+ for end in xrange (start + 1 , len (escaped_token )):
341
358
subtoken_string = escaped_token [start :end ]
342
- counts [subtoken_string ] = counts . get ( subtoken_string , 0 ) + count
359
+ counts [subtoken_string ] += count
343
360
# array of lists of candidate subtoken strings, by length
344
361
len_to_subtoken_strings = []
345
362
for subtoken_string , count in six .iteritems (counts ):
346
- if count < min_count or len (subtoken_string ) <= 1 :
363
+ lsub = len (subtoken_string )
364
+ # all subtoken strings of length 1 are included regardless of count
365
+ if count < min_count and lsub != 1 :
347
366
continue
348
- while len (len_to_subtoken_strings ) <= len ( subtoken_string ) :
367
+ while len (len_to_subtoken_strings ) <= lsub :
349
368
len_to_subtoken_strings .append ([])
350
- len_to_subtoken_strings [len ( subtoken_string ) ].append (subtoken_string )
369
+ len_to_subtoken_strings [lsub ].append (subtoken_string )
351
370
new_subtoken_strings = []
352
371
# consider the candidates longest to shortest, so that if we accept
353
372
# a longer subtoken string, we can decrement the counts of its prefixes.
354
373
for subtoken_strings in len_to_subtoken_strings [::- 1 ]:
355
374
for subtoken_string in subtoken_strings :
356
375
count = counts [subtoken_string ]
357
- if count < min_count :
376
+ if count < min_count and len (subtoken_string ) != 1 :
377
+ # subtoken strings of length 1 are included regardless of count
358
378
continue
359
379
new_subtoken_strings .append ((- count , subtoken_string ))
360
380
for l in xrange (1 , len (subtoken_string )):
361
381
counts [subtoken_string [:l ]] -= count
362
- # make sure we have all single characters.
363
- new_subtoken_strings .extend ([(- counts .get (chr (i ), 0 ), chr (i ))
364
- for i in xrange (2 ** 8 )])
382
+ # Make sure to include the underscore as a subtoken string
383
+ new_subtoken_strings .append ((0 , '_' ))
365
384
new_subtoken_strings .sort ()
366
385
self ._init_from_list (['' ] * self ._num_reserved_ids +
367
386
[p [1 ] for p in new_subtoken_strings ])
@@ -390,13 +409,19 @@ def _load_from_file(self, filename):
390
409
subtoken_strings = []
391
410
with tf .gfile .Open (filename ) as f :
392
411
for line in f :
393
- subtoken_strings .append (line .strip ()[1 :- 1 ].decode ('string-escape' ))
412
+ if six .PY2 :
413
+ subtoken_strings .append (line .strip ()[1 :- 1 ].decode ('string-escape' ))
414
+ else :
415
+ subtoken_strings .append (line .strip ()[1 :- 1 ])
394
416
self ._init_from_list (subtoken_strings )
395
417
396
418
def _store_to_file (self , filename ):
397
419
with tf .gfile .Open (filename , 'w' ) as f :
398
420
for subtoken_string in self ._all_subtoken_strings :
399
- f .write ('\' ' + subtoken_string .encode ('string-escape' ) + '\' \n ' )
421
+ if six .PY2 :
422
+ f .write ('\' ' + subtoken_string .encode ('string-escape' ) + '\' \n ' )
423
+ else :
424
+ f .write ('\' ' + subtoken_string + '\' \n ' )
400
425
401
426
def _escape_token (self , token ):
402
427
r"""Translate '\'->'\\' and '_'->'\u', then append '_'.
0 commit comments