Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit ea1c771

Browse files
authored
Residual Shuffle-Exchange network (#1805)
* Fix issue #1802 tensorflow._api.v1.compat.v1.compat' has no attribute 'v1. tf already imported as tf.compat.v1 and there is no need to use it explicitly. * T2T implementation of Residual Shuffle-Exchange networks. Publication: https://arxiv.org/abs/2004.04662 Original code: https://github.com/LUMII-Syslab/RSE
1 parent 4e172ee commit ea1c771

File tree

2 files changed

+280
-0
lines changed

2 files changed

+280
-0
lines changed

tensor2tensor/models/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@
4949
from tensor2tensor.models.research import neural_stack
5050
from tensor2tensor.models.research import rl
5151
from tensor2tensor.models.research import shuffle_network
52+
from tensor2tensor.models.research import residual_shuffle_exchange
5253
from tensor2tensor.models.research import similarity_transformer
5354
from tensor2tensor.models.research import super_lm
5455
from tensor2tensor.models.research import transformer_moe
Lines changed: 279 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,279 @@
1+
# coding=utf-8
2+
# Copyright 2020 The Tensor2Tensor Authors.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
"""Residual Shuffle-Exchange Network.
17+
18+
Implementation of
19+
"Residual Shuffle-Exchange Networks for Fast Processing of Long Sequences"
20+
paper by A.Draguns, E.Ozolins, A.Sostaks, M.Apinis, K.Freivalds.
21+
22+
Paper: https://arxiv.org/abs/2004.04662
23+
Original code: https://github.com/LUMII-Syslab/RSE
24+
"""
25+
26+
from __future__ import absolute_import
27+
from __future__ import division
28+
from __future__ import print_function
29+
30+
from tensor2tensor.models.research.shuffle_network import ShuffleNetwork
31+
from tensor2tensor.models.research.shuffle_network import shuffle_layer
32+
from tensor2tensor.models.research.shuffle_network import reverse_shuffle_layer
33+
from tensor2tensor.layers.common_layers import gelu
34+
from tensor2tensor.utils import registry
35+
36+
import numpy as np
37+
import tensorflow.compat.v1 as tf
38+
39+
40+
class LayerNormalization(tf.keras.layers.Layer):
41+
"""Layer Normalization (LayerNorm) without output bias and gain."""
42+
43+
def __init__(self, axis=1, epsilon=1e-10, **kwargs):
44+
"""Initialize Layer Normalization layer.
45+
46+
Args:
47+
axis: Tuple or number of axis for calculating mean and variance
48+
epsilon: Small epsilon to avoid division by zero
49+
"""
50+
self.axis = axis
51+
self.epsilon = epsilon
52+
self.bias = None
53+
super(LayerNormalization, self).__init__(**kwargs)
54+
55+
def build(self, input_shape):
56+
""" Initialize bias weights for layer normalization.
57+
Args:
58+
input_shape: shape of input tensor
59+
"""
60+
num_units = input_shape.as_list()[-1]
61+
self.bias = self.add_weight("bias", [1, 1, num_units],
62+
initializer=tf.zeros_initializer)
63+
super(LayerNormalization, self).build(input_shape)
64+
65+
def call(self, inputs, **kwargs):
66+
""" Apply Layer Normalization without output bias and gain.
67+
68+
Args:
69+
inputs: tensor to be normalized. Axis should be smaller than input
70+
tensor dimensions.
71+
**kwargs: more arguments (unused)
72+
"""
73+
inputs -= tf.reduce_mean(inputs, axis=self.axis, keepdims=True)
74+
inputs += self.bias
75+
variance = tf.reduce_mean(tf.square(inputs), self.axis, keepdims=True)
76+
return inputs * tf.math.rsqrt(variance + self.epsilon)
77+
78+
79+
def inv_sigmoid(y):
80+
"""Inverse sigmoid function.
81+
82+
Args:
83+
y: float in range 0 to 1
84+
"""
85+
return np.log(y / (1 - y))
86+
87+
88+
class RSU(tf.keras.layers.Layer):
89+
"""Residual Switch Unit of Residual Shuffle-Exchange network."""
90+
91+
def __init__(self, prefix, dropout, mode, **kwargs):
92+
"""Initialize Switch Layer.
93+
94+
Args:
95+
prefix: Name prefix for switch layer
96+
dropout: Dropout rate
97+
mode: Training mode
98+
**kwargs: more arguments (unused)
99+
"""
100+
super().__init__(**kwargs)
101+
self.prefix = prefix
102+
self.dropout = dropout
103+
self.mode = mode
104+
self.first_linear = None
105+
self.second_linear = None
106+
self.layer_norm = None
107+
self.residual_scale = None
108+
109+
residual_weight = 0.9
110+
self.candidate_weight = np.sqrt(1 - residual_weight ** 2) * 0.25
111+
self.init_value = inv_sigmoid(residual_weight)
112+
113+
def build(self, input_shape):
114+
"""Initialize layer weights and sublayers.
115+
116+
Args:
117+
input_shape: shape of inputs
118+
"""
119+
in_units = input_shape[-1]
120+
middle_units = in_units * 4
121+
out_units = in_units * 2
122+
init = tf.variance_scaling_initializer(scale=1.0, mode="fan_avg",
123+
distribution="uniform")
124+
125+
self.first_linear = tf.keras.layers.Dense(middle_units,
126+
use_bias=False,
127+
kernel_initializer=init,
128+
name=self.prefix + "/cand1")
129+
130+
self.second_linear = tf.keras.layers.Dense(out_units,
131+
kernel_initializer=init,
132+
name=self.prefix + "/cand2")
133+
self.layer_norm = LayerNormalization()
134+
135+
init = tf.constant_initializer(self.init_value)
136+
self.residual_scale = self.add_weight(self.prefix + "/residual",
137+
[out_units], initializer=init)
138+
super(RSU, self).build(input_shape)
139+
140+
def call(self, inputs, **kwargs):
141+
"""Apply Residual Switch Layer to inputs.
142+
143+
Args:
144+
inputs: Input tensor
145+
146+
Returns:
147+
tf.Tensor: New candidate value
148+
"""
149+
input_shape = tf.shape(inputs)
150+
batch_size = input_shape[0]
151+
length = input_shape[1]
152+
num_units = inputs.shape.as_list()[2]
153+
154+
n_bits = tf.log(tf.cast(length - 1, tf.float32)) / tf.log(2.0)
155+
n_bits = tf.floor(n_bits) + 1
156+
157+
reshape_shape = [batch_size, length // 2, num_units * 2]
158+
reshaped_inputs = tf.reshape(inputs, reshape_shape)
159+
160+
first_linear = self.first_linear(reshaped_inputs)
161+
first_linear = self.layer_norm(first_linear)
162+
first_linear = gelu(first_linear)
163+
candidate = self.second_linear(first_linear)
164+
165+
residual = tf.sigmoid(self.residual_scale) * reshaped_inputs
166+
candidate = residual + candidate * self.candidate_weight
167+
candidate = tf.reshape(candidate, input_shape)
168+
169+
if self.dropout > 0:
170+
candidate = tf.nn.dropout(candidate, rate=self.dropout / n_bits)
171+
if self.dropout != 0.0 and self.mode == tf.estimator.ModeKeys.TRAIN:
172+
noise = tf.random_normal(tf.shape(candidate), mean=1.0, stddev=0.001)
173+
candidate = candidate * noise
174+
175+
return candidate
176+
177+
178+
def residual_shuffle_network(inputs, hparams):
179+
"""Residual Shuffle-Exchange network with weight sharing.
180+
181+
Args:
182+
inputs: inputs to the Shuffle-Exchange network. Should be in length of power
183+
of 2.
184+
hparams: Model configuration
185+
186+
Returns:
187+
tf.Tensor: Outputs of the Shuffle-Exchange last layer
188+
"""
189+
input_shape = tf.shape(inputs)
190+
n_bits = tf.log(tf.cast(input_shape[1] - 1, tf.float32)) / tf.log(2.0)
191+
n_bits = tf.cast(n_bits, tf.int32) + 1
192+
193+
block_out = inputs
194+
195+
for k in range(hparams.num_hidden_layers):
196+
with tf.variable_scope("benes_block_" + str(k), reuse=tf.AUTO_REUSE):
197+
forward_output = forward_part(block_out, hparams, n_bits)
198+
block_out = reverse_part(forward_output, hparams, n_bits)
199+
200+
return RSU("last_layer", hparams.dropout, hparams.mode)(block_out)
201+
202+
203+
def reverse_part(inputs, hparams, n_bits):
204+
""" Reverse part of Beneš block.
205+
206+
Repeatably applies interleaved Residual Switch layer and Reverse Shuffle
207+
Layer. One set of weights used for all Switch layers.
208+
209+
Args:
210+
inputs: inputs for reverse part. Should be outputs from forward part.
211+
hparams: params of the network.
212+
n_bits: count of repeated layer applications.
213+
214+
Returns:
215+
tf.Tensor: output of reverse part.
216+
"""
217+
reverse_rsu = RSU("reverse_switch", hparams.dropout, hparams.mode)
218+
219+
def reverse_step(state, _):
220+
with tf.variable_scope("reverse"):
221+
new_state = reverse_rsu(state)
222+
return reverse_shuffle_layer(new_state)
223+
224+
reverse_outputs = tf.scan(
225+
reverse_step,
226+
tf.range(n_bits, n_bits * 2),
227+
initializer=inputs,
228+
parallel_iterations=1,
229+
swap_memory=True)
230+
231+
return reverse_outputs[-1, :, :, :]
232+
233+
234+
def forward_part(block_out, hparams, n_bits):
235+
""" Forward part of Beneš block.
236+
237+
Repeatably applies interleaved Residual Switch layer and Shuffle
238+
Layer. One set of weights used for all Switch layers.
239+
240+
Args:
241+
inputs: inputs for forward part. Should be inputs from previous layers
242+
or Beneš block.
243+
hparams: params of the network.
244+
n_bits: count of repeated layer applications.
245+
246+
Returns:
247+
tf.Tensor: output of forward part.
248+
"""
249+
forward_rsu = RSU("switch", hparams.dropout, hparams.mode)
250+
251+
def forward_step(state, _):
252+
with tf.variable_scope("forward"):
253+
new_state = forward_rsu(state)
254+
return shuffle_layer(new_state)
255+
256+
forward_outputs = tf.scan(
257+
forward_step,
258+
tf.range(0, n_bits),
259+
initializer=block_out,
260+
parallel_iterations=1,
261+
swap_memory=True)
262+
263+
return forward_outputs[-1, :, :, :]
264+
265+
266+
@registry.register_model
267+
class ResidualShuffleExchange(ShuffleNetwork):
268+
"""T2T implementation of Residual Shuffle-Exchange network."""
269+
270+
def body(self, features):
271+
"""Body of Residual Shuffle-Exchange network.
272+
273+
Args:
274+
features: dictionary of inputs and targets
275+
"""
276+
277+
inputs = tf.squeeze(features["inputs"], axis=2)
278+
logits = residual_shuffle_network(inputs, self._hparams)
279+
return tf.expand_dims(logits, axis=2)

0 commit comments

Comments
 (0)