Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit 3410bea

Browse files
vthorsteinssonrsepassi
authored andcommitted
Python3 compatibility; better Unicode support (#22)
* Python3 compatibility fixes * Removed print statements * Fixes suggested by @rsepassi * Made Python3 ByteTextEncoder compatible with Python2 * Python3 compatibility fixes
1 parent 204b359 commit 3410bea

File tree

4 files changed

+66
-50
lines changed

4 files changed

+66
-50
lines changed

tensor2tensor/data_generators/generator_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,12 @@
2222
import io
2323
import os
2424
import tarfile
25-
import urllib
2625

2726
# Dependency imports
2827

2928
import six
3029
from six.moves import xrange # pylint: disable=redefined-builtin
30+
import six.moves.urllib_request
3131

3232
from tensor2tensor.data_generators.text_encoder import SubwordTextEncoder
3333
from tensor2tensor.data_generators.tokenizer import Tokenizer

tensor2tensor/data_generators/image.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
from __future__ import division
1919
from __future__ import print_function
2020

21-
import cPickle
2221
import gzip
2322
import io
2423
import json
@@ -32,6 +31,8 @@
3231
import numpy as np
3332
from six.moves import xrange # pylint: disable=redefined-builtin
3433
from six.moves import zip # pylint: disable=redefined-builtin
34+
from six.moves import cPickle
35+
3536
from tensor2tensor.data_generators import generator_utils
3637

3738
import tensorflow as tf

tensor2tensor/data_generators/text_encoder.py

100644100755
Lines changed: 55 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727

2828
import six
2929
from six.moves import xrange # pylint: disable=redefined-builtin
30+
from collections import defaultdict
3031
from tensor2tensor.data_generators import tokenizer
3132

3233
import tensorflow as tf
@@ -35,7 +36,10 @@
3536
PAD = '<pad>'
3637
EOS = '<EOS>'
3738
RESERVED_TOKENS = [PAD, EOS]
38-
39+
if six.PY2:
40+
RESERVED_TOKENS_BYTES = RESERVED_TOKENS
41+
else:
42+
RESERVED_TOKENS_BYTES = [bytes(PAD, 'ascii'), bytes(EOS, 'ascii')]
3943

4044
class TextEncoder(object):
4145
"""Base class for converting from ints to/from human readable strings."""
@@ -87,17 +91,25 @@ class ByteTextEncoder(TextEncoder):
8791
"""Encodes each byte to an id. For 8-bit strings only."""
8892

8993
def encode(self, s):
90-
return [ord(c) + self._num_reserved_ids for c in s]
94+
numres = self._num_reserved_ids
95+
if six.PY2:
96+
return [ord(c) + numres for c in s]
97+
# Python3: explicitly convert to UTF-8
98+
return [c + numres for c in s.encode("utf-8")]
9199

92100
def decode(self, ids):
101+
numres = self._num_reserved_ids
93102
decoded_ids = []
103+
int2byte = six.int2byte
94104
for id_ in ids:
95-
if 0 <= id_ < self._num_reserved_ids:
96-
decoded_ids.append(RESERVED_TOKENS[int(id_)])
105+
if 0 <= id_ < numres:
106+
decoded_ids.append(RESERVED_TOKENS_BYTES[int(id_)])
97107
else:
98-
decoded_ids.append(chr(id_))
99-
100-
return ''.join(decoded_ids)
108+
decoded_ids.append(int2byte(id_ - numres))
109+
if six.PY2:
110+
return ''.join(decoded_ids)
111+
# Python3: join byte arrays and then decode string
112+
return b''.join(decoded_ids).decode("utf-8")
101113

102114
@property
103115
def vocab_size(self):
@@ -111,20 +123,16 @@ def __init__(self, vocab_filename, reverse=False, num_reserved_ids=2):
111123
"""Initialize from a file, one token per line."""
112124
super(TokenTextEncoder, self).__init__(num_reserved_ids=num_reserved_ids)
113125
self._reverse = reverse
114-
if vocab_filename is not None:
115-
self._load_vocab_from_file(vocab_filename)
126+
self._load_vocab_from_file(vocab_filename)
116127

117128
def encode(self, sentence):
118129
"""Converts a space-separated string of tokens to a list of ids."""
119130
ret = [self._token_to_id[tok] for tok in sentence.strip().split()]
120-
if self._reverse:
121-
ret = ret[::-1]
122-
return ret
131+
return ret[::-1] if self._reverse else ret
123132

124133
def decode(self, ids):
125-
if self._reverse:
126-
ids = ids[::-1]
127-
return ' '.join([self._safe_id_to_token(i) for i in ids])
134+
seq = reversed(ids) if self._reverse else ids
135+
return ' '.join([self._safe_id_to_token(i) for i in seq])
128136

129137
@property
130138
def vocab_size(self):
@@ -243,15 +251,22 @@ def _escaped_token_to_subtokens(self, escaped_token):
243251
"""
244252
ret = []
245253
pos = 0
246-
while pos < len(escaped_token):
247-
end = len(escaped_token)
248-
while True:
254+
lesc = len(escaped_token)
255+
while pos < lesc:
256+
end = lesc
257+
while end > pos:
249258
subtoken = self._subtoken_string_to_id.get(escaped_token[pos:end], -1)
250259
if subtoken != -1:
251260
break
252261
end -= 1
253262
ret.append(subtoken)
254-
pos = end
263+
if end > pos:
264+
pos = end
265+
else:
266+
# This kinda should not happen, but it does. Cop out by skipping the
267+
# nonexistent subtoken from the returned list.
268+
# print("Unable to find subtoken in string '{0}'".format(escaped_token))
269+
pos += 1
255270
return ret
256271

257272
@classmethod
@@ -322,13 +337,13 @@ def build_from_token_counts(self,
322337
# then count the resulting potential subtokens, keeping the ones
323338
# with high enough counts for our new vocabulary.
324339
for i in xrange(num_iterations):
325-
counts = {}
340+
counts = defaultdict(int)
326341
for token, count in six.iteritems(token_counts):
327342
escaped_token = self._escape_token(token)
328343
# we will count all tails of the escaped_token, starting from boundaries
329344
# determined by our current segmentation.
330345
if i == 0:
331-
starts = list(range(len(escaped_token)))
346+
starts = xrange(len(escaped_token))
332347
else:
333348
subtokens = self._escaped_token_to_subtokens(escaped_token)
334349
pos = 0
@@ -337,31 +352,33 @@ def build_from_token_counts(self,
337352
starts.append(pos)
338353
pos += len(self.subtoken_to_subtoken_string(subtoken))
339354
for start in starts:
340-
for end in xrange(start + 1, len(escaped_token) + 1):
355+
for end in xrange(start + 1, len(escaped_token)):
341356
subtoken_string = escaped_token[start:end]
342-
counts[subtoken_string] = counts.get(subtoken_string, 0) + count
357+
counts[subtoken_string] += count
343358
# array of lists of candidate subtoken strings, by length
344359
len_to_subtoken_strings = []
345360
for subtoken_string, count in six.iteritems(counts):
346-
if count < min_count or len(subtoken_string) <= 1:
361+
lsub = len(subtoken_string)
362+
# all subtoken strings of length 1 are included regardless of count
363+
if count < min_count and lsub != 1:
347364
continue
348-
while len(len_to_subtoken_strings) <= len(subtoken_string):
365+
while len(len_to_subtoken_strings) <= lsub:
349366
len_to_subtoken_strings.append([])
350-
len_to_subtoken_strings[len(subtoken_string)].append(subtoken_string)
367+
len_to_subtoken_strings[lsub].append(subtoken_string)
351368
new_subtoken_strings = []
352369
# consider the candidates longest to shortest, so that if we accept
353370
# a longer subtoken string, we can decrement the counts of its prefixes.
354371
for subtoken_strings in len_to_subtoken_strings[::-1]:
355372
for subtoken_string in subtoken_strings:
356373
count = counts[subtoken_string]
357-
if count < min_count:
374+
if count < min_count and len(subtoken_string) != 1:
375+
# subtoken strings of length 1 are included regardless of count
358376
continue
359377
new_subtoken_strings.append((-count, subtoken_string))
360378
for l in xrange(1, len(subtoken_string)):
361379
counts[subtoken_string[:l]] -= count
362-
# make sure we have all single characters.
363-
new_subtoken_strings.extend([(-counts.get(chr(i), 0), chr(i))
364-
for i in xrange(2**8)])
380+
# Make sure to include the underscore as a subtoken string
381+
new_subtoken_strings.append((0, '_'))
365382
new_subtoken_strings.sort()
366383
self._init_from_list([''] * self._num_reserved_ids +
367384
[p[1] for p in new_subtoken_strings])
@@ -390,13 +407,19 @@ def _load_from_file(self, filename):
390407
subtoken_strings = []
391408
with tf.gfile.Open(filename) as f:
392409
for line in f:
393-
subtoken_strings.append(line.strip()[1:-1].decode('string-escape'))
410+
if six.PY2:
411+
subtoken_strings.append(line.strip()[1:-1].decode('string-escape'))
412+
else:
413+
subtoken_strings.append(line.strip()[1:-1])
394414
self._init_from_list(subtoken_strings)
395415

396416
def _store_to_file(self, filename):
397417
with tf.gfile.Open(filename, 'w') as f:
398418
for subtoken_string in self._all_subtoken_strings:
399-
f.write('\'' + subtoken_string.encode('string-escape') + '\'\n')
419+
if six.PY2:
420+
f.write('\'' + subtoken_string.encode('string-escape') + '\'\n')
421+
else:
422+
f.write('\'' + subtoken_string + '\'\n')
400423

401424
def _escape_token(self, token):
402425
r"""Translate '\'->'\\' and '_'->'\u', then append '_'.

tensor2tensor/data_generators/tokenizer.py

100644100755
Lines changed: 8 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -45,29 +45,21 @@
4545
from __future__ import division
4646
from __future__ import print_function
4747

48-
import array
4948
import string
5049

5150
# Dependency imports
5251

5352
from six.moves import xrange # pylint: disable=redefined-builtin
54-
53+
from collections import defaultdict
5554

5655
class Tokenizer(object):
5756
"""Vocab for breaking words into wordpieces.
5857
"""
5958

60-
def __init__(self):
61-
self._separator_chars = string.punctuation + string.whitespace
62-
self._separator_char_mask = array.array(
63-
"l", [chr(i) in self._separator_chars for i in xrange(256)])
64-
self.token_counts = dict()
59+
_SEPARATOR_CHAR_SET = set(string.punctuation + string.whitespace)
6560

66-
def _increment_token_count(self, token):
67-
if token in self.token_counts:
68-
self.token_counts[token] += 1
69-
else:
70-
self.token_counts[token] = 1
61+
def __init__(self):
62+
self.token_counts = defaultdict(int)
7163

7264
def encode(self, raw_text):
7365
"""Encode a raw string as a list of tokens.
@@ -87,11 +79,11 @@ def encode(self, raw_text):
8779
token = raw_text[token_start:pos]
8880
if token != " " or token_start == 0:
8981
ret.append(token)
90-
self._increment_token_count(token)
82+
self.token_counts[token] += 1
9183
token_start = pos
9284
final_token = raw_text[token_start:]
9385
ret.append(final_token)
94-
self._increment_token_count(final_token)
86+
self.token_counts[final_token] += 1
9587
return ret
9688

9789
def decode(self, tokens):
@@ -111,7 +103,7 @@ def decode(self, tokens):
111103
return ret
112104

113105
def _is_separator_char(self, c):
114-
return self._separator_char_mask[ord(c)]
106+
return c in self._SEPARATOR_CHAR_SET
115107

116108
def _is_word_char(self, c):
117-
return not self._is_separator_char(c)
109+
return c not in self._SEPARATOR_CHAR_SET

0 commit comments

Comments
 (0)