Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit 0fad290

Browse files
nshazeerRyan Sepassi
authored andcommitted
updated image transformer. now combines channels to have only 1024 positions
for rev-cifar instead of 3072. PiperOrigin-RevId: 159754350
1 parent a8463f5 commit 0fad290

File tree

2 files changed

+27
-17
lines changed

2 files changed

+27
-17
lines changed

tensor2tensor/models/common_attention.py

Lines changed: 24 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -271,10 +271,14 @@ def attention_image_summary(attn, image_shapes=None):
271271
272272
Args:
273273
attn: a Tensor with shape [batch, num_heads, query_length, memory_length]
274-
image_shapes: optional quadruple of integer scalars.
274+
image_shapes: optional tuple of integer scalars.
275275
If the query positions and memory positions represent the
276-
pixels of a flattened image, then pass in their dimensions:
276+
pixels of flattened images, then pass in their dimensions:
277277
(query_rows, query_cols, memory_rows, memory_cols).
278+
If the query positions and memory positions represent the
279+
pixels x channels of flattened images, then pass in their dimensions:
280+
(query_rows, query_cols, query_channels,
281+
memory_rows, memory_cols, memory_channels).
278282
"""
279283
num_heads = attn.get_shape().as_list()[1]
280284
# [batch, query_length, memory_length, num_heads]
@@ -286,10 +290,20 @@ def attention_image_summary(attn, image_shapes=None):
286290
image = split_last_dimension(image, 3)
287291
image = tf.reduce_max(image, 4)
288292
if image_shapes is not None:
289-
q_rows, q_cols, m_rows, m_cols = list(image_shapes)
290-
image = tf.reshape(image, [-1, q_rows, q_cols, m_rows, m_cols, 3])
291-
image = tf.transpose(image, [0, 1, 3, 2, 4, 5])
292-
image = tf.reshape(image, [-1, q_rows * m_rows, q_cols * m_cols, 3])
293+
if len(image_shapes) == 4:
294+
q_rows, q_cols, m_rows, m_cols = list(image_shapes)
295+
image = tf.reshape(image, [-1, q_rows, q_cols, m_rows, m_cols, 3])
296+
image = tf.transpose(image, [0, 1, 3, 2, 4, 5])
297+
image = tf.reshape(image, [-1, q_rows * m_rows, q_cols * m_cols, 3])
298+
else:
299+
assert len(image_shapes) == 6
300+
q_rows, q_cols, q_channnels, m_rows, m_cols, m_channels = list(
301+
image_shapes)
302+
image = tf.reshape(image, [-1, q_rows, q_cols, q_channnels,
303+
m_rows, m_cols, m_channels, 3])
304+
image = tf.transpose(image, [0, 1, 4, 3, 2, 5, 6, 7])
305+
image = tf.reshape(image, [-1, q_rows * m_rows * q_channnels,
306+
q_cols * m_cols * m_channels, 3])
293307
tf.summary.image("attention", image, max_outputs=1)
294308

295309

@@ -310,10 +324,8 @@ def dot_product_attention(q,
310324
bias: bias Tensor (see attention_bias())
311325
dropout_rate: a floating point number
312326
summaries: a boolean
313-
image_shapes: optional quadruple of integer scalars for image summary.
314-
If the query positions and memory positions represent the
315-
pixels of a flattened image, then pass in their dimensions:
316-
(query_rows, query_cols, memory_rows, memory_cols).
327+
image_shapes: optional tuple of integer scalars.
328+
see comments for attention_image_summary()
317329
name: an optional string
318330
319331
Returns:
@@ -356,10 +368,8 @@ def multihead_attention(query_antecedent,
356368
num_heads: an integer dividing total_key_depth and total_value_depth
357369
dropout_rate: a floating point number
358370
summaries: a boolean
359-
image_shapes: optional quadruple of integer scalars for image summary.
360-
If the query positions and memory positions represent the
361-
pixels of a flattened image, then pass in their dimensions:
362-
(query_rows, query_cols, memory_rows, memory_cols).
371+
image_shapes: optional tuple of integer scalars.
372+
see comments for attention_image_summary()
363373
name: an optional string
364374
365375
Returns:

tensor2tensor/models/modalities.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -441,8 +441,8 @@ class IdentityModality(modality.Modality):
441441
def targets_dimensionality(self):
442442
return self._vocab_size
443443

444-
def inputs_bottom_simple(self, inputs):
445-
return tf.to_float(inputs)
444+
def bottom(self, x):
445+
return tf.to_float(x)
446446

447-
def targets_top_simple(self, body_output, _):
447+
def top(self, body_output, _):
448448
return body_output

0 commit comments

Comments
 (0)