|
| 1 | +# coding=utf-8 |
| 2 | +# Copyright 2020 The Tensor2Tensor Authors. |
| 3 | +# |
| 4 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 5 | +# you may not use this file except in compliance with the License. |
| 6 | +# You may obtain a copy of the License at |
| 7 | +# |
| 8 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 9 | +# |
| 10 | +# Unless required by applicable law or agreed to in writing, software |
| 11 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 12 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 13 | +# See the License for the specific language governing permissions and |
| 14 | +# limitations under the License. |
| 15 | + |
| 16 | +"""Residual Shuffle-Exchange Network. |
| 17 | +
|
| 18 | +Implementation of |
| 19 | +"Residual Shuffle-Exchange Networks for Fast Processing of Long Sequences" |
| 20 | +paper by A.Draguns, E.Ozolins, A.Sostaks, M.Apinis, K.Freivalds. |
| 21 | +
|
| 22 | +Paper: https://arxiv.org/abs/2004.04662 |
| 23 | +Original code: https://github.com/LUMII-Syslab/RSE |
| 24 | +""" |
| 25 | + |
| 26 | +from __future__ import absolute_import |
| 27 | +from __future__ import division |
| 28 | +from __future__ import print_function |
| 29 | + |
| 30 | +from tensor2tensor.models.research.shuffle_network import ShuffleNetwork |
| 31 | +from tensor2tensor.models.research.shuffle_network import shuffle_layer |
| 32 | +from tensor2tensor.models.research.shuffle_network import reverse_shuffle_layer |
| 33 | +from tensor2tensor.layers.common_layers import gelu |
| 34 | +from tensor2tensor.utils import registry |
| 35 | + |
| 36 | +import numpy as np |
| 37 | +import tensorflow.compat.v1 as tf |
| 38 | + |
| 39 | + |
| 40 | +class LayerNormalization(tf.keras.layers.Layer): |
| 41 | + """Layer Normalization (LayerNorm) without output bias and gain.""" |
| 42 | + |
| 43 | + def __init__(self, axis=1, epsilon=1e-10, **kwargs): |
| 44 | + """Initialize Layer Normalization layer. |
| 45 | +
|
| 46 | + Args: |
| 47 | + axis: Tuple or number of axis for calculating mean and variance |
| 48 | + epsilon: Small epsilon to avoid division by zero |
| 49 | + """ |
| 50 | + self.axis = axis |
| 51 | + self.epsilon = epsilon |
| 52 | + self.bias = None |
| 53 | + super(LayerNormalization, self).__init__(**kwargs) |
| 54 | + |
| 55 | + def build(self, input_shape): |
| 56 | + """ Initialize bias weights for layer normalization. |
| 57 | + Args: |
| 58 | + input_shape: shape of input tensor |
| 59 | + """ |
| 60 | + num_units = input_shape.as_list()[-1] |
| 61 | + self.bias = self.add_weight("bias", [1, 1, num_units], |
| 62 | + initializer=tf.zeros_initializer) |
| 63 | + super(LayerNormalization, self).build(input_shape) |
| 64 | + |
| 65 | + def call(self, inputs, **kwargs): |
| 66 | + """ Apply Layer Normalization without output bias and gain. |
| 67 | +
|
| 68 | + Args: |
| 69 | + inputs: tensor to be normalized. Axis should be smaller than input |
| 70 | + tensor dimensions. |
| 71 | + **kwargs: more arguments (unused) |
| 72 | + """ |
| 73 | + inputs -= tf.reduce_mean(inputs, axis=self.axis, keepdims=True) |
| 74 | + inputs += self.bias |
| 75 | + variance = tf.reduce_mean(tf.square(inputs), self.axis, keepdims=True) |
| 76 | + return inputs * tf.math.rsqrt(variance + self.epsilon) |
| 77 | + |
| 78 | + |
| 79 | +def inv_sigmoid(y): |
| 80 | + """Inverse sigmoid function. |
| 81 | +
|
| 82 | + Args: |
| 83 | + y: float in range 0 to 1 |
| 84 | + """ |
| 85 | + return np.log(y / (1 - y)) |
| 86 | + |
| 87 | + |
| 88 | +class RSU(tf.keras.layers.Layer): |
| 89 | + """Residual Switch Unit of Residual Shuffle-Exchange network.""" |
| 90 | + |
| 91 | + def __init__(self, prefix, dropout, mode, **kwargs): |
| 92 | + """Initialize Switch Layer. |
| 93 | +
|
| 94 | + Args: |
| 95 | + prefix: Name prefix for switch layer |
| 96 | + dropout: Dropout rate |
| 97 | + mode: Training mode |
| 98 | + **kwargs: more arguments (unused) |
| 99 | + """ |
| 100 | + super().__init__(**kwargs) |
| 101 | + self.prefix = prefix |
| 102 | + self.dropout = dropout |
| 103 | + self.mode = mode |
| 104 | + self.first_linear = None |
| 105 | + self.second_linear = None |
| 106 | + self.layer_norm = None |
| 107 | + self.residual_scale = None |
| 108 | + |
| 109 | + residual_weight = 0.9 |
| 110 | + self.candidate_weight = np.sqrt(1 - residual_weight ** 2) * 0.25 |
| 111 | + self.init_value = inv_sigmoid(residual_weight) |
| 112 | + |
| 113 | + def build(self, input_shape): |
| 114 | + """Initialize layer weights and sublayers. |
| 115 | +
|
| 116 | + Args: |
| 117 | + input_shape: shape of inputs |
| 118 | + """ |
| 119 | + in_units = input_shape[-1] |
| 120 | + middle_units = in_units * 4 |
| 121 | + out_units = in_units * 2 |
| 122 | + init = tf.variance_scaling_initializer(scale=1.0, mode="fan_avg", |
| 123 | + distribution="uniform") |
| 124 | + |
| 125 | + self.first_linear = tf.keras.layers.Dense(middle_units, |
| 126 | + use_bias=False, |
| 127 | + kernel_initializer=init, |
| 128 | + name=self.prefix + "/cand1") |
| 129 | + |
| 130 | + self.second_linear = tf.keras.layers.Dense(out_units, |
| 131 | + kernel_initializer=init, |
| 132 | + name=self.prefix + "/cand2") |
| 133 | + self.layer_norm = LayerNormalization() |
| 134 | + |
| 135 | + init = tf.constant_initializer(self.init_value) |
| 136 | + self.residual_scale = self.add_weight(self.prefix + "/residual", |
| 137 | + [out_units], initializer=init) |
| 138 | + super(RSU, self).build(input_shape) |
| 139 | + |
| 140 | + def call(self, inputs, **kwargs): |
| 141 | + """Apply Residual Switch Layer to inputs. |
| 142 | +
|
| 143 | + Args: |
| 144 | + inputs: Input tensor |
| 145 | +
|
| 146 | + Returns: |
| 147 | + tf.Tensor: New candidate value |
| 148 | + """ |
| 149 | + input_shape = tf.shape(inputs) |
| 150 | + batch_size = input_shape[0] |
| 151 | + length = input_shape[1] |
| 152 | + num_units = inputs.shape.as_list()[2] |
| 153 | + |
| 154 | + n_bits = tf.log(tf.cast(length - 1, tf.float32)) / tf.log(2.0) |
| 155 | + n_bits = tf.floor(n_bits) + 1 |
| 156 | + |
| 157 | + reshape_shape = [batch_size, length // 2, num_units * 2] |
| 158 | + reshaped_inputs = tf.reshape(inputs, reshape_shape) |
| 159 | + |
| 160 | + first_linear = self.first_linear(reshaped_inputs) |
| 161 | + first_linear = self.layer_norm(first_linear) |
| 162 | + first_linear = gelu(first_linear) |
| 163 | + candidate = self.second_linear(first_linear) |
| 164 | + |
| 165 | + residual = tf.sigmoid(self.residual_scale) * reshaped_inputs |
| 166 | + candidate = residual + candidate * self.candidate_weight |
| 167 | + candidate = tf.reshape(candidate, input_shape) |
| 168 | + |
| 169 | + if self.dropout > 0: |
| 170 | + candidate = tf.nn.dropout(candidate, rate=self.dropout / n_bits) |
| 171 | + if self.dropout != 0.0 and self.mode == tf.estimator.ModeKeys.TRAIN: |
| 172 | + noise = tf.random_normal(tf.shape(candidate), mean=1.0, stddev=0.001) |
| 173 | + candidate = candidate * noise |
| 174 | + |
| 175 | + return candidate |
| 176 | + |
| 177 | + |
| 178 | +def residual_shuffle_network(inputs, hparams): |
| 179 | + """Residual Shuffle-Exchange network with weight sharing. |
| 180 | +
|
| 181 | + Args: |
| 182 | + inputs: inputs to the Shuffle-Exchange network. Should be in length of power |
| 183 | + of 2. |
| 184 | + hparams: Model configuration |
| 185 | +
|
| 186 | + Returns: |
| 187 | + tf.Tensor: Outputs of the Shuffle-Exchange last layer |
| 188 | + """ |
| 189 | + input_shape = tf.shape(inputs) |
| 190 | + n_bits = tf.log(tf.cast(input_shape[1] - 1, tf.float32)) / tf.log(2.0) |
| 191 | + n_bits = tf.cast(n_bits, tf.int32) + 1 |
| 192 | + |
| 193 | + block_out = inputs |
| 194 | + |
| 195 | + for k in range(hparams.num_hidden_layers): |
| 196 | + with tf.variable_scope("benes_block_" + str(k), reuse=tf.AUTO_REUSE): |
| 197 | + forward_output = forward_part(block_out, hparams, n_bits) |
| 198 | + block_out = reverse_part(forward_output, hparams, n_bits) |
| 199 | + |
| 200 | + return RSU("last_layer", hparams.dropout, hparams.mode)(block_out) |
| 201 | + |
| 202 | + |
| 203 | +def reverse_part(inputs, hparams, n_bits): |
| 204 | + """ Reverse part of Beneš block. |
| 205 | +
|
| 206 | + Repeatably applies interleaved Residual Switch layer and Reverse Shuffle |
| 207 | + Layer. One set of weights used for all Switch layers. |
| 208 | +
|
| 209 | + Args: |
| 210 | + inputs: inputs for reverse part. Should be outputs from forward part. |
| 211 | + hparams: params of the network. |
| 212 | + n_bits: count of repeated layer applications. |
| 213 | +
|
| 214 | + Returns: |
| 215 | + tf.Tensor: output of reverse part. |
| 216 | + """ |
| 217 | + reverse_rsu = RSU("reverse_switch", hparams.dropout, hparams.mode) |
| 218 | + |
| 219 | + def reverse_step(state, _): |
| 220 | + with tf.variable_scope("reverse"): |
| 221 | + new_state = reverse_rsu(state) |
| 222 | + return reverse_shuffle_layer(new_state) |
| 223 | + |
| 224 | + reverse_outputs = tf.scan( |
| 225 | + reverse_step, |
| 226 | + tf.range(n_bits, n_bits * 2), |
| 227 | + initializer=inputs, |
| 228 | + parallel_iterations=1, |
| 229 | + swap_memory=True) |
| 230 | + |
| 231 | + return reverse_outputs[-1, :, :, :] |
| 232 | + |
| 233 | + |
| 234 | +def forward_part(block_out, hparams, n_bits): |
| 235 | + """ Forward part of Beneš block. |
| 236 | +
|
| 237 | + Repeatably applies interleaved Residual Switch layer and Shuffle |
| 238 | + Layer. One set of weights used for all Switch layers. |
| 239 | +
|
| 240 | + Args: |
| 241 | + inputs: inputs for forward part. Should be inputs from previous layers |
| 242 | + or Beneš block. |
| 243 | + hparams: params of the network. |
| 244 | + n_bits: count of repeated layer applications. |
| 245 | +
|
| 246 | + Returns: |
| 247 | + tf.Tensor: output of forward part. |
| 248 | + """ |
| 249 | + forward_rsu = RSU("switch", hparams.dropout, hparams.mode) |
| 250 | + |
| 251 | + def forward_step(state, _): |
| 252 | + with tf.variable_scope("forward"): |
| 253 | + new_state = forward_rsu(state) |
| 254 | + return shuffle_layer(new_state) |
| 255 | + |
| 256 | + forward_outputs = tf.scan( |
| 257 | + forward_step, |
| 258 | + tf.range(0, n_bits), |
| 259 | + initializer=block_out, |
| 260 | + parallel_iterations=1, |
| 261 | + swap_memory=True) |
| 262 | + |
| 263 | + return forward_outputs[-1, :, :, :] |
| 264 | + |
| 265 | + |
| 266 | +@registry.register_model |
| 267 | +class ResidualShuffleExchange(ShuffleNetwork): |
| 268 | + """T2T implementation of Residual Shuffle-Exchange network.""" |
| 269 | + |
| 270 | + def body(self, features): |
| 271 | + """Body of Residual Shuffle-Exchange network. |
| 272 | +
|
| 273 | + Args: |
| 274 | + features: dictionary of inputs and targets |
| 275 | + """ |
| 276 | + |
| 277 | + inputs = tf.squeeze(features["inputs"], axis=2) |
| 278 | + logits = residual_shuffle_network(inputs, self._hparams) |
| 279 | + return tf.expand_dims(logits, axis=2) |
0 commit comments