Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit 280d2f7

Browse files
Lukasz Kaisercopybara-github
authored andcommitted
Trax: split layers into smaller files, move chunked Transformer to research, add tests.
PiperOrigin-RevId: 247111854
1 parent ed6343a commit 280d2f7

15 files changed

+770
-489
lines changed

tensor2tensor/trax/layers/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,5 +24,9 @@
2424
from tensor2tensor.trax.layers.attention import *
2525
from tensor2tensor.trax.layers.base import *
2626
from tensor2tensor.trax.layers.combinators import *
27+
from tensor2tensor.trax.layers.convolution import *
2728
from tensor2tensor.trax.layers.core import *
29+
from tensor2tensor.trax.layers.initializers import *
30+
from tensor2tensor.trax.layers.normalization import *
31+
from tensor2tensor.trax.layers.pooling import *
2832
from tensor2tensor.trax.layers.rnn import *

tensor2tensor/trax/layers/attention.py

Lines changed: 0 additions & 127 deletions
Original file line numberDiff line numberDiff line change
@@ -59,25 +59,6 @@ def EncoderDecoderMask(x, **unused_kwargs):
5959
return padding_mask + np.zeros((1, 1, decoder_input.shape[1], 1))
6060

6161

62-
# Layer normalization.
63-
def _layer_norm_new_params(input_shape, rng, epsilon=1e-6): # pylint: disable=invalid-name
64-
"""Helper: create layer norm parameters."""
65-
del rng, epsilon
66-
features = input_shape[-1]
67-
scale = np.ones(features)
68-
bias = np.zeros(features)
69-
return (scale, bias)
70-
71-
72-
@base.layer(new_parameters=_layer_norm_new_params)
73-
def LayerNorm(x, params, epsilon=1e-6, **unused_kwargs):
74-
(scale, bias) = params
75-
mean = np.mean(x, axis=-1, keepdims=True)
76-
variance = np.mean((x - mean)**2, axis=-1, keepdims=True)
77-
norm_inputs = (x - mean) / np.sqrt(variance + epsilon)
78-
return norm_inputs * scale + bias
79-
80-
8162
# Positional encoding.
8263
def _positional_encoding_new_params(input_shape, rng, max_len=2048): # pylint: disable=invalid-name
8364
"""Helper: create positional encoding parameters."""
@@ -271,114 +252,6 @@ def MultiHeadedAttention(
271252
)
272253

273254

274-
# Chunked attention.
275-
def _chunked_selector_output_shape( # pylint: disable=invalid-name
276-
input_shapes, selector=None, **unused_kwargs):
277-
"""Helper: calculate output shape for chunked key selector (see below)."""
278-
# Read the main function below first, the shape logic just follows the ops.
279-
selector = selector or (lambda x: [] if x < 1 else [x-1])
280-
triples, _ = zip(*input_shapes)
281-
(query_shapes, key_shapes, value_shapes) = zip(*triples)
282-
result = []
283-
for i in range(len(input_shapes)):
284-
selected = selector(i)
285-
cur_key_shape, cur_value_shape = key_shapes[i], value_shapes[i]
286-
# Since keys and values are [batch, length, depth] we concatenate on axis=1.
287-
new_key_len = sum([key_shapes[j][1] for j in selected]) + cur_key_shape[1]
288-
new_key_shape = (cur_key_shape[0], new_key_len, cur_key_shape[2])
289-
new_value_len = sum(
290-
[value_shapes[j][1] for j in selected]) + cur_value_shape[1]
291-
new_value_shape = (cur_value_shape[0], new_value_len, cur_value_shape[2])
292-
# Masks are (1, query-len, key-len).
293-
new_mask_shape = (1, query_shapes[i][1], new_key_len)
294-
new_shape = ((query_shapes[i], new_key_shape, new_value_shape),
295-
new_mask_shape)
296-
result.append(new_shape)
297-
return tuple(result)
298-
299-
300-
@base.layer(output_shape=_chunked_selector_output_shape)
301-
def ChunkedAttentionSelector(x, params, selector=None, **kwargs):
302-
"""Select which chunks to attend to in chunked attention.
303-
304-
Args:
305-
x: inputs, a list of elements of the form (q, k, v), mask for each chunk.
306-
params: parameters (unused).
307-
selector: a function from chunk_number -> list of chunk numbers that says
308-
which other chunks should be appended to the given one (previous if None).
309-
**kwargs: unused other arguments.
310-
311-
Returns:
312-
a list of elements of the form (q, k', v'), mask' where k', v' and mask' are
313-
concatenations of k, v and identity-extended masks from selected chunks.
314-
"""
315-
del params, kwargs
316-
selector = selector or (lambda x: [] if x < 1 else [x-1])
317-
triples, masks = zip(*x)
318-
(queries, keys, values) = zip(*triples)
319-
result = []
320-
for i in range(len(x)):
321-
selected = selector(i)
322-
# Since keys and values are [batch, length, depth] we concatenate on axis=1.
323-
# We also always include the current key or value at the end.
324-
new_key_list = [keys[j] for j in selected]
325-
new_key = np.concatenate(new_key_list + [keys[i]], axis=1)
326-
new_value = np.concatenate(
327-
[values[j] for j in selected] + [values[i]], axis=1)
328-
# Masks are (1, query-len, key-len) so we concatenate on axis=2.
329-
new_mask_shapes = [(1, queries[i].shape[1], key.shape[1])
330-
for key in new_key_list]
331-
cur_mask = masks[i]
332-
# Masks are all-1 for the added chunks (no masking).
333-
new_mask_list = [np.ones(s, dtype=cur_mask.dtype) for s in new_mask_shapes]
334-
# We still use the current (often causal) mask for the final chunk.
335-
new_mask = np.concatenate(new_mask_list + [cur_mask], axis=2)
336-
result.append(((queries[i], new_key, new_value), new_mask))
337-
return tuple(result)
338-
339-
340-
def ChunkedCausalMultiHeadedAttention(
341-
feature_depth, num_heads=8, dropout=0.0, chunk_selector=None, mode='train'):
342-
"""Transformer-style causal multi-headed attention operating on chunks.
343-
344-
Accepts inputs that are a list of chunks and applies causal attention.
345-
346-
Args:
347-
feature_depth: int: depth of embedding
348-
num_heads: int: number of attention heads
349-
dropout: float: dropout rate
350-
chunk_selector: a function from chunk number to list of chunks to attend.
351-
mode: str: 'train' or 'eval'
352-
353-
Returns:
354-
Multi-headed self-attention layer.
355-
"""
356-
prepare_attention_input = combinators.Serial(
357-
combinators.Branch(
358-
combinators.Branch( # q = k = v = first input
359-
combinators.Copy(), combinators.Copy(), combinators.Copy()),
360-
CausalMask(axis=-2), # pylint: disable=no-value-for-parameter
361-
),
362-
combinators.Parallel(
363-
combinators.Parallel(
364-
core.Dense(feature_depth),
365-
core.Dense(feature_depth),
366-
core.Dense(feature_depth),
367-
),
368-
combinators.Copy()
369-
)
370-
)
371-
return combinators.Serial(
372-
combinators.Map(prepare_attention_input),
373-
ChunkedAttentionSelector(selector=chunk_selector), # pylint: disable=no-value-for-parameter
374-
combinators.Map(PureMultiHeadedAttention( # pylint: disable=no-value-for-parameter
375-
feature_depth=feature_depth, num_heads=num_heads,
376-
dropout=dropout, mode=mode), check_shapes=False),
377-
combinators.Map(combinators.Select(0), check_shapes=False), # drop masks
378-
combinators.Map(core.Dense(feature_depth))
379-
)
380-
381-
382255
@base.layer()
383256
def ShiftRight(x, **unused_kwargs):
384257
"""Layer to shift the tensor to the right by padding on axis 1."""
Lines changed: 152 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,152 @@
1+
# coding=utf-8
2+
# Copyright 2019 The Tensor2Tensor Authors.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
"""Trax convolution layers."""
17+
18+
from __future__ import absolute_import
19+
from __future__ import division
20+
from __future__ import print_function
21+
22+
import itertools
23+
24+
from jax import lax
25+
26+
import numpy as onp
27+
from tensor2tensor.trax.layers import base
28+
from tensor2tensor.trax.layers import initializers as init
29+
30+
31+
def PadtypeToPads(in_shape, window_shape, window_strides, padding):
32+
"""Convert padding string to list of pairs of pad values."""
33+
padding = padding.upper()
34+
if padding == 'SAME':
35+
out_shape = onp.ceil(
36+
onp.true_divide(in_shape, window_strides)).astype(int)
37+
pad_sizes = [max((out_size - 1) * stride + window_shape - in_size, 0)
38+
for out_size, stride, window_shape, in_size
39+
in zip(out_shape, window_strides, window_shape, in_shape)]
40+
return [(pad_size // 2, pad_size - pad_size // 2)
41+
for pad_size in pad_sizes]
42+
elif padding == 'VALID':
43+
return [(0, 0)] * len(in_shape)
44+
else:
45+
msg = 'Unknown padding type: {}.'
46+
raise TypeError(msg.format(padding))
47+
48+
49+
class Conv(base.Layer):
50+
"""Layer constructor function for a general convolution layer."""
51+
52+
def __init__(self, filters, kernel_size, strides=None, padding='VALID',
53+
dimension_numbers=('NHWC', 'HWIO', 'NHWC'),
54+
kernel_initializer=None,
55+
bias_initializer=init.RandomNormalInitializer(1e-6)):
56+
super(Conv, self).__init__()
57+
self._filters = filters
58+
self._kernel_size = kernel_size
59+
self._padding = padding
60+
self._dimension_numbers = dimension_numbers
61+
self._lhs_spec, self._rhs_spec, self._out_spec = dimension_numbers
62+
self._one = (1,) * len(kernel_size)
63+
self._strides = strides or self._one
64+
self._bias_initializer = bias_initializer
65+
rhs_spec = self._rhs_spec
66+
self._kernel_initializer = kernel_initializer
67+
if kernel_initializer is None:
68+
self._kernel_initializer = init.GlorotNormalInitializer(
69+
rhs_spec.index('O'), rhs_spec.index('I'))
70+
71+
def call(self, x, params=(), **kwargs):
72+
del kwargs
73+
w, b = params
74+
return lax.conv_general_dilated(
75+
x, w, self._strides, self._padding, self._one, self._one,
76+
self._dimension_numbers) + b
77+
78+
def _kernel_shape(self, input_shape):
79+
"""Helper to calculate the kernel shape."""
80+
kernel_size_iter = iter(self._kernel_size)
81+
return [self._filters if c == 'O' else
82+
input_shape[self._lhs_spec.index('C')] if c == 'I' else
83+
next(kernel_size_iter) for c in self._rhs_spec]
84+
85+
def _conv_shape_tuple(self, lhs_shape, rhs_shape, strides, pads):
86+
"""Compute the shape of a conv given input shapes in canonical order."""
87+
if isinstance(pads, str):
88+
pads = PadtypeToPads(lhs_shape[2:], rhs_shape[2:], strides, pads)
89+
if len(pads) != len(lhs_shape) - 2:
90+
msg = 'Wrong number of explicit pads for conv: expected {}, got {}.'
91+
raise TypeError(msg.format(len(lhs_shape) - 2, len(pads)))
92+
lhs_padded = onp.add(lhs_shape[2:], onp.add(*zip(*pads)))
93+
out_space = onp.floor_divide(
94+
onp.subtract(lhs_padded, rhs_shape[2:]), strides) + 1
95+
out_space = onp.maximum(0, out_space)
96+
out_shape = (lhs_shape[0], rhs_shape[0]) + tuple(out_space)
97+
return tuple(out_shape)
98+
99+
def _conv_general_permutations(self, dimension_numbers):
100+
"""Utility for convolution dimension permutations relative to Conv HLO."""
101+
lhs_spec, rhs_spec, out_spec = dimension_numbers
102+
lhs_char, rhs_char, out_char = ('N', 'C'), ('O', 'I'), ('N', 'C')
103+
charpairs = (lhs_char, rhs_char, out_char)
104+
for i, (a, b) in enumerate(charpairs):
105+
if not (dimension_numbers[i].count(a) == 1 and
106+
dimension_numbers[i].count(b) == 1):
107+
msg = ('convolution dimension_numbers[{}] must contain the characters '
108+
'"{}" and "{}" exatly once, got {}.')
109+
raise TypeError(msg.format(i, a, b, dimension_numbers[i]))
110+
if len(dimension_numbers[i]) != len(set(dimension_numbers[i])):
111+
msg = ('convolution dimension_numbers[{}] cannot have duplicate '
112+
'characters, got {}.')
113+
raise TypeError(msg.format(i, dimension_numbers[i]))
114+
if not (set(lhs_spec) - set(lhs_char) == set(rhs_spec) - set(rhs_char) ==
115+
set(out_spec) - set(out_char)):
116+
msg = ('convolution dimension_numbers elements must each have the same '
117+
'set of spatial characters, got {}.')
118+
raise TypeError(msg.format(dimension_numbers))
119+
120+
def GetPerm(spec, charpair):
121+
spatial = (i for i, c in enumerate(spec) if c not in charpair)
122+
if spec is not rhs_spec:
123+
spatial = sorted(spatial, key=lambda i: rhs_spec.index(spec[i]))
124+
return (spec.index(charpair[0]), spec.index(charpair[1])) + tuple(spatial)
125+
126+
lhs_perm, rhs_perm, out_perm = map(GetPerm, dimension_numbers, charpairs)
127+
return lhs_perm, rhs_perm, out_perm
128+
129+
def _conv_general_shape_tuple(self, lhs_shape, rhs_shape, window_strides,
130+
padding, dimension_numbers):
131+
"""Generalized computation of conv shape."""
132+
lhs_perm, rhs_perm, out_perm = self._conv_general_permutations(
133+
dimension_numbers)
134+
lhs_trans = onp.take(lhs_shape, lhs_perm)
135+
rhs_trans = onp.take(rhs_shape, rhs_perm)
136+
out_trans = self._conv_shape_tuple(
137+
lhs_trans, rhs_trans, window_strides, padding)
138+
return tuple(onp.take(out_trans, onp.argsort(out_perm)))
139+
140+
def output_shape(self, input_shape):
141+
kernel_shape = self._kernel_shape(input_shape)
142+
return self._conv_general_shape_tuple(
143+
input_shape, kernel_shape,
144+
self._strides, self._padding, self._dimension_numbers)
145+
146+
def new_parameters(self, input_shape, rng):
147+
kernel_shape = self._kernel_shape(input_shape)
148+
bias_shape = [self._filters if c == 'C' else 1 for c in self._out_spec]
149+
bias_shape = tuple(itertools.dropwhile(lambda x: x == 1, bias_shape))
150+
w = self._kernel_initializer(kernel_shape, rng)
151+
b = self._bias_initializer(bias_shape, rng)
152+
return (w, b)
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
# coding=utf-8
2+
# Copyright 2019 The Tensor2Tensor Authors.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
"""Tests for convolution layers."""
17+
from __future__ import absolute_import
18+
from __future__ import division
19+
from __future__ import print_function
20+
21+
from absl.testing import absltest
22+
from tensor2tensor.trax.layers import base
23+
from tensor2tensor.trax.layers import convolution
24+
25+
26+
class ConvolutionLayerTest(absltest.TestCase):
27+
28+
def test_conv(self):
29+
input_shape = (29, 5, 5, 20)
30+
result_shape = base.check_shape_agreement(
31+
convolution.Conv(30, (3, 3)), input_shape)
32+
self.assertEqual(result_shape, (29, 3, 3, 30))
33+
34+
35+
if __name__ == "__main__":
36+
absltest.main()

0 commit comments

Comments
 (0)