Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit 98be812

Browse files
nshazeerlukaszkaiser
authored andcommitted
Fix subword_text_tokenizer to make it invertible again. This breaks existing models and vocabularies. Change criteria for which characters are parts of words and which are separators - we now consider unicode letters and numbers to be parts of words.
PiperOrigin-RevId: 160690718
1 parent f3e5859 commit 98be812

File tree

4 files changed

+171
-151
lines changed

4 files changed

+171
-151
lines changed

tensor2tensor/data_generators/text_encoder.py

Lines changed: 127 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -28,11 +28,20 @@
2828
# Dependency imports
2929

3030
import six
31+
from six import PY2
3132
from six.moves import xrange # pylint: disable=redefined-builtin
3233
from tensor2tensor.data_generators import tokenizer
3334

3435
import tensorflow as tf
3536

37+
38+
# Conversion between Unicode and UTF-8, if required (on Python2)
39+
_native_to_unicode = (lambda s: s.decode("utf-8")) if PY2 else (lambda s: s)
40+
41+
42+
_unicode_to_native = (lambda s: s.encode("utf-8")) if PY2 else (lambda s: s)
43+
44+
3645
# Reserved tokens for things like padding and EOS symbols.
3746
PAD = "<pad>"
3847
EOS = "<EOS>"
@@ -162,15 +171,36 @@ def _load_vocab_from_file(self, filename):
162171

163172

164173
class SubwordTextEncoder(TextEncoder):
165-
"""Class for breaking tokens into subtokens.
174+
"""Class for invertibly encoding text using a limited vocabulary.
166175
167-
Invertibly encodes a string as a sequence of subtokens from a limited
176+
Invertibly encodes a native string as a sequence of subtokens from a limited
168177
vocabulary.
169178
170179
A SubwordTextEncoder is built from a corpus (so it is tailored to the text in
171180
the corpus), and stored to a file. See text_encoder_build_subword.py.
172181
173182
It can then be loaded and used to encode/decode any text.
183+
184+
Encoding has four phases:
185+
186+
1. Tokenize into a list of tokens. Each token is a unicode string of either
187+
all alphanumeric characters or all non-alphanumeric characters. We drop
188+
tokens consisting of a single space that are between two alphanumeric
189+
tokens.
190+
191+
2. Escape each token. This escapes away special and out-of-vocabulary
192+
characters, and makes sure that each token ends with an underscore, and
193+
has no other underscores.
194+
195+
3. Represent each escaped token as a the concatenation of a list of subtokens
196+
from the limited vocabulary. Subtoken selection is done greedily from
197+
beginning to end. That is, we construct the list in order, always picking
198+
the longest subtoken in our vocabulary that matches a prefix of the
199+
remaining portion of the encoded token.
200+
201+
4. Concatenate these lists. This concatenation is invertible due to the
202+
fact that the trailing underscores indicate when one list is finished.
203+
174204
"""
175205

176206
def __init__(self, filename=None, num_reserved_ids=2):
@@ -182,24 +212,26 @@ def __init__(self, filename=None, num_reserved_ids=2):
182212
super(SubwordTextEncoder, self).__init__(num_reserved_ids=num_reserved_ids)
183213

184214
def encode(self, raw_text):
185-
"""Converts a string to a list of subtoken ids.
215+
"""Converts a native string to a list of subtoken ids.
186216
187217
Args:
188-
raw_text: a string.
218+
raw_text: a native string.
189219
Returns:
190220
a list of integers in the range [0, vocab_size)
191221
"""
192-
return self._tokens_to_subtokens(self._tokenizer.encode(raw_text))
222+
return self._tokens_to_subtokens(self._tokenizer.encode(
223+
_native_to_unicode(raw_text)))
193224

194225
def decode(self, subtokens):
195-
"""Converts a sequence of subtoken ids to a string.
226+
"""Converts a sequence of subtoken ids to a native string.
196227
197228
Args:
198229
subtokens: a list of integers in the range [0, vocab_size)
199230
Returns:
200-
a string
231+
a native string
201232
"""
202-
return self._tokenizer.decode(self._subtokens_to_tokens(subtokens))
233+
return _unicode_to_native(self._tokenizer.decode(
234+
self._subtokens_to_tokens(subtokens)))
203235

204236
@property
205237
def vocab_size(self):
@@ -239,8 +271,8 @@ def subtoken_to_subtoken_string(self, subtoken):
239271
if subtoken_string:
240272
return subtoken_string
241273
if 0 <= subtoken < self._num_reserved_ids:
242-
return "%s_" % RESERVED_TOKENS[subtoken]
243-
return "ID%d_" % subtoken
274+
return u"%s_" % RESERVED_TOKENS[subtoken]
275+
return u"ID%d_" % subtoken
244276

245277
def _escaped_token_to_subtokens(self, escaped_token):
246278
"""Converts an escaped token string to a list of subtokens.
@@ -260,27 +292,11 @@ def _escaped_token_to_subtokens(self, escaped_token):
260292
if subtoken != -1:
261293
break
262294
end -= 1
263-
if end > pos:
264-
ret.append(subtoken)
265-
pos = end
266-
else:
267-
# No subtoken in the vocabulary matches escaped_token[pos].
268-
# This can happen if the token contains a Unicode character
269-
# that did not occur in the vocabulary training set.
270-
# The id self.vocab_size - 1 is decoded as Unicode uFFFD,
271-
# REPLACEMENT_CHARACTER.
272-
ret.append(self.vocab_size - 1)
273-
# Ensure that the outer loop continues
274-
pos += 1
275-
return ret
295+
assert end > pos
296+
ret.append(subtoken)
297+
pos = end
276298

277-
@classmethod
278-
def alphabet(cls, token_counts):
279-
"""Return the set of Unicode characters that appear in the tokens."""
280-
alphabet_set = set()
281-
for token in six.iterkeys(token_counts):
282-
alphabet_set |= set(token)
283-
return alphabet_set
299+
return ret
284300

285301
@classmethod
286302
def build_to_target_size(cls,
@@ -304,17 +320,12 @@ def build_to_target_size(cls,
304320
Returns:
305321
a SubwordTextEncoder instance.
306322
"""
307-
# Calculate the alphabet, i.e. the set of all Unicode characters
308-
# that appear in the tokens.
309-
alphabet_set = cls.alphabet(token_counts)
310-
tf.logging.info("Alphabet contains %d characters" % len(alphabet_set))
311-
312323
def bisect(min_val, max_val):
313324
"""Bisection to find the right size."""
314325
present_count = (max_val + min_val) // 2
315326
tf.logging.info("Trying min_count %d" % present_count)
316327
subtokenizer = cls()
317-
subtokenizer.build_from_token_counts(token_counts, alphabet_set,
328+
subtokenizer.build_from_token_counts(token_counts,
318329
present_count, num_iterations)
319330
if min_val >= max_val or subtokenizer.vocab_size == target_size:
320331
return subtokenizer
@@ -333,17 +344,29 @@ def bisect(min_val, max_val):
333344

334345
def build_from_token_counts(self,
335346
token_counts,
336-
alphabet_set,
337347
min_count,
338348
num_iterations=4):
339349
"""Train a SubwordTextEncoder based on a dictionary of word counts.
340350
341351
Args:
342352
token_counts: a dictionary of Unicode strings to int.
343-
alphabet_set: the set of Unicode characters that appear in the tokens.
344353
min_count: an integer - discard subtokens with lower counts.
345354
num_iterations: an integer. how many iterations of refinement.
346355
"""
356+
# first determine the alphabet to include all characters with count at
357+
# least min_count in the dataset.
358+
char_counts = defaultdict(int)
359+
for token, count in six.iteritems(token_counts):
360+
for c in token:
361+
char_counts[c] += count
362+
self._alphabet = set()
363+
for c, count in six.iteritems(char_counts):
364+
if count >= min_count:
365+
self._alphabet.add(c)
366+
# Make sure all characters needed for escaping are included
367+
for c in u"\\_;0123456789":
368+
self._alphabet.add(c)
369+
347370
# We build iteratively. On each iteration, we segment all the words,
348371
# then count the resulting potential subtokens, keeping the ones
349372
# with high enough counts for our new vocabulary.
@@ -367,43 +390,36 @@ def build_from_token_counts(self,
367390
for end in xrange(start + 1, len(escaped_token) + 1):
368391
subtoken_string = escaped_token[start:end]
369392
counts[subtoken_string] += count
393+
# Make sure all characters needed for escaping are included
394+
for c in self._alphabet:
395+
counts[c] += min_count
370396
# Array of sets of candidate subtoken strings, by length
371397
len_to_subtoken_strings = []
372398
for subtoken_string, count in six.iteritems(counts):
373399
lsub = len(subtoken_string)
374-
# All subtoken strings of length 1 are automatically included
375-
# later, so we don't need to consider them here
376-
if count < min_count or lsub <= 1:
377-
continue
378-
# Add this subtoken string to its length set
379-
while len(len_to_subtoken_strings) <= lsub:
380-
len_to_subtoken_strings.append(set())
381-
len_to_subtoken_strings[lsub].add(subtoken_string)
400+
if count >= min_count:
401+
# Add this subtoken string to its length set
402+
while len(len_to_subtoken_strings) <= lsub:
403+
len_to_subtoken_strings.append(set())
404+
len_to_subtoken_strings[lsub].add(subtoken_string)
382405
new_subtoken_strings = []
383406
# consider the candidates longest to shortest, so that if we accept
384407
# a longer subtoken string, we can decrement the counts of its prefixes.
385-
for subtoken_strings in reversed(len_to_subtoken_strings[2:]):
408+
for lsub in reversed(range(1, len(len_to_subtoken_strings))):
409+
subtoken_strings = len_to_subtoken_strings[lsub]
386410
for subtoken_string in subtoken_strings:
387411
count = counts[subtoken_string]
388-
if count < min_count:
389-
continue
390-
new_subtoken_strings.append((count, subtoken_string))
391-
for l in xrange(1, len(subtoken_string)):
392-
counts[subtoken_string[:l]] -= count
393-
# Sort what we've got so far in decreasing order by count
412+
if count >= min_count:
413+
new_subtoken_strings.append((count, subtoken_string))
414+
for l in xrange(1, lsub):
415+
counts[subtoken_string[:l]] -= count
416+
# Sort in decreasing order by count
394417
new_subtoken_strings.sort(reverse=True)
395-
# Add the alphabet set at the end of the vocabulary list
396-
for char in alphabet_set:
397-
new_subtoken_strings.append((0, char))
398-
# Also include the Unicode REPLACEMENT CHARACTER to use
399-
# when encountering previously unseen Unicode characters
400-
# in the input (i.e. input external to the tokenizer training
401-
# set, which may thus contain characters not in the alphabet_set).
402-
# This must be the last entry in the subtoken vocabulary list.
403-
new_subtoken_strings.append((0, u"\uFFFD"))
404418
# Now we have a candidate vocabulary
419+
old_alphabet = self._alphabet
405420
self._init_from_list([u""] * self._num_reserved_ids +
406421
[p[1] for p in new_subtoken_strings])
422+
assert old_alphabet == self._alphabet
407423
tf.logging.info("vocab_size = %d" % self.vocab_size)
408424

409425
original = "This sentence was encoded by the SubwordTextEncoder."
@@ -426,46 +442,77 @@ def _init_from_list(self, subtoken_strings):
426442
self._all_subtoken_strings = subtoken_strings
427443
self._subtoken_string_to_id = {
428444
s: i for i, s in enumerate(subtoken_strings) if s}
445+
self._alphabet = set([c for c in subtoken_strings if len(c) == 1])
429446

430447
def _load_from_file(self, filename):
431448
"""Load from a file."""
432449
subtoken_strings = []
433450
with tf.gfile.Open(filename) as f:
434451
for line in f:
435-
if six.PY2:
436-
subtoken_strings.append(line.strip()[1:-1].decode("utf-8"))
437-
else:
438-
subtoken_strings.append(line.strip()[1:-1])
452+
subtoken_strings.append(_native_to_unicode(line.strip()[1:-1]))
439453
self._init_from_list(subtoken_strings)
440454

441455
def store_to_file(self, filename):
442456
with tf.gfile.Open(filename, "w") as f:
443457
for subtoken_string in self._all_subtoken_strings:
444-
if six.PY2:
445-
f.write("'" + subtoken_string.encode("utf-8") + "'\n")
446-
else:
447-
f.write("'" + subtoken_string + "'\n")
458+
f.write("'" + _unicode_to_native(subtoken_string) + "'\n")
448459

449460
def _escape_token(self, token):
450-
r"""Translate '\'->'\\' and '_'->'\u', then append '_'.
461+
r"""Escape away underscores and OOV characters and append '_'.
462+
463+
This allows the token to be experessed as the concatenation of a list
464+
of subtokens from the vocabulary. The underscore acts as a sentinel
465+
which allows us to invertibly concatenate multiple such lists.
451466
452467
Args:
453-
token: a string
468+
token: a unicode string
454469
Returns:
455-
escaped_token: a string
470+
escaped_token: a unicode string
456471
"""
457-
return token.replace("\\", "\\\\").replace("_", "\\u") + "_"
472+
token = token.replace("\\", "\\\\").replace("_", "\\u") + "_"
473+
ret = u""
474+
for c in token:
475+
if c in self._alphabet:
476+
ret += c
477+
else:
478+
ret += u"\\%d;" % ord(c)
479+
return ret
458480

459481
def _unescape_token(self, escaped_token):
460-
r"""Remove '_' from end, then translate '\\'->'\' and '\u'->'_'.
482+
r"""Inverse of _escape_token().
461483
462484
Args:
463-
escaped_token: a string
485+
escaped_token: a unicode string
464486
Returns:
465-
token: a string
487+
token: a unicode string
466488
"""
467-
assert escaped_token[-1] == "_"
468-
return escaped_token[:-1].replace("\\u", "_").replace("\\\\", "\\")
489+
ret = u""
490+
escaped_token = escaped_token[:-1]
491+
pos = 0
492+
while pos < len(escaped_token):
493+
c = escaped_token[pos]
494+
if c == "\\":
495+
pos += 1
496+
c = escaped_token[pos]
497+
if c == u"u":
498+
ret += u"_"
499+
pos += 1
500+
elif c == "\\":
501+
ret += u"_"
502+
pos += 1
503+
else:
504+
semicolon_pos = escaped_token.find(u";", pos)
505+
if semicolon_pos == -1:
506+
continue
507+
try:
508+
ret += unichr(int(escaped_token[pos:semicolon_pos]))
509+
pos = semicolon_pos + 1
510+
except (ValueError, OverflowError) as _:
511+
pass
512+
else:
513+
ret += c
514+
pos += 1
515+
return ret
469516

470517
@classmethod
471518
def get_token_counts(cls, text_filepattern, corpus_max_lines):
@@ -477,7 +524,7 @@ def get_token_counts(cls, text_filepattern, corpus_max_lines):
477524
with tf.gfile.Open(text_filename) as f:
478525
for line in f:
479526
# The tokenizer updates token_counts in encode()
480-
tok.encode(line.strip())
527+
tok.encode(_native_to_unicode(line.strip()))
481528
lines_read += 1
482529
if corpus_max_lines > 0 and lines_read > corpus_max_lines:
483530
return tok.token_counts

tensor2tensor/data_generators/text_encoder_build_subword.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,8 +59,7 @@ def main(unused_argv):
5959
raise ValueError('Must provide --corpus_filepattern')
6060
token_counts = text_encoder.SubwordTextEncoder.get_token_counts(
6161
FLAGS.corpus_filepattern, FLAGS.corpus_max_lines)
62-
alphabet_set = text_encoder.SubwordTextEncoder.alphabet(token_counts)
63-
gs.build_from_token_counts(token_counts, alphabet_set,
62+
gs.build_from_token_counts(token_counts,
6463
FLAGS.min_count,
6564
FLAGS.num_iterations)
6665
gs.store_to_file(FLAGS.output_fn)

0 commit comments

Comments
 (0)