Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit 62a0ee7

Browse files
authored
Merge pull request #177 from cshanbo/pos_emb
fix positional embedding
2 parents 8cc9a5e + ac038c5 commit 62a0ee7

File tree

1 file changed

+3
-0
lines changed

1 file changed

+3
-0
lines changed

tensor2tensor/models/common_attention.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,9 @@ def add_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4):
6565
tf.to_float(tf.range(num_timescales)) * -log_timescale_increment)
6666
scaled_time = tf.expand_dims(position, 1) * tf.expand_dims(inv_timescales, 0)
6767
signal = tf.concat([tf.sin(scaled_time), tf.cos(scaled_time)], axis=1)
68+
signal = tf.reshape(signal, [length, 2, num_timescales])
69+
signal = tf.transpose(signal, perm=[0, 2, 1])
70+
signal = tf.reshape(signal, [length, channels])
6871
signal = tf.pad(signal, [[0, 0], [0, tf.mod(channels, 2)]])
6972
signal = tf.reshape(signal, [1, length, channels])
7073
return x + signal

0 commit comments

Comments
 (0)