@@ -271,10 +271,14 @@ def attention_image_summary(attn, image_shapes=None):
271
271
272
272
Args:
273
273
attn: a Tensor with shape [batch, num_heads, query_length, memory_length]
274
- image_shapes: optional quadruple of integer scalars.
274
+ image_shapes: optional tuple of integer scalars.
275
275
If the query positions and memory positions represent the
276
- pixels of a flattened image , then pass in their dimensions:
276
+ pixels of flattened images , then pass in their dimensions:
277
277
(query_rows, query_cols, memory_rows, memory_cols).
278
+ If the query positions and memory positions represent the
279
+ pixels x channels of flattened images, then pass in their dimensions:
280
+ (query_rows, query_cols, query_channels,
281
+ memory_rows, memory_cols, memory_channels).
278
282
"""
279
283
num_heads = attn .get_shape ().as_list ()[1 ]
280
284
# [batch, query_length, memory_length, num_heads]
@@ -286,10 +290,20 @@ def attention_image_summary(attn, image_shapes=None):
286
290
image = split_last_dimension (image , 3 )
287
291
image = tf .reduce_max (image , 4 )
288
292
if image_shapes is not None :
289
- q_rows , q_cols , m_rows , m_cols = list (image_shapes )
290
- image = tf .reshape (image , [- 1 , q_rows , q_cols , m_rows , m_cols , 3 ])
291
- image = tf .transpose (image , [0 , 1 , 3 , 2 , 4 , 5 ])
292
- image = tf .reshape (image , [- 1 , q_rows * m_rows , q_cols * m_cols , 3 ])
293
+ if len (image_shapes ) == 4 :
294
+ q_rows , q_cols , m_rows , m_cols = list (image_shapes )
295
+ image = tf .reshape (image , [- 1 , q_rows , q_cols , m_rows , m_cols , 3 ])
296
+ image = tf .transpose (image , [0 , 1 , 3 , 2 , 4 , 5 ])
297
+ image = tf .reshape (image , [- 1 , q_rows * m_rows , q_cols * m_cols , 3 ])
298
+ else :
299
+ assert len (image_shapes ) == 6
300
+ q_rows , q_cols , q_channnels , m_rows , m_cols , m_channels = list (
301
+ image_shapes )
302
+ image = tf .reshape (image , [- 1 , q_rows , q_cols , q_channnels ,
303
+ m_rows , m_cols , m_channels , 3 ])
304
+ image = tf .transpose (image , [0 , 1 , 4 , 3 , 2 , 5 , 6 , 7 ])
305
+ image = tf .reshape (image , [- 1 , q_rows * m_rows * q_channnels ,
306
+ q_cols * m_cols * m_channels , 3 ])
293
307
tf .summary .image ("attention" , image , max_outputs = 1 )
294
308
295
309
@@ -310,10 +324,8 @@ def dot_product_attention(q,
310
324
bias: bias Tensor (see attention_bias())
311
325
dropout_rate: a floating point number
312
326
summaries: a boolean
313
- image_shapes: optional quadruple of integer scalars for image summary.
314
- If the query positions and memory positions represent the
315
- pixels of a flattened image, then pass in their dimensions:
316
- (query_rows, query_cols, memory_rows, memory_cols).
327
+ image_shapes: optional tuple of integer scalars.
328
+ see comments for attention_image_summary()
317
329
name: an optional string
318
330
319
331
Returns:
@@ -356,10 +368,8 @@ def multihead_attention(query_antecedent,
356
368
num_heads: an integer dividing total_key_depth and total_value_depth
357
369
dropout_rate: a floating point number
358
370
summaries: a boolean
359
- image_shapes: optional quadruple of integer scalars for image summary.
360
- If the query positions and memory positions represent the
361
- pixels of a flattened image, then pass in their dimensions:
362
- (query_rows, query_cols, memory_rows, memory_cols).
371
+ image_shapes: optional tuple of integer scalars.
372
+ see comments for attention_image_summary()
363
373
name: an optional string
364
374
365
375
Returns:
0 commit comments