8
8
import malaya_speech .train as train
9
9
from malaya_speech .train .model .conformer .model import Model as ConformerModel
10
10
from malaya_speech .train .model import hubert
11
- import tensorflow .keras as keras
12
- import tensorflow .keras .backend as K
13
11
import numpy as np
14
12
import string
15
13
import json
27
25
test_set = glob ('/home/husein/youtube/voxceleb-wav/*.wav' )
28
26
29
27
sr = 16000
30
- maxlen = 18
31
- minlen = 3
32
- weight_decay = 1e-5
28
+ maxlen = 15
29
+ minlen = 2
30
+ kmean = hubert . kmeans . ApplyKmeans_TF ( 'kmean.km' )
33
31
34
32
35
33
def generate (files ):
36
34
while True :
37
35
random .shuffle (files )
38
36
for f in files :
39
37
f = f .decode () if isinstance (f , bytes ) else f
40
- x , _ = malaya_speech .load (f )
38
+ wav_data , _ = malaya_speech .load (f )
41
39
label = os .path .split (f )[1 ].replace ('wav-' , '' ).split ('-' )[1 ]
42
40
y = int (ids [label ])
43
41
44
- len_x = len (x )
42
+ len_x = len (wav_data ) / sr
45
43
46
- if ( len_x / sr ) < minlen :
44
+ if len_x < minlen :
47
45
continue
48
46
49
- if ( len_x / sr ) > maxlen :
50
- x = augmentation .random_sampling (x , sr , random .randint (1000 * minlen , 1000 * maxlen ))
47
+ if len_x > maxlen :
48
+ wav_data = augmentation .random_sampling (wav_data , sr , random .randint (1000 * minlen , 1000 * maxlen ))
51
49
52
50
yield {
53
- 'waveforms' : x ,
54
- 'waveforms_length' : [len (x )],
51
+ 'waveforms' : wav_data ,
52
+ 'waveforms_length' : [len (wav_data )],
55
53
'Y' : [y ],
56
54
}
57
55
58
56
59
- def get_dataset (files , batch_size = 4 , shuffle_size = 32 , thread_count = 24 ):
57
+ def preprocess_inputs (example ):
58
+ v = featurizer .vectorize (example ['waveforms' ])
59
+ deltas = malaya_speech .utils .tf_featurization .deltas (v )
60
+ ddeltas = malaya_speech .utils .tf_featurization .deltas (deltas )
61
+ concated = tf .concat ([v , deltas , ddeltas ], axis = 1 )
62
+ s = tf .compat .v1 .numpy_function (kmean , [concated ], tf .int64 )
63
+ s = tf .cast (s , tf .int32 )
64
+ kmean_tf = tf .reshape (s , (- 1 ,)) + 3
65
+ example ['targets' ] = kmean_tf
66
+ return example
67
+
68
+
69
+ def get_dataset (
70
+ file ,
71
+ batch_size = 4 ,
72
+ shuffle_size = 20 ,
73
+ thread_count = 24 ,
74
+ maxlen_feature = 1800 ,
75
+ ):
60
76
def get ():
61
77
dataset = tf .data .Dataset .from_generator (
62
78
generate ,
63
- {
64
- 'waveforms' : tf .float32 ,
65
- 'waveforms_length' : tf .int32 ,
66
- 'Y' : tf .int32 ,
67
- },
79
+ {'waveforms' : tf .float32 ,
80
+ 'waveforms_length' : tf .int32 ,
81
+ 'Y' : tf .int32 ,
82
+ },
68
83
output_shapes = {
69
84
'waveforms' : tf .TensorShape ([None ]),
70
85
'waveforms_length' : tf .TensorShape ([None ]),
71
86
'Y' : tf .TensorShape ([None ]),
72
87
},
73
- args = (files ,),
74
- )
75
- dataset = dataset .filter (
76
- lambda x : tf .less (tf .shape (x ['waveforms' ])[0 ] / sr , maxlen )
88
+ args = (file ,),
77
89
)
78
- dataset = dataset .filter (
79
- lambda x : tf .greater (tf .shape (x ['waveforms' ])[0 ] / sr , minlen )
90
+ dataset = dataset .prefetch (tf .contrib .data .AUTOTUNE )
91
+ dataset = dataset .map (
92
+ preprocess_inputs , num_parallel_calls = thread_count
80
93
)
81
94
dataset = dataset .padded_batch (
82
- shuffle_size ,
95
+ batch_size ,
83
96
padded_shapes = {
84
97
'waveforms' : tf .TensorShape ([None ]),
85
98
'waveforms_length' : tf .TensorShape ([None ]),
99
+ 'targets' : tf .TensorShape ([None ]),
86
100
'Y' : tf .TensorShape ([None ]),
87
101
},
88
102
padding_values = {
89
103
'waveforms' : tf .constant (0 , dtype = tf .float32 ),
90
104
'waveforms_length' : tf .constant (0 , dtype = tf .int32 ),
105
+ 'targets' : tf .constant (0 , dtype = tf .int32 ),
91
106
'Y' : tf .constant (0 , dtype = tf .int32 ),
92
107
},
93
108
)
94
109
return dataset
110
+
95
111
return get
96
112
97
113
@@ -107,12 +123,6 @@ def __call__(self, x, input_mask, training=True):
107
123
total_steps = 3000000
108
124
109
125
110
- def amsoftmax_loss (y_true , y_pred , scale = 30 , margin = 0.35 ):
111
- y_pred = y_true * (y_pred - margin ) + (1 - y_true ) * y_pred
112
- y_pred *= scale
113
- return K .categorical_crossentropy (y_true , y_pred , from_logits = True )
114
-
115
-
116
126
def model_fn (features , labels , mode , params ):
117
127
config_conformer = malaya_speech .config .conformer_base_encoder_config
118
128
config_conformer ['subsampling' ]['type' ] = 'none'
@@ -130,40 +140,47 @@ def model_fn(features, labels, mode, params):
130
140
model = hubert .Model (cfg , encoder , ['pad' , 'eos' , 'unk' ] + [str (i ) for i in range (100 )])
131
141
X = features ['waveforms' ]
132
142
X_len = features ['waveforms_length' ][:, 0 ]
143
+ Y = features ['targets' ]
144
+ r = model (X , padding_mask = X_len , target_list = Y )
145
+
146
+ target_m = tf .zeros ((tf .shape (r ['logit_m_list' ])[0 ],), dtype = tf .int32 )
147
+ target_u = tf .zeros ((tf .shape (r ['logit_u_list' ])[0 ],), dtype = tf .int32 )
148
+
149
+ sample_size = tf .cast (tf .shape (target_m )[0 ], tf .float32 )
150
+ entropy = tf .nn .sparse_softmax_cross_entropy_with_logits (labels = target_m , logits = r ['logit_m_list' ])
151
+ entropy_m = tf .reduce_sum (entropy ) / sample_size
152
+
153
+ sample_size = tf .cast (tf .shape (target_u )[0 ], tf .float32 )
154
+ entropy = tf .nn .sparse_softmax_cross_entropy_with_logits (labels = target_u , logits = r ['logit_u_list' ])
155
+ entropy_u = tf .reduce_sum (entropy ) / sample_size
156
+
157
+ seq = r ['x' ]
133
158
Y = features ['Y' ]
134
- Y_onehot = tf .one_hot (Y , depth = num_class )
135
-
136
- r = model (X , padding_mask = X_len , features_only = True , mask = False )
137
- first_token_tensor = tf .squeeze (r ['x' ][:, 0 :1 , :], axis = 1 )
138
- pooled_output = keras .layers .Dense (cfg .final_dim * 2 , activation = 'tanh' ,
139
- kernel_initializer = 'orthogonal' ,
140
- use_bias = True , trainable = True ,
141
- kernel_regularizer = keras .regularizers .l2 (weight_decay ),
142
- bias_regularizer = keras .regularizers .l2 (weight_decay ))(first_token_tensor )
143
- logits = keras .layers .Dense (num_class ,
144
- kernel_initializer = 'orthogonal' ,
145
- use_bias = False , trainable = True ,
146
- kernel_constraint = keras .constraints .unit_norm (),
147
- kernel_regularizer = keras .regularizers .l2 (weight_decay ),
148
- bias_regularizer = keras .regularizers .l2 (weight_decay ),
149
- name = 'prediction' )(pooled_output )
150
- loss = tf .reduce_mean (amsoftmax_loss (Y_onehot , logits ))
151
- accuracy = tf .metrics .accuracy (
152
- labels = Y , predictions = tf .argmax (logits , axis = 1 )
159
+ first_token_tensor = tf .squeeze (seq [:, 0 :1 , :], axis = 1 )
160
+ pooled_output = tf .keras .layers .Dense (embedding_dim , activation = 'tanh' ,
161
+ use_bias = True , trainable = True )(first_token_tensor )
162
+ logits = tf .keras .layers .Dense (num_class , trainable = True ,)(pooled_output )
163
+ entropy_speakers = tf .reduce_mean (
164
+ tf .nn .sparse_softmax_cross_entropy_with_logits (
165
+ logits = logits , labels = Y
166
+ )
153
167
)
154
168
155
- tf . identity ( accuracy [ 1 ], name = 'train_accuracy' )
169
+ loss = entropy_m * 0.95 + entropy_u * 0.05 + entropy_speakers
156
170
157
- tf .identity (loss , 'train_loss' )
171
+ tf .identity (entropy_m , 'entropy_m' )
172
+ tf .summary .scalar ('entropy_m' , entropy_m )
158
173
159
- variables = tf .get_collection ( tf . GraphKeys . TRAINABLE_VARIABLES )
160
- init_checkpoint = 'hubert-conformer-base-output-3mixed/model.ckpt-2000000'
174
+ tf .identity ( entropy_u , 'entropy_u' )
175
+ tf . summary . scalar ( 'entropy_u' , entropy_u )
161
176
162
- assignment_map , initialized_variable_names = train .get_assignment_map_from_checkpoint (
163
- variables , init_checkpoint
177
+ tf .identity (loss , 'train_loss' )
178
+
179
+ accuracy = tf .metrics .accuracy (
180
+ labels = Y , predictions = tf .argmax (logits , axis = 1 )
164
181
)
165
182
166
- tf .train . init_from_checkpoint ( init_checkpoint , assignment_map )
183
+ tf .identity ( accuracy [ 1 ], name = 'train_accuracy' )
167
184
168
185
if mode == tf .estimator .ModeKeys .TRAIN :
169
186
train_op = train .optimizer .adamw .create_optimizer (
@@ -195,7 +212,7 @@ def model_fn(features, labels, mode, params):
195
212
196
213
train_hooks = [
197
214
tf .train .LoggingTensorHook (
198
- ['train_accuracy' , 'train_loss' ], every_n_iter = 1
215
+ ['entropy_m' , 'entropy_u' , 'entropy_speakers' , ' train_accuracy' , 'train_loss' ], every_n_iter = 1
199
216
)
200
217
]
201
218
0 commit comments