28
28
# Dependency imports
29
29
30
30
import six
31
+ from six import PY2
31
32
from six .moves import xrange # pylint: disable=redefined-builtin
32
33
from tensor2tensor .data_generators import tokenizer
33
34
34
35
import tensorflow as tf
35
36
37
+
38
+ # Conversion between Unicode and UTF-8, if required (on Python2)
39
+ _native_to_unicode = (lambda s : s .decode ("utf-8" )) if PY2 else (lambda s : s )
40
+
41
+
42
+ _unicode_to_native = (lambda s : s .encode ("utf-8" )) if PY2 else (lambda s : s )
43
+
44
+
36
45
# Reserved tokens for things like padding and EOS symbols.
37
46
PAD = "<pad>"
38
47
EOS = "<EOS>"
@@ -162,15 +171,36 @@ def _load_vocab_from_file(self, filename):
162
171
163
172
164
173
class SubwordTextEncoder (TextEncoder ):
165
- """Class for breaking tokens into subtokens .
174
+ """Class for invertibly encoding text using a limited vocabulary .
166
175
167
- Invertibly encodes a string as a sequence of subtokens from a limited
176
+ Invertibly encodes a native string as a sequence of subtokens from a limited
168
177
vocabulary.
169
178
170
179
A SubwordTextEncoder is built from a corpus (so it is tailored to the text in
171
180
the corpus), and stored to a file. See text_encoder_build_subword.py.
172
181
173
182
It can then be loaded and used to encode/decode any text.
183
+
184
+ Encoding has four phases:
185
+
186
+ 1. Tokenize into a list of tokens. Each token is a unicode string of either
187
+ all alphanumeric characters or all non-alphanumeric characters. We drop
188
+ tokens consisting of a single space that are between two alphanumeric
189
+ tokens.
190
+
191
+ 2. Escape each token. This escapes away special and out-of-vocabulary
192
+ characters, and makes sure that each token ends with an underscore, and
193
+ has no other underscores.
194
+
195
+ 3. Represent each escaped token as a the concatenation of a list of subtokens
196
+ from the limited vocabulary. Subtoken selection is done greedily from
197
+ beginning to end. That is, we construct the list in order, always picking
198
+ the longest subtoken in our vocabulary that matches a prefix of the
199
+ remaining portion of the encoded token.
200
+
201
+ 4. Concatenate these lists. This concatenation is invertible due to the
202
+ fact that the trailing underscores indicate when one list is finished.
203
+
174
204
"""
175
205
176
206
def __init__ (self , filename = None , num_reserved_ids = 2 ):
@@ -182,24 +212,26 @@ def __init__(self, filename=None, num_reserved_ids=2):
182
212
super (SubwordTextEncoder , self ).__init__ (num_reserved_ids = num_reserved_ids )
183
213
184
214
def encode (self , raw_text ):
185
- """Converts a string to a list of subtoken ids.
215
+ """Converts a native string to a list of subtoken ids.
186
216
187
217
Args:
188
- raw_text: a string.
218
+ raw_text: a native string.
189
219
Returns:
190
220
a list of integers in the range [0, vocab_size)
191
221
"""
192
- return self ._tokens_to_subtokens (self ._tokenizer .encode (raw_text ))
222
+ return self ._tokens_to_subtokens (self ._tokenizer .encode (
223
+ _native_to_unicode (raw_text )))
193
224
194
225
def decode (self , subtokens ):
195
- """Converts a sequence of subtoken ids to a string.
226
+ """Converts a sequence of subtoken ids to a native string.
196
227
197
228
Args:
198
229
subtokens: a list of integers in the range [0, vocab_size)
199
230
Returns:
200
- a string
231
+ a native string
201
232
"""
202
- return self ._tokenizer .decode (self ._subtokens_to_tokens (subtokens ))
233
+ return _unicode_to_native (self ._tokenizer .decode (
234
+ self ._subtokens_to_tokens (subtokens )))
203
235
204
236
@property
205
237
def vocab_size (self ):
@@ -239,8 +271,8 @@ def subtoken_to_subtoken_string(self, subtoken):
239
271
if subtoken_string :
240
272
return subtoken_string
241
273
if 0 <= subtoken < self ._num_reserved_ids :
242
- return "%s_" % RESERVED_TOKENS [subtoken ]
243
- return "ID%d_" % subtoken
274
+ return u "%s_" % RESERVED_TOKENS [subtoken ]
275
+ return u "ID%d_" % subtoken
244
276
245
277
def _escaped_token_to_subtokens (self , escaped_token ):
246
278
"""Converts an escaped token string to a list of subtokens.
@@ -260,27 +292,11 @@ def _escaped_token_to_subtokens(self, escaped_token):
260
292
if subtoken != - 1 :
261
293
break
262
294
end -= 1
263
- if end > pos :
264
- ret .append (subtoken )
265
- pos = end
266
- else :
267
- # No subtoken in the vocabulary matches escaped_token[pos].
268
- # This can happen if the token contains a Unicode character
269
- # that did not occur in the vocabulary training set.
270
- # The id self.vocab_size - 1 is decoded as Unicode uFFFD,
271
- # REPLACEMENT_CHARACTER.
272
- ret .append (self .vocab_size - 1 )
273
- # Ensure that the outer loop continues
274
- pos += 1
275
- return ret
295
+ assert end > pos
296
+ ret .append (subtoken )
297
+ pos = end
276
298
277
- @classmethod
278
- def alphabet (cls , token_counts ):
279
- """Return the set of Unicode characters that appear in the tokens."""
280
- alphabet_set = set ()
281
- for token in six .iterkeys (token_counts ):
282
- alphabet_set |= set (token )
283
- return alphabet_set
299
+ return ret
284
300
285
301
@classmethod
286
302
def build_to_target_size (cls ,
@@ -304,17 +320,12 @@ def build_to_target_size(cls,
304
320
Returns:
305
321
a SubwordTextEncoder instance.
306
322
"""
307
- # Calculate the alphabet, i.e. the set of all Unicode characters
308
- # that appear in the tokens.
309
- alphabet_set = cls .alphabet (token_counts )
310
- tf .logging .info ("Alphabet contains %d characters" % len (alphabet_set ))
311
-
312
323
def bisect (min_val , max_val ):
313
324
"""Bisection to find the right size."""
314
325
present_count = (max_val + min_val ) // 2
315
326
tf .logging .info ("Trying min_count %d" % present_count )
316
327
subtokenizer = cls ()
317
- subtokenizer .build_from_token_counts (token_counts , alphabet_set ,
328
+ subtokenizer .build_from_token_counts (token_counts ,
318
329
present_count , num_iterations )
319
330
if min_val >= max_val or subtokenizer .vocab_size == target_size :
320
331
return subtokenizer
@@ -333,17 +344,29 @@ def bisect(min_val, max_val):
333
344
334
345
def build_from_token_counts (self ,
335
346
token_counts ,
336
- alphabet_set ,
337
347
min_count ,
338
348
num_iterations = 4 ):
339
349
"""Train a SubwordTextEncoder based on a dictionary of word counts.
340
350
341
351
Args:
342
352
token_counts: a dictionary of Unicode strings to int.
343
- alphabet_set: the set of Unicode characters that appear in the tokens.
344
353
min_count: an integer - discard subtokens with lower counts.
345
354
num_iterations: an integer. how many iterations of refinement.
346
355
"""
356
+ # first determine the alphabet to include all characters with count at
357
+ # least min_count in the dataset.
358
+ char_counts = defaultdict (int )
359
+ for token , count in six .iteritems (token_counts ):
360
+ for c in token :
361
+ char_counts [c ] += count
362
+ self ._alphabet = set ()
363
+ for c , count in six .iteritems (char_counts ):
364
+ if count >= min_count :
365
+ self ._alphabet .add (c )
366
+ # Make sure all characters needed for escaping are included
367
+ for c in u"\\ _;0123456789" :
368
+ self ._alphabet .add (c )
369
+
347
370
# We build iteratively. On each iteration, we segment all the words,
348
371
# then count the resulting potential subtokens, keeping the ones
349
372
# with high enough counts for our new vocabulary.
@@ -367,43 +390,36 @@ def build_from_token_counts(self,
367
390
for end in xrange (start + 1 , len (escaped_token ) + 1 ):
368
391
subtoken_string = escaped_token [start :end ]
369
392
counts [subtoken_string ] += count
393
+ # Make sure all characters needed for escaping are included
394
+ for c in self ._alphabet :
395
+ counts [c ] += min_count
370
396
# Array of sets of candidate subtoken strings, by length
371
397
len_to_subtoken_strings = []
372
398
for subtoken_string , count in six .iteritems (counts ):
373
399
lsub = len (subtoken_string )
374
- # All subtoken strings of length 1 are automatically included
375
- # later, so we don't need to consider them here
376
- if count < min_count or lsub <= 1 :
377
- continue
378
- # Add this subtoken string to its length set
379
- while len (len_to_subtoken_strings ) <= lsub :
380
- len_to_subtoken_strings .append (set ())
381
- len_to_subtoken_strings [lsub ].add (subtoken_string )
400
+ if count >= min_count :
401
+ # Add this subtoken string to its length set
402
+ while len (len_to_subtoken_strings ) <= lsub :
403
+ len_to_subtoken_strings .append (set ())
404
+ len_to_subtoken_strings [lsub ].add (subtoken_string )
382
405
new_subtoken_strings = []
383
406
# consider the candidates longest to shortest, so that if we accept
384
407
# a longer subtoken string, we can decrement the counts of its prefixes.
385
- for subtoken_strings in reversed (len_to_subtoken_strings [2 :]):
408
+ for lsub in reversed (range (1 , len (len_to_subtoken_strings ))):
409
+ subtoken_strings = len_to_subtoken_strings [lsub ]
386
410
for subtoken_string in subtoken_strings :
387
411
count = counts [subtoken_string ]
388
- if count < min_count :
389
- continue
390
- new_subtoken_strings .append ((count , subtoken_string ))
391
- for l in xrange (1 , len (subtoken_string )):
392
- counts [subtoken_string [:l ]] -= count
393
- # Sort what we've got so far in decreasing order by count
412
+ if count >= min_count :
413
+ new_subtoken_strings .append ((count , subtoken_string ))
414
+ for l in xrange (1 , lsub ):
415
+ counts [subtoken_string [:l ]] -= count
416
+ # Sort in decreasing order by count
394
417
new_subtoken_strings .sort (reverse = True )
395
- # Add the alphabet set at the end of the vocabulary list
396
- for char in alphabet_set :
397
- new_subtoken_strings .append ((0 , char ))
398
- # Also include the Unicode REPLACEMENT CHARACTER to use
399
- # when encountering previously unseen Unicode characters
400
- # in the input (i.e. input external to the tokenizer training
401
- # set, which may thus contain characters not in the alphabet_set).
402
- # This must be the last entry in the subtoken vocabulary list.
403
- new_subtoken_strings .append ((0 , u"\uFFFD " ))
404
418
# Now we have a candidate vocabulary
419
+ old_alphabet = self ._alphabet
405
420
self ._init_from_list ([u"" ] * self ._num_reserved_ids +
406
421
[p [1 ] for p in new_subtoken_strings ])
422
+ assert old_alphabet == self ._alphabet
407
423
tf .logging .info ("vocab_size = %d" % self .vocab_size )
408
424
409
425
original = "This sentence was encoded by the SubwordTextEncoder."
@@ -426,46 +442,77 @@ def _init_from_list(self, subtoken_strings):
426
442
self ._all_subtoken_strings = subtoken_strings
427
443
self ._subtoken_string_to_id = {
428
444
s : i for i , s in enumerate (subtoken_strings ) if s }
445
+ self ._alphabet = set ([c for c in subtoken_strings if len (c ) == 1 ])
429
446
430
447
def _load_from_file (self , filename ):
431
448
"""Load from a file."""
432
449
subtoken_strings = []
433
450
with tf .gfile .Open (filename ) as f :
434
451
for line in f :
435
- if six .PY2 :
436
- subtoken_strings .append (line .strip ()[1 :- 1 ].decode ("utf-8" ))
437
- else :
438
- subtoken_strings .append (line .strip ()[1 :- 1 ])
452
+ subtoken_strings .append (_native_to_unicode (line .strip ()[1 :- 1 ]))
439
453
self ._init_from_list (subtoken_strings )
440
454
441
455
def store_to_file (self , filename ):
442
456
with tf .gfile .Open (filename , "w" ) as f :
443
457
for subtoken_string in self ._all_subtoken_strings :
444
- if six .PY2 :
445
- f .write ("'" + subtoken_string .encode ("utf-8" ) + "'\n " )
446
- else :
447
- f .write ("'" + subtoken_string + "'\n " )
458
+ f .write ("'" + _unicode_to_native (subtoken_string ) + "'\n " )
448
459
449
460
def _escape_token (self , token ):
450
- r"""Translate '\'->'\\' and '_'->'\u', then append '_'.
461
+ r"""Escape away underscores and OOV characters and append '_'.
462
+
463
+ This allows the token to be experessed as the concatenation of a list
464
+ of subtokens from the vocabulary. The underscore acts as a sentinel
465
+ which allows us to invertibly concatenate multiple such lists.
451
466
452
467
Args:
453
- token: a string
468
+ token: a unicode string
454
469
Returns:
455
- escaped_token: a string
470
+ escaped_token: a unicode string
456
471
"""
457
- return token .replace ("\\ " , "\\ \\ " ).replace ("_" , "\\ u" ) + "_"
472
+ token = token .replace ("\\ " , "\\ \\ " ).replace ("_" , "\\ u" ) + "_"
473
+ ret = u""
474
+ for c in token :
475
+ if c in self ._alphabet :
476
+ ret += c
477
+ else :
478
+ ret += u"\\ %d;" % ord (c )
479
+ return ret
458
480
459
481
def _unescape_token (self , escaped_token ):
460
- r"""Remove '_' from end, then translate '\\'->'\' and '\u'->'_' .
482
+ r"""Inverse of _escape_token() .
461
483
462
484
Args:
463
- escaped_token: a string
485
+ escaped_token: a unicode string
464
486
Returns:
465
- token: a string
487
+ token: a unicode string
466
488
"""
467
- assert escaped_token [- 1 ] == "_"
468
- return escaped_token [:- 1 ].replace ("\\ u" , "_" ).replace ("\\ \\ " , "\\ " )
489
+ ret = u""
490
+ escaped_token = escaped_token [:- 1 ]
491
+ pos = 0
492
+ while pos < len (escaped_token ):
493
+ c = escaped_token [pos ]
494
+ if c == "\\ " :
495
+ pos += 1
496
+ c = escaped_token [pos ]
497
+ if c == u"u" :
498
+ ret += u"_"
499
+ pos += 1
500
+ elif c == "\\ " :
501
+ ret += u"_"
502
+ pos += 1
503
+ else :
504
+ semicolon_pos = escaped_token .find (u";" , pos )
505
+ if semicolon_pos == - 1 :
506
+ continue
507
+ try :
508
+ ret += unichr (int (escaped_token [pos :semicolon_pos ]))
509
+ pos = semicolon_pos + 1
510
+ except (ValueError , OverflowError ) as _ :
511
+ pass
512
+ else :
513
+ ret += c
514
+ pos += 1
515
+ return ret
469
516
470
517
@classmethod
471
518
def get_token_counts (cls , text_filepattern , corpus_max_lines ):
@@ -477,7 +524,7 @@ def get_token_counts(cls, text_filepattern, corpus_max_lines):
477
524
with tf .gfile .Open (text_filename ) as f :
478
525
for line in f :
479
526
# The tokenizer updates token_counts in encode()
480
- tok .encode (line .strip ())
527
+ tok .encode (_native_to_unicode ( line .strip () ))
481
528
lines_read += 1
482
529
if corpus_max_lines > 0 and lines_read > corpus_max_lines :
483
530
return tok .token_counts
0 commit comments