Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit d578f52

Browse files
author
Ryan Sepassi
committed
internal merge
PiperOrigin-RevId: 160018490
1 parent b53d6df commit d578f52

File tree

4 files changed

+66
-47
lines changed

4 files changed

+66
-47
lines changed

tensor2tensor/data_generators/generator_utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828

2929
import six
3030
from six.moves import xrange # pylint: disable=redefined-builtin
31+
import six.moves.urllib_request
3132

3233
from tensor2tensor.data_generators.text_encoder import SubwordTextEncoder
3334
from tensor2tensor.data_generators.tokenizer import Tokenizer

tensor2tensor/data_generators/image.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
from __future__ import division
1919
from __future__ import print_function
2020

21-
import cPickle
2221
import gzip
2322
import io
2423
import json
@@ -30,6 +29,7 @@
3029
# Dependency imports
3130

3231
import numpy as np
32+
from six.moves import cPickle
3333
from six.moves import xrange # pylint: disable=redefined-builtin
3434
from six.moves import zip # pylint: disable=redefined-builtin
3535
from tensor2tensor.data_generators import generator_utils

tensor2tensor/data_generators/text_encoder.py

Lines changed: 56 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@
2323
from __future__ import division
2424
from __future__ import print_function
2525

26+
from collections import defaultdict
27+
2628
# Dependency imports
2729

2830
import six
@@ -35,6 +37,10 @@
3537
PAD = '<pad>'
3638
EOS = '<EOS>'
3739
RESERVED_TOKENS = [PAD, EOS]
40+
if six.PY2:
41+
RESERVED_TOKENS_BYTES = RESERVED_TOKENS
42+
else:
43+
RESERVED_TOKENS_BYTES = [bytes(PAD, 'ascii'), bytes(EOS, 'ascii')]
3844

3945

4046
class TextEncoder(object):
@@ -87,17 +93,25 @@ class ByteTextEncoder(TextEncoder):
8793
"""Encodes each byte to an id. For 8-bit strings only."""
8894

8995
def encode(self, s):
90-
return [ord(c) + self._num_reserved_ids for c in s]
96+
numres = self._num_reserved_ids
97+
if six.PY2:
98+
return [ord(c) + numres for c in s]
99+
# Python3: explicitly convert to UTF-8
100+
return [c + numres for c in s.encode('utf-8')]
91101

92102
def decode(self, ids):
103+
numres = self._num_reserved_ids
93104
decoded_ids = []
105+
int2byte = six.int2byte
94106
for id_ in ids:
95-
if 0 <= id_ < self._num_reserved_ids:
96-
decoded_ids.append(RESERVED_TOKENS[int(id_)])
107+
if 0 <= id_ < numres:
108+
decoded_ids.append(RESERVED_TOKENS_BYTES[int(id_)])
97109
else:
98-
decoded_ids.append(chr(id_))
99-
100-
return ''.join(decoded_ids)
110+
decoded_ids.append(int2byte(id_ - numres))
111+
if six.PY2:
112+
return ''.join(decoded_ids)
113+
# Python3: join byte arrays and then decode string
114+
return b''.join(decoded_ids).decode('utf-8')
101115

102116
@property
103117
def vocab_size(self):
@@ -111,20 +125,16 @@ def __init__(self, vocab_filename, reverse=False, num_reserved_ids=2):
111125
"""Initialize from a file, one token per line."""
112126
super(TokenTextEncoder, self).__init__(num_reserved_ids=num_reserved_ids)
113127
self._reverse = reverse
114-
if vocab_filename is not None:
115-
self._load_vocab_from_file(vocab_filename)
128+
self._load_vocab_from_file(vocab_filename)
116129

117130
def encode(self, sentence):
118131
"""Converts a space-separated string of tokens to a list of ids."""
119132
ret = [self._token_to_id[tok] for tok in sentence.strip().split()]
120-
if self._reverse:
121-
ret = ret[::-1]
122-
return ret
133+
return ret[::-1] if self._reverse else ret
123134

124135
def decode(self, ids):
125-
if self._reverse:
126-
ids = ids[::-1]
127-
return ' '.join([self._safe_id_to_token(i) for i in ids])
136+
seq = reversed(ids) if self._reverse else ids
137+
return ' '.join([self._safe_id_to_token(i) for i in seq])
128138

129139
@property
130140
def vocab_size(self):
@@ -243,15 +253,22 @@ def _escaped_token_to_subtokens(self, escaped_token):
243253
"""
244254
ret = []
245255
pos = 0
246-
while pos < len(escaped_token):
247-
end = len(escaped_token)
248-
while True:
256+
lesc = len(escaped_token)
257+
while pos < lesc:
258+
end = lesc
259+
while end > pos:
249260
subtoken = self._subtoken_string_to_id.get(escaped_token[pos:end], -1)
250261
if subtoken != -1:
251262
break
252263
end -= 1
253264
ret.append(subtoken)
254-
pos = end
265+
if end > pos:
266+
pos = end
267+
else:
268+
# This kinda should not happen, but it does. Cop out by skipping the
269+
# nonexistent subtoken from the returned list.
270+
# print("Unable to find subtoken in string '{0}'".format(escaped_token))
271+
pos += 1
255272
return ret
256273

257274
@classmethod
@@ -322,13 +339,13 @@ def build_from_token_counts(self,
322339
# then count the resulting potential subtokens, keeping the ones
323340
# with high enough counts for our new vocabulary.
324341
for i in xrange(num_iterations):
325-
counts = {}
342+
counts = defaultdict(int)
326343
for token, count in six.iteritems(token_counts):
327344
escaped_token = self._escape_token(token)
328345
# we will count all tails of the escaped_token, starting from boundaries
329346
# determined by our current segmentation.
330347
if i == 0:
331-
starts = list(range(len(escaped_token)))
348+
starts = xrange(len(escaped_token))
332349
else:
333350
subtokens = self._escaped_token_to_subtokens(escaped_token)
334351
pos = 0
@@ -337,31 +354,33 @@ def build_from_token_counts(self,
337354
starts.append(pos)
338355
pos += len(self.subtoken_to_subtoken_string(subtoken))
339356
for start in starts:
340-
for end in xrange(start + 1, len(escaped_token) + 1):
357+
for end in xrange(start + 1, len(escaped_token)):
341358
subtoken_string = escaped_token[start:end]
342-
counts[subtoken_string] = counts.get(subtoken_string, 0) + count
359+
counts[subtoken_string] += count
343360
# array of lists of candidate subtoken strings, by length
344361
len_to_subtoken_strings = []
345362
for subtoken_string, count in six.iteritems(counts):
346-
if count < min_count or len(subtoken_string) <= 1:
363+
lsub = len(subtoken_string)
364+
# all subtoken strings of length 1 are included regardless of count
365+
if count < min_count and lsub != 1:
347366
continue
348-
while len(len_to_subtoken_strings) <= len(subtoken_string):
367+
while len(len_to_subtoken_strings) <= lsub:
349368
len_to_subtoken_strings.append([])
350-
len_to_subtoken_strings[len(subtoken_string)].append(subtoken_string)
369+
len_to_subtoken_strings[lsub].append(subtoken_string)
351370
new_subtoken_strings = []
352371
# consider the candidates longest to shortest, so that if we accept
353372
# a longer subtoken string, we can decrement the counts of its prefixes.
354373
for subtoken_strings in len_to_subtoken_strings[::-1]:
355374
for subtoken_string in subtoken_strings:
356375
count = counts[subtoken_string]
357-
if count < min_count:
376+
if count < min_count and len(subtoken_string) != 1:
377+
# subtoken strings of length 1 are included regardless of count
358378
continue
359379
new_subtoken_strings.append((-count, subtoken_string))
360380
for l in xrange(1, len(subtoken_string)):
361381
counts[subtoken_string[:l]] -= count
362-
# make sure we have all single characters.
363-
new_subtoken_strings.extend([(-counts.get(chr(i), 0), chr(i))
364-
for i in xrange(2**8)])
382+
# Make sure to include the underscore as a subtoken string
383+
new_subtoken_strings.append((0, '_'))
365384
new_subtoken_strings.sort()
366385
self._init_from_list([''] * self._num_reserved_ids +
367386
[p[1] for p in new_subtoken_strings])
@@ -390,13 +409,19 @@ def _load_from_file(self, filename):
390409
subtoken_strings = []
391410
with tf.gfile.Open(filename) as f:
392411
for line in f:
393-
subtoken_strings.append(line.strip()[1:-1].decode('string-escape'))
412+
if six.PY2:
413+
subtoken_strings.append(line.strip()[1:-1].decode('string-escape'))
414+
else:
415+
subtoken_strings.append(line.strip()[1:-1])
394416
self._init_from_list(subtoken_strings)
395417

396418
def _store_to_file(self, filename):
397419
with tf.gfile.Open(filename, 'w') as f:
398420
for subtoken_string in self._all_subtoken_strings:
399-
f.write('\'' + subtoken_string.encode('string-escape') + '\'\n')
421+
if six.PY2:
422+
f.write('\'' + subtoken_string.encode('string-escape') + '\'\n')
423+
else:
424+
f.write('\'' + subtoken_string + '\'\n')
400425

401426
def _escape_token(self, token):
402427
r"""Translate '\'->'\\' and '_'->'\u', then append '_'.

tensor2tensor/data_generators/tokenizer.py

Lines changed: 8 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@
4545
from __future__ import division
4646
from __future__ import print_function
4747

48-
import array
48+
from collections import defaultdict
4949
import string
5050

5151
# Dependency imports
@@ -57,17 +57,10 @@ class Tokenizer(object):
5757
"""Vocab for breaking words into wordpieces.
5858
"""
5959

60-
def __init__(self):
61-
self._separator_chars = string.punctuation + string.whitespace
62-
self._separator_char_mask = array.array(
63-
"l", [chr(i) in self._separator_chars for i in xrange(256)])
64-
self.token_counts = dict()
60+
_SEPARATOR_CHAR_SET = set(string.punctuation + string.whitespace)
6561

66-
def _increment_token_count(self, token):
67-
if token in self.token_counts:
68-
self.token_counts[token] += 1
69-
else:
70-
self.token_counts[token] = 1
62+
def __init__(self):
63+
self.token_counts = defaultdict(int)
7164

7265
def encode(self, raw_text):
7366
"""Encode a raw string as a list of tokens.
@@ -87,11 +80,11 @@ def encode(self, raw_text):
8780
token = raw_text[token_start:pos]
8881
if token != " " or token_start == 0:
8982
ret.append(token)
90-
self._increment_token_count(token)
83+
self.token_counts[token] += 1
9184
token_start = pos
9285
final_token = raw_text[token_start:]
9386
ret.append(final_token)
94-
self._increment_token_count(final_token)
87+
self.token_counts[final_token] += 1
9588
return ret
9689

9790
def decode(self, tokens):
@@ -111,7 +104,7 @@ def decode(self, tokens):
111104
return ret
112105

113106
def _is_separator_char(self, c):
114-
return self._separator_char_mask[ord(c)]
107+
return c in self._SEPARATOR_CHAR_SET
115108

116109
def _is_word_char(self, c):
117-
return not self._is_separator_char(c)
110+
return c not in self._SEPARATOR_CHAR_SET

0 commit comments

Comments
 (0)