|
| 1 | +# coding=utf-8 |
| 2 | +# Copyright 2019 The Tensor2Tensor Authors. |
| 3 | +# |
| 4 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 5 | +# you may not use this file except in compliance with the License. |
| 6 | +# You may obtain a copy of the License at |
| 7 | +# |
| 8 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 9 | +# |
| 10 | +# Unless required by applicable law or agreed to in writing, software |
| 11 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 12 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 13 | +# See the License for the specific language governing permissions and |
| 14 | +# limitations under the License. |
| 15 | + |
| 16 | +"""Trax convolution layers.""" |
| 17 | + |
| 18 | +from __future__ import absolute_import |
| 19 | +from __future__ import division |
| 20 | +from __future__ import print_function |
| 21 | + |
| 22 | +import itertools |
| 23 | + |
| 24 | +from jax import lax |
| 25 | + |
| 26 | +import numpy as onp |
| 27 | +from tensor2tensor.trax.layers import base |
| 28 | +from tensor2tensor.trax.layers import initializers as init |
| 29 | + |
| 30 | + |
| 31 | +def PadtypeToPads(in_shape, window_shape, window_strides, padding): |
| 32 | + """Convert padding string to list of pairs of pad values.""" |
| 33 | + padding = padding.upper() |
| 34 | + if padding == 'SAME': |
| 35 | + out_shape = onp.ceil( |
| 36 | + onp.true_divide(in_shape, window_strides)).astype(int) |
| 37 | + pad_sizes = [max((out_size - 1) * stride + window_shape - in_size, 0) |
| 38 | + for out_size, stride, window_shape, in_size |
| 39 | + in zip(out_shape, window_strides, window_shape, in_shape)] |
| 40 | + return [(pad_size // 2, pad_size - pad_size // 2) |
| 41 | + for pad_size in pad_sizes] |
| 42 | + elif padding == 'VALID': |
| 43 | + return [(0, 0)] * len(in_shape) |
| 44 | + else: |
| 45 | + msg = 'Unknown padding type: {}.' |
| 46 | + raise TypeError(msg.format(padding)) |
| 47 | + |
| 48 | + |
| 49 | +class Conv(base.Layer): |
| 50 | + """Layer constructor function for a general convolution layer.""" |
| 51 | + |
| 52 | + def __init__(self, filters, kernel_size, strides=None, padding='VALID', |
| 53 | + dimension_numbers=('NHWC', 'HWIO', 'NHWC'), |
| 54 | + kernel_initializer=None, |
| 55 | + bias_initializer=init.RandomNormalInitializer(1e-6)): |
| 56 | + super(Conv, self).__init__() |
| 57 | + self._filters = filters |
| 58 | + self._kernel_size = kernel_size |
| 59 | + self._padding = padding |
| 60 | + self._dimension_numbers = dimension_numbers |
| 61 | + self._lhs_spec, self._rhs_spec, self._out_spec = dimension_numbers |
| 62 | + self._one = (1,) * len(kernel_size) |
| 63 | + self._strides = strides or self._one |
| 64 | + self._bias_initializer = bias_initializer |
| 65 | + rhs_spec = self._rhs_spec |
| 66 | + self._kernel_initializer = kernel_initializer |
| 67 | + if kernel_initializer is None: |
| 68 | + self._kernel_initializer = init.GlorotNormalInitializer( |
| 69 | + rhs_spec.index('O'), rhs_spec.index('I')) |
| 70 | + |
| 71 | + def call(self, x, params=(), **kwargs): |
| 72 | + del kwargs |
| 73 | + w, b = params |
| 74 | + return lax.conv_general_dilated( |
| 75 | + x, w, self._strides, self._padding, self._one, self._one, |
| 76 | + self._dimension_numbers) + b |
| 77 | + |
| 78 | + def _kernel_shape(self, input_shape): |
| 79 | + """Helper to calculate the kernel shape.""" |
| 80 | + kernel_size_iter = iter(self._kernel_size) |
| 81 | + return [self._filters if c == 'O' else |
| 82 | + input_shape[self._lhs_spec.index('C')] if c == 'I' else |
| 83 | + next(kernel_size_iter) for c in self._rhs_spec] |
| 84 | + |
| 85 | + def _conv_shape_tuple(self, lhs_shape, rhs_shape, strides, pads): |
| 86 | + """Compute the shape of a conv given input shapes in canonical order.""" |
| 87 | + if isinstance(pads, str): |
| 88 | + pads = PadtypeToPads(lhs_shape[2:], rhs_shape[2:], strides, pads) |
| 89 | + if len(pads) != len(lhs_shape) - 2: |
| 90 | + msg = 'Wrong number of explicit pads for conv: expected {}, got {}.' |
| 91 | + raise TypeError(msg.format(len(lhs_shape) - 2, len(pads))) |
| 92 | + lhs_padded = onp.add(lhs_shape[2:], onp.add(*zip(*pads))) |
| 93 | + out_space = onp.floor_divide( |
| 94 | + onp.subtract(lhs_padded, rhs_shape[2:]), strides) + 1 |
| 95 | + out_space = onp.maximum(0, out_space) |
| 96 | + out_shape = (lhs_shape[0], rhs_shape[0]) + tuple(out_space) |
| 97 | + return tuple(out_shape) |
| 98 | + |
| 99 | + def _conv_general_permutations(self, dimension_numbers): |
| 100 | + """Utility for convolution dimension permutations relative to Conv HLO.""" |
| 101 | + lhs_spec, rhs_spec, out_spec = dimension_numbers |
| 102 | + lhs_char, rhs_char, out_char = ('N', 'C'), ('O', 'I'), ('N', 'C') |
| 103 | + charpairs = (lhs_char, rhs_char, out_char) |
| 104 | + for i, (a, b) in enumerate(charpairs): |
| 105 | + if not (dimension_numbers[i].count(a) == 1 and |
| 106 | + dimension_numbers[i].count(b) == 1): |
| 107 | + msg = ('convolution dimension_numbers[{}] must contain the characters ' |
| 108 | + '"{}" and "{}" exatly once, got {}.') |
| 109 | + raise TypeError(msg.format(i, a, b, dimension_numbers[i])) |
| 110 | + if len(dimension_numbers[i]) != len(set(dimension_numbers[i])): |
| 111 | + msg = ('convolution dimension_numbers[{}] cannot have duplicate ' |
| 112 | + 'characters, got {}.') |
| 113 | + raise TypeError(msg.format(i, dimension_numbers[i])) |
| 114 | + if not (set(lhs_spec) - set(lhs_char) == set(rhs_spec) - set(rhs_char) == |
| 115 | + set(out_spec) - set(out_char)): |
| 116 | + msg = ('convolution dimension_numbers elements must each have the same ' |
| 117 | + 'set of spatial characters, got {}.') |
| 118 | + raise TypeError(msg.format(dimension_numbers)) |
| 119 | + |
| 120 | + def GetPerm(spec, charpair): |
| 121 | + spatial = (i for i, c in enumerate(spec) if c not in charpair) |
| 122 | + if spec is not rhs_spec: |
| 123 | + spatial = sorted(spatial, key=lambda i: rhs_spec.index(spec[i])) |
| 124 | + return (spec.index(charpair[0]), spec.index(charpair[1])) + tuple(spatial) |
| 125 | + |
| 126 | + lhs_perm, rhs_perm, out_perm = map(GetPerm, dimension_numbers, charpairs) |
| 127 | + return lhs_perm, rhs_perm, out_perm |
| 128 | + |
| 129 | + def _conv_general_shape_tuple(self, lhs_shape, rhs_shape, window_strides, |
| 130 | + padding, dimension_numbers): |
| 131 | + """Generalized computation of conv shape.""" |
| 132 | + lhs_perm, rhs_perm, out_perm = self._conv_general_permutations( |
| 133 | + dimension_numbers) |
| 134 | + lhs_trans = onp.take(lhs_shape, lhs_perm) |
| 135 | + rhs_trans = onp.take(rhs_shape, rhs_perm) |
| 136 | + out_trans = self._conv_shape_tuple( |
| 137 | + lhs_trans, rhs_trans, window_strides, padding) |
| 138 | + return tuple(onp.take(out_trans, onp.argsort(out_perm))) |
| 139 | + |
| 140 | + def output_shape(self, input_shape): |
| 141 | + kernel_shape = self._kernel_shape(input_shape) |
| 142 | + return self._conv_general_shape_tuple( |
| 143 | + input_shape, kernel_shape, |
| 144 | + self._strides, self._padding, self._dimension_numbers) |
| 145 | + |
| 146 | + def new_parameters(self, input_shape, rng): |
| 147 | + kernel_shape = self._kernel_shape(input_shape) |
| 148 | + bias_shape = [self._filters if c == 'C' else 1 for c in self._out_spec] |
| 149 | + bias_shape = tuple(itertools.dropwhile(lambda x: x == 1, bias_shape)) |
| 150 | + w = self._kernel_initializer(kernel_shape, rng) |
| 151 | + b = self._bias_initializer(bias_shape, rng) |
| 152 | + return (w, b) |
0 commit comments