Skip to content

Commit f0a00bb

Browse files
committed
Added conv dropout rate #105
1 parent a053196 commit f0a00bb

File tree

2 files changed

+7
-2
lines changed

2 files changed

+7
-2
lines changed

python/params.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,8 @@
4444
"""Weather to use Baidu's `warp_ctc_loss` or TensorFlow's `ctc_loss`.""")
4545

4646
# Dropout.
47+
tf.flags.DEFINE_float('conv_dropout_rate', 0.0,
48+
"""Dropout rate for convolutional layers.""")
4749
tf.flags.DEFINE_float('rnn_dropout_rate', 0.0,
4850
"""Dropout rate for the RNN cell layers.""")
4951
tf.flags.DEFINE_float('dense_dropout_rate', 0.1,

python/util/tf_contrib.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,8 @@ def _apply_dense(self, grad, var):
3939
def conv_layers(sequences, filters=FLAGS.conv_filters,
4040
kernel_sizes=((11, 41), (11, 21), (11, 21)),
4141
strides=((2, 2), (1, 2), (1, 2)),
42-
kernel_initializer=tf.glorot_normal_initializer(), kernel_regularizer=None):
42+
kernel_initializer=tf.glorot_normal_initializer(), kernel_regularizer=None,
43+
training=True):
4344
"""Add 2D convolutional layers to the network's graph. New sequence length are being calculated.
4445
4546
Convolutional layer output shapes:
@@ -72,6 +73,8 @@ def conv_layers(sequences, filters=FLAGS.conv_filters,
7273
TensorFlow kernel initializer.
7374
kernel_regularizer (tf.Tensor):
7475
TensorFlow kernel regularizer.
76+
training (bool):
77+
`FLAGS.conv_dropout_rate` is being applied during training only.
7578
7679
Returns:
7780
tf.Tensor: `output`
@@ -102,7 +105,7 @@ def conv_layers(sequences, filters=FLAGS.conv_filters,
102105
kernel_regularizer=kernel_regularizer)
103106

104107
output = tf.minimum(output, FLAGS.relu_cutoff)
105-
# output = tf.layers.dropout(output, rate=FLAGS.dense_dropout_rate, training=training)
108+
output = tf.layers.dropout(output, rate=FLAGS.conv_dropout_rate, training=training)
106109

107110
# Reshape to: conv3 = [batch_size, time, 10 * NUM_FILTERS], where 10 is the number of
108111
# frequencies left over from convolutions.

0 commit comments

Comments
 (0)