Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit 029047a

Browse files
authored
Merge pull request #104 from lukaszkaiser/push
Release 1.0.11
2 parents 1692515 + b88c13b commit 029047a

File tree

11 files changed

+403
-316
lines changed

11 files changed

+403
-316
lines changed

tensor2tensor/data_generators/generator_utils.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
from six.moves import xrange # pylint: disable=redefined-builtin
3030
import six.moves.urllib_request as urllib # Imports urllib on Python2, urllib.request on Python3
3131

32-
from tensor2tensor.data_generators.text_encoder import SubwordTextEncoder
32+
from tensor2tensor.data_generators import text_encoder
3333
from tensor2tensor.data_generators.tokenizer import Tokenizer
3434

3535
import tensorflow as tf
@@ -218,15 +218,18 @@ def gunzip_file(gz_path, new_path):
218218
]
219219

220220

221-
def get_or_generate_vocab(tmp_dir, vocab_filename, vocab_size):
222-
"""Generate a vocabulary from the datasets listed in _DATA_FILE_URLS."""
221+
def get_or_generate_vocab(tmp_dir, vocab_filename, vocab_size, sources=None):
222+
"""Generate a vocabulary from the datasets in sources (_DATA_FILE_URLS)."""
223223
vocab_filepath = os.path.join(tmp_dir, vocab_filename)
224224
if os.path.exists(vocab_filepath):
225-
vocab = SubwordTextEncoder(vocab_filepath)
225+
tf.logging.info("Found vocab file: %s", vocab_filepath)
226+
vocab = text_encoder.SubwordTextEncoder(vocab_filepath)
226227
return vocab
227228

229+
sources = sources or _DATA_FILE_URLS
230+
tf.logging.info("Generating vocab from: %s", str(sources))
228231
tokenizer = Tokenizer()
229-
for source in _DATA_FILE_URLS:
232+
for source in sources:
230233
url = source[0]
231234
filename = os.path.basename(url)
232235
read_type = "r:gz" if "tgz" in filename else "r"
@@ -259,9 +262,9 @@ def get_or_generate_vocab(tmp_dir, vocab_filename, vocab_size):
259262
break
260263
line = line.strip()
261264
file_byte_budget -= len(line)
262-
_ = tokenizer.encode(line)
265+
_ = tokenizer.encode(text_encoder.native_to_unicode(line))
263266

264-
vocab = SubwordTextEncoder.build_to_target_size(
267+
vocab = text_encoder.SubwordTextEncoder.build_to_target_size(
265268
vocab_size, tokenizer.token_counts, 1, 1e3)
266269
vocab.store_to_file(vocab_filepath)
267270
return vocab

tensor2tensor/data_generators/problem_hparams.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -664,8 +664,17 @@ def image_mscoco_tokens(model_hparams, vocab_count):
664664
}
665665
p.batch_size_multiplier = 256
666666
p.max_expected_batch_size_per_shard = 2
667+
668+
669+
def img2img_imagenet(unused_model_hparams):
670+
"""Image 2 Image for imagenet dataset."""
671+
p = default_problem_hparams()
672+
p.input_modality = {"inputs": ("image:identity", None)}
673+
p.target_modality = ("image:identity", None)
674+
p.batch_size_multiplier = 256
675+
p.max_expected_batch_size_per_shard = 4
667676
p.input_space_id = 1
668-
p.target_space_id = 3
677+
p.target_space_id = 1
669678
return p
670679

671680

@@ -732,4 +741,5 @@ def image_mscoco_tokens(model_hparams, vocab_count):
732741
"image_mscoco_tokens_128k_tune": lambda p: image_mscoco_tokens(p, 2**17),
733742
"image_mscoco_tokens_128k_test": lambda p: image_mscoco_tokens(p, 2**17),
734743
"image_imagenet": image_imagenet,
744+
"img2img_imagenet": img2img_imagenet,
735745
}

tensor2tensor/data_generators/text_encoder.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -36,10 +36,10 @@
3636

3737

3838
# Conversion between Unicode and UTF-8, if required (on Python2)
39-
_native_to_unicode = (lambda s: s.decode("utf-8")) if PY2 else (lambda s: s)
39+
native_to_unicode = (lambda s: s.decode("utf-8")) if PY2 else (lambda s: s)
4040

4141

42-
_unicode_to_native = (lambda s: s.encode("utf-8")) if PY2 else (lambda s: s)
42+
unicode_to_native = (lambda s: s.encode("utf-8")) if PY2 else (lambda s: s)
4343

4444

4545
# Reserved tokens for things like padding and EOS symbols.
@@ -220,7 +220,7 @@ def encode(self, raw_text):
220220
a list of integers in the range [0, vocab_size)
221221
"""
222222
return self._tokens_to_subtokens(self._tokenizer.encode(
223-
_native_to_unicode(raw_text)))
223+
native_to_unicode(raw_text)))
224224

225225
def decode(self, subtokens):
226226
"""Converts a sequence of subtoken ids to a native string.
@@ -230,7 +230,7 @@ def decode(self, subtokens):
230230
Returns:
231231
a native string
232232
"""
233-
return _unicode_to_native(self._tokenizer.decode(
233+
return unicode_to_native(self._tokenizer.decode(
234234
self._subtokens_to_tokens(subtokens)))
235235

236236
@property
@@ -335,6 +335,9 @@ def bisect(min_val, max_val):
335335
else:
336336
other_subtokenizer = bisect(min_val, present_count - 1)
337337

338+
if other_subtokenizer is None:
339+
return subtokenizer
340+
338341
if (abs(other_subtokenizer.vocab_size - target_size) <
339342
abs(subtokenizer.vocab_size - target_size)):
340343
return other_subtokenizer
@@ -449,13 +452,13 @@ def _load_from_file(self, filename):
449452
subtoken_strings = []
450453
with tf.gfile.Open(filename) as f:
451454
for line in f:
452-
subtoken_strings.append(_native_to_unicode(line.strip()[1:-1]))
455+
subtoken_strings.append(native_to_unicode(line.strip()[1:-1]))
453456
self._init_from_list(subtoken_strings)
454457

455458
def store_to_file(self, filename):
456459
with tf.gfile.Open(filename, "w") as f:
457460
for subtoken_string in self._all_subtoken_strings:
458-
f.write("'" + _unicode_to_native(subtoken_string) + "'\n")
461+
f.write("'" + unicode_to_native(subtoken_string) + "'\n")
459462

460463
def _escape_token(self, token):
461464
r"""Escape away underscores and OOV characters and append '_'.
@@ -524,7 +527,7 @@ def get_token_counts(cls, text_filepattern, corpus_max_lines):
524527
with tf.gfile.Open(text_filename) as f:
525528
for line in f:
526529
# The tokenizer updates token_counts in encode()
527-
tok.encode(_native_to_unicode(line.strip()))
530+
tok.encode(native_to_unicode(line.strip()))
528531
lines_read += 1
529532
if corpus_max_lines > 0 and lines_read > corpus_max_lines:
530533
return tok.token_counts

0 commit comments

Comments
 (0)