27
27
28
28
import six
29
29
from six .moves import xrange # pylint: disable=redefined-builtin
30
+ from collections import defaultdict
30
31
from tensor2tensor .data_generators import tokenizer
31
32
32
33
import tensorflow as tf
35
36
PAD = '<pad>'
36
37
EOS = '<EOS>'
37
38
RESERVED_TOKENS = [PAD , EOS ]
38
-
39
+ if six .PY2 :
40
+ RESERVED_TOKENS_BYTES = RESERVED_TOKENS
41
+ else :
42
+ RESERVED_TOKENS_BYTES = [bytes (PAD , 'ascii' ), bytes (EOS , 'ascii' )]
39
43
40
44
class TextEncoder (object ):
41
45
"""Base class for converting from ints to/from human readable strings."""
@@ -87,17 +91,25 @@ class ByteTextEncoder(TextEncoder):
87
91
"""Encodes each byte to an id. For 8-bit strings only."""
88
92
89
93
def encode (self , s ):
90
- return [ord (c ) + self ._num_reserved_ids for c in s ]
94
+ numres = self ._num_reserved_ids
95
+ if six .PY2 :
96
+ return [ord (c ) + numres for c in s ]
97
+ # Python3: explicitly convert to UTF-8
98
+ return [c + numres for c in s .encode ("utf-8" )]
91
99
92
100
def decode (self , ids ):
101
+ numres = self ._num_reserved_ids
93
102
decoded_ids = []
103
+ int2byte = six .int2byte
94
104
for id_ in ids :
95
- if 0 <= id_ < self . _num_reserved_ids :
96
- decoded_ids .append (RESERVED_TOKENS [int (id_ )])
105
+ if 0 <= id_ < numres :
106
+ decoded_ids .append (RESERVED_TOKENS_BYTES [int (id_ )])
97
107
else :
98
- decoded_ids .append (chr (id_ ))
99
-
100
- return '' .join (decoded_ids )
108
+ decoded_ids .append (int2byte (id_ - numres ))
109
+ if six .PY2 :
110
+ return '' .join (decoded_ids )
111
+ # Python3: join byte arrays and then decode string
112
+ return b'' .join (decoded_ids ).decode ("utf-8" )
101
113
102
114
@property
103
115
def vocab_size (self ):
@@ -111,20 +123,16 @@ def __init__(self, vocab_filename, reverse=False, num_reserved_ids=2):
111
123
"""Initialize from a file, one token per line."""
112
124
super (TokenTextEncoder , self ).__init__ (num_reserved_ids = num_reserved_ids )
113
125
self ._reverse = reverse
114
- if vocab_filename is not None :
115
- self ._load_vocab_from_file (vocab_filename )
126
+ self ._load_vocab_from_file (vocab_filename )
116
127
117
128
def encode (self , sentence ):
118
129
"""Converts a space-separated string of tokens to a list of ids."""
119
130
ret = [self ._token_to_id [tok ] for tok in sentence .strip ().split ()]
120
- if self ._reverse :
121
- ret = ret [::- 1 ]
122
- return ret
131
+ return ret [::- 1 ] if self ._reverse else ret
123
132
124
133
def decode (self , ids ):
125
- if self ._reverse :
126
- ids = ids [::- 1 ]
127
- return ' ' .join ([self ._safe_id_to_token (i ) for i in ids ])
134
+ seq = reversed (ids ) if self ._reverse else ids
135
+ return ' ' .join ([self ._safe_id_to_token (i ) for i in seq ])
128
136
129
137
@property
130
138
def vocab_size (self ):
@@ -243,15 +251,22 @@ def _escaped_token_to_subtokens(self, escaped_token):
243
251
"""
244
252
ret = []
245
253
pos = 0
246
- while pos < len (escaped_token ):
247
- end = len (escaped_token )
248
- while True :
254
+ lesc = len (escaped_token )
255
+ while pos < lesc :
256
+ end = lesc
257
+ while end > pos :
249
258
subtoken = self ._subtoken_string_to_id .get (escaped_token [pos :end ], - 1 )
250
259
if subtoken != - 1 :
251
260
break
252
261
end -= 1
253
262
ret .append (subtoken )
254
- pos = end
263
+ if end > pos :
264
+ pos = end
265
+ else :
266
+ # This kinda should not happen, but it does. Cop out by skipping the
267
+ # nonexistent subtoken from the returned list.
268
+ # print("Unable to find subtoken in string '{0}'".format(escaped_token))
269
+ pos += 1
255
270
return ret
256
271
257
272
@classmethod
@@ -322,13 +337,13 @@ def build_from_token_counts(self,
322
337
# then count the resulting potential subtokens, keeping the ones
323
338
# with high enough counts for our new vocabulary.
324
339
for i in xrange (num_iterations ):
325
- counts = {}
340
+ counts = defaultdict ( int )
326
341
for token , count in six .iteritems (token_counts ):
327
342
escaped_token = self ._escape_token (token )
328
343
# we will count all tails of the escaped_token, starting from boundaries
329
344
# determined by our current segmentation.
330
345
if i == 0 :
331
- starts = list ( range ( len (escaped_token ) ))
346
+ starts = xrange ( len (escaped_token ))
332
347
else :
333
348
subtokens = self ._escaped_token_to_subtokens (escaped_token )
334
349
pos = 0
@@ -337,31 +352,33 @@ def build_from_token_counts(self,
337
352
starts .append (pos )
338
353
pos += len (self .subtoken_to_subtoken_string (subtoken ))
339
354
for start in starts :
340
- for end in xrange (start + 1 , len (escaped_token ) + 1 ):
355
+ for end in xrange (start + 1 , len (escaped_token )):
341
356
subtoken_string = escaped_token [start :end ]
342
- counts [subtoken_string ] = counts . get ( subtoken_string , 0 ) + count
357
+ counts [subtoken_string ] += count
343
358
# array of lists of candidate subtoken strings, by length
344
359
len_to_subtoken_strings = []
345
360
for subtoken_string , count in six .iteritems (counts ):
346
- if count < min_count or len (subtoken_string ) <= 1 :
361
+ lsub = len (subtoken_string )
362
+ # all subtoken strings of length 1 are included regardless of count
363
+ if count < min_count and lsub != 1 :
347
364
continue
348
- while len (len_to_subtoken_strings ) <= len ( subtoken_string ) :
365
+ while len (len_to_subtoken_strings ) <= lsub :
349
366
len_to_subtoken_strings .append ([])
350
- len_to_subtoken_strings [len ( subtoken_string ) ].append (subtoken_string )
367
+ len_to_subtoken_strings [lsub ].append (subtoken_string )
351
368
new_subtoken_strings = []
352
369
# consider the candidates longest to shortest, so that if we accept
353
370
# a longer subtoken string, we can decrement the counts of its prefixes.
354
371
for subtoken_strings in len_to_subtoken_strings [::- 1 ]:
355
372
for subtoken_string in subtoken_strings :
356
373
count = counts [subtoken_string ]
357
- if count < min_count :
374
+ if count < min_count and len (subtoken_string ) != 1 :
375
+ # subtoken strings of length 1 are included regardless of count
358
376
continue
359
377
new_subtoken_strings .append ((- count , subtoken_string ))
360
378
for l in xrange (1 , len (subtoken_string )):
361
379
counts [subtoken_string [:l ]] -= count
362
- # make sure we have all single characters.
363
- new_subtoken_strings .extend ([(- counts .get (chr (i ), 0 ), chr (i ))
364
- for i in xrange (2 ** 8 )])
380
+ # Make sure to include the underscore as a subtoken string
381
+ new_subtoken_strings .append ((0 , '_' ))
365
382
new_subtoken_strings .sort ()
366
383
self ._init_from_list (['' ] * self ._num_reserved_ids +
367
384
[p [1 ] for p in new_subtoken_strings ])
@@ -390,13 +407,19 @@ def _load_from_file(self, filename):
390
407
subtoken_strings = []
391
408
with tf .gfile .Open (filename ) as f :
392
409
for line in f :
393
- subtoken_strings .append (line .strip ()[1 :- 1 ].decode ('string-escape' ))
410
+ if six .PY2 :
411
+ subtoken_strings .append (line .strip ()[1 :- 1 ].decode ('string-escape' ))
412
+ else :
413
+ subtoken_strings .append (line .strip ()[1 :- 1 ])
394
414
self ._init_from_list (subtoken_strings )
395
415
396
416
def _store_to_file (self , filename ):
397
417
with tf .gfile .Open (filename , 'w' ) as f :
398
418
for subtoken_string in self ._all_subtoken_strings :
399
- f .write ('\' ' + subtoken_string .encode ('string-escape' ) + '\' \n ' )
419
+ if six .PY2 :
420
+ f .write ('\' ' + subtoken_string .encode ('string-escape' ) + '\' \n ' )
421
+ else :
422
+ f .write ('\' ' + subtoken_string + '\' \n ' )
400
423
401
424
def _escape_token (self , token ):
402
425
r"""Translate '\'->'\\' and '_'->'\u', then append '_'.
0 commit comments