Skip to content

Commit a053196

Browse files
committed
Merge branch 'conv'
* conv: (39 commits) Refactoring. References #105 SortaGrad resumption optimization and refactoring. References #105 SortaGrad resumption optimization and refactoring. References #105 Fixed inference.py Inference.py prototype Corrected status print format Corrected status print format Named eval_dir after the evaluation target. closes #103 Refactoring Added SD and mean values Runconfig Updated mel feature calculation #100 Layer dimensions added. closes #89 Reduced samples to evaluate in sd estimation train.txt sorted, bucketed, params etc. #97 Sorted and removed long seq from train.txt #97 Removed wav_length estimator function Renamed utils folder to util estimate_bucket_sizes can now remove longer examples from .txt file Added README.md for used datasets #97 ...
2 parents b1cb25d + f0f3bbf commit a053196

17 files changed

+709
-344
lines changed

python/evaluate.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
import python.model as model
1919

2020

21-
# Which dataset *.txt file to use for evaluation. 'train' or 'validate'.
21+
# Which dataset *.txt file to use for evaluation. 'test' or 'validate'.
2222
EVALUATION_TARGET = 'test'
2323

2424

@@ -83,9 +83,9 @@ def evaluate_once(loss_op, mean_ed_op, wer_op, summary_op, summary_writer):
8383
wer_sum += wer_batch
8484
step += 1
8585

86-
print('{:%Y-%m-%d %H:%M:%S}: Step {:5,d} results: loss={:7.3f}; '
86+
print('{:%Y-%m-%d %H:%M:%S}: Step {:,d} of {:,d}; Results: loss={:7.3f}; '
8787
'mean_edit_distance={:5.3f}; WER={:5.3f}'
88-
.format(datetime.now(), step, loss_batch, mean_ed_batch, wer_batch))
88+
.format(datetime.now(), step, num_iter, loss_batch, mean_ed_batch, wer_batch))
8989

9090
# Compute error rates.
9191
avg_loss = loss_sum / num_iter
@@ -126,11 +126,11 @@ def evaluate(eval_dir):
126126
with tf.Graph().as_default() as graph:
127127
# Get evaluation sequences and ground truth.
128128
with tf.device('/cpu:0'):
129-
sequences, seq_length, labels, label_length, originals = model.inputs(
129+
sequences, _, labels, label_length, originals = model.inputs(
130130
target=EVALUATION_TARGET)
131131

132132
# Build a graph that computes the logits predictions from the inference model.
133-
logits = model.inference(sequences, seq_length, training=False)
133+
logits, seq_length = model.inference(sequences, training=False)
134134

135135
with tf.variable_scope('loss', reuse=tf.AUTO_REUSE):
136136
# Calculate error rates
@@ -156,7 +156,8 @@ def main(argv=None):
156156
"""TensorFlow starting routine."""
157157

158158
# Determine evaluation log directory.
159-
eval_dir = FLAGS.eval_dir if len(FLAGS.eval_dir) > 0 else '{}_eval'.format(FLAGS.train_dir)
159+
eval_dir = FLAGS.eval_dir if len(FLAGS.eval_dir) > 0 else '{}_{}'\
160+
.format(FLAGS.train_dir, EVALUATION_TARGET)
160161

161162
# Delete old evaluation data if requested.
162163
if tf.gfile.Exists(eval_dir) and FLAGS.delete:

python/inference.py

Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
"""Transcribe a given audio file."""
2+
3+
import os
4+
5+
import tensorflow as tf
6+
7+
from python.params import FLAGS, TF_FLOAT
8+
from python.loader.load_sample import load_sample, NUM_FEATURES
9+
# WarpCTC crashes during evaluation. Even if it's only imported and not actually being used.
10+
if FLAGS.use_warp_ctc:
11+
FLAGS.use_warp_ctc = False
12+
import python.model as model
13+
else:
14+
import python.model as model
15+
16+
17+
# File to transcribe.
18+
WAV_FILE = '/home/marc/workspace/datasets/speech_data/timit/TIMIT/TRAIN/DR4/FALR0/SA1.WAV'
19+
20+
21+
def transcribe_once(logits_op, decoded_op, plaintext_op, sequences, sequences_ph):
22+
"""Restore model from latest checkpoint and run the inference for the provided `sequence`.
23+
24+
Args:
25+
logits_op (tf.Tensor):
26+
Logits operator.
27+
decoded_op (tf.Tensor):
28+
Decoded operator.
29+
plaintext_op (tf.Tensor):
30+
Plaintext operator.
31+
sequences (List[np.ndarray]):
32+
Python list of 2D numpy arrays, each containing audio features.
33+
sequences_ph (tf.Tensor):
34+
Placeholder for the input sequences.
35+
36+
Returns:
37+
Nothing.
38+
"""
39+
# Session configuration.
40+
session_config = tf.ConfigProto(
41+
log_device_placement=False,
42+
gpu_options=tf.GPUOptions(allow_growth=True)
43+
)
44+
45+
with tf.Session(config=session_config) as sess:
46+
checkpoint = tf.train.get_checkpoint_state(FLAGS.train_dir)
47+
if checkpoint and checkpoint.model_checkpoint_path:
48+
saver = tf.train.Saver()
49+
50+
# Restore from checkpoint.
51+
saver.restore(sess, checkpoint.model_checkpoint_path)
52+
# Extract global stop from checkpoint.
53+
global_step = checkpoint.model_checkpoint_path.split('/')[-1].split('-')[-1]
54+
global_step = str(global_step)
55+
print('Loaded global step: {}, from checkpoint: {}'
56+
.format(global_step, FLAGS.train_dir))
57+
else:
58+
print('No checkpoint file found.')
59+
return
60+
61+
# Start the queue runners.
62+
coord = tf.train.Coordinator()
63+
threads = []
64+
try:
65+
for qr in tf.get_collection(tf.GraphKeys.QUEUE_RUNNERS):
66+
threads.extend(qr.create_threads(sess, coord=coord, daemon=True, start=True))
67+
68+
if not coord.should_stop():
69+
logits, decoded, plaintext = sess.run([logits_op, decoded_op, plaintext_op],
70+
feed_dict={sequences_ph: sequences})
71+
72+
print('Transcriptions {}:\n{}'.format(plaintext.shape, plaintext))
73+
74+
except Exception as e:
75+
print('EXCEPTION:', e, ', type:', type(e))
76+
coord.request_stop(e)
77+
78+
coord.request_stop()
79+
coord.join(threads, stop_grace_period_secs=120)
80+
81+
82+
def transcribe():
83+
"""Load an audio file and prepare the TensorFlow graph for inference.
84+
85+
Returns:
86+
Nothing.
87+
"""
88+
assert os.path.isfile(WAV_FILE)
89+
90+
with tf.Graph().as_default():
91+
# Get evaluation sequences and ground truth.
92+
with tf.device('/cpu:0'):
93+
# Load audio file into tensor.
94+
sequences, _ = load_sample(WAV_FILE)
95+
sequences = [sequences] * FLAGS.batch_size
96+
sequences_ph = tf.placeholder(dtype=TF_FLOAT,
97+
shape=[FLAGS.batch_size, None, NUM_FEATURES])
98+
99+
# Build a graph that computes the logits predictions from the inference model.
100+
logits_op, seq_length = model.inference(sequences_ph, training=False)
101+
102+
decoded_op, plaintext_op, _ = model.decode(logits_op, seq_length, originals=None)
103+
104+
transcribe_once(logits_op, decoded_op, plaintext_op, sequences, sequences_ph)
105+
106+
107+
# noinspection PyUnusedLocal
108+
def main(argv=None):
109+
"""TensorFlow starting routine."""
110+
transcribe()
111+
112+
113+
if __name__ == '__main__':
114+
main()

python/loader/audio_sample_info.py

Lines changed: 29 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,6 @@
33
Note that the network does not use `librosa`_ anymore, because it has problems
44
with concurrent sample loading. This module has not been updated yet.
55
6-
L8ER: Move away from librosa, use python_speech_features.
7-
86
.. _librosa:
97
https://librosa.github.io/librosa/index.html
108
"""
@@ -16,6 +14,8 @@
1614
from librosa import display
1715
from matplotlib import pyplot as plt
1816

17+
from python.loader import load_sample as ls
18+
1919

2020
DATASETS_PATH = '/home/marc/workspace/datasets/speech_data'
2121

@@ -36,7 +36,6 @@ def display_sample_info(file_path, label=''):
3636
raise ValueError('{} does not exist.'.format(file_path))
3737

3838
# By default, all audio is mixed to mono and resampled to 22050 Hz at load time.
39-
# y, sr = rosa.load(file_path, sr=None, mono=True)
4039
y, sr = rosa.load(file_path, sr=None, mono=True)
4140

4241
# At 16000 Hz, 512 samples ~= 32ms. At 16000 Hz, 200 samples = 12ms. 16 samples = 1ms @ 16kHz.
@@ -141,6 +140,31 @@ def display_sample_info(file_path, label=''):
141140
plt.colorbar(format='%+2.0f dB')
142141
plt.title('Mel spectrogram')
143142

143+
# Import project used features (python speech features).
144+
normalize_features = 'global'
145+
mfcc = ls.load_sample(file_path, feature_type='mfcc', normalize_features=normalize_features,
146+
normalize_signal=False)[0]
147+
mfcc = np.swapaxes(mfcc, 0, 1)
148+
149+
mel = ls.load_sample(file_path, feature_type='mel', normalize_features=normalize_features,
150+
normalize_signal=False)[0]
151+
mel = np.swapaxes(mel, 0, 1)
152+
153+
plt.figure(figsize=(12, 8))
154+
plt.subplot(2, 1, 1)
155+
display.specshow(mfcc, sr=16000, x_axis='time', y_axis='linear', hop_length=ls.WIN_STEP * 16000)
156+
# plt.set_cmap('magma')
157+
plt.xticks(rotation=295)
158+
plt.colorbar(format='%+2.0f')
159+
plt.title('MFCC')
160+
161+
plt.subplot(2, 1, 2)
162+
display.specshow(mel, sr=16000, x_axis='time', y_axis='linear', hop_length=ls.WIN_STEP * 16000)
163+
# plt.set_cmap('magma')
164+
plt.xticks(rotation=295)
165+
plt.colorbar(format='%+2.0f')
166+
plt.title('Mel')
167+
144168
plt.tight_layout()
145169
plt.show()
146170

@@ -151,7 +175,8 @@ def display_sample_info(file_path, label=''):
151175
# Display specific sample info's.
152176
with open(_test_txt_path, 'r') as f:
153177
_lines = f.readlines()
154-
_line = _lines[0]
178+
_line = _lines[len(_lines) // 5]
179+
# _line = _lines[1]
155180
_wav_path, txt = _line.split(' ', 1)
156181
_wav_path = os.path.join('/home/marc/workspace/datasets/speech_data', _wav_path)
157182
_txt = txt.strip()

python/loader/bucket_estimator.py

Lines changed: 35 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,8 @@
99
import numpy as np
1010
from matplotlib import pyplot as plt
1111

12-
from python.loader.load_sample import wav_length
12+
from python.loader.load_sample import load_sample
13+
from python.util.storage import delete_file_if_exists
1314

1415

1516
# Path to train.txt file.
@@ -18,31 +19,60 @@
1819
DATASET_PATH = '/home/marc/workspace/datasets/speech_data/'
1920

2021

21-
def estimate_bucket_sizes(num_buckets=32):
22+
def estimate_bucket_sizes(num_buckets=32, max_length=1750):
2223
"""Estimate optimal bucket sizes based on the samples in `train.txt` file.
2324
Results are printed out or plotted.
25+
Optional, if `max_length` is greater than `0`, audio examples with feature vectors longer than
26+
`max_length` are being removed from the .txt file.
2427
2528
Args:
2629
num_buckets (int): Number of buckets.
2730
Note that TensorFlow bucketing adds a smallest and largest bucket to the list.
31+
max_length (int): Maximum feature vector length of a preprocessed audio example.
32+
Longer ones are being removed from the .txt file.
33+
Set to `0` to disable removal.
2834
2935
Returns:
3036
Nothing.
3137
"""
3238
with open(TRAIN_TXT_PATH, 'r') as f:
3339
lines = f.readlines()
3440

41+
overlength_counter = 0
3542
lengths = []
43+
tmp_lines = []
3644

3745
# Progressbar
3846
for line in tqdm(lines, desc='Reading audio files', total=len(lines), file=sys.stdout,
3947
unit='files', dynamic_ncols=True):
40-
wav_path = line.split(' ', 1)[0]
48+
wav_path, label = line.split(' ', 1)
4149
wav_path = os.path.join(DATASET_PATH, wav_path)
42-
sample_len = wav_length(wav_path)
43-
lengths.append(sample_len)
50+
_, sample_len = load_sample(wav_path, feature_type='mel',
51+
normalize_features=False, normalize_signal=False)
52+
53+
if max_length > 0:
54+
if sample_len < max_length:
55+
lengths.append(sample_len)
56+
tmp_lines.append(line)
57+
else:
58+
overlength_counter += 1
59+
60+
else:
61+
lengths.append(sample_len)
62+
tmp_lines.append(line)
63+
4464
print() # Clear line from tqdm progressbar.
4565

66+
# Write reduced data back to .txt file, if selected.
67+
if max_length > 0:
68+
print('{} examples have a length greater than {} and have been removed from .txt file.'
69+
.format(overlength_counter, max_length))
70+
71+
delete_file_if_exists(TRAIN_TXT_PATH)
72+
with open(TRAIN_TXT_PATH, 'w') as f:
73+
f.writelines(tmp_lines)
74+
75+
print('Evaluated {} examples.'.format(len(lengths)))
4676
lengths = np.array(lengths)
4777
lengths = np.sort(lengths)
4878
step = len(lengths) // num_buckets

0 commit comments

Comments
 (0)