@@ -22,7 +22,9 @@ def __init__(self, mX, sTarget, nResidual, psTarget = [], pnResidual = [], alpha
22
22
self ._alpha = alpha
23
23
self ._method = method
24
24
self ._iterations = 200
25
- self ._lr = 2e-3
25
+ self ._lr = 3e-3 #2e-3
26
+ self ._hetaplus = 1.2
27
+ self ._hetaminus = 0.5
26
28
27
29
def __call__ (self , reverse = False ):
28
30
@@ -255,8 +257,8 @@ def phaseSensitive(self):
255
257
256
258
def optAlpha (self , initloss ):
257
259
"""
258
- A simple gradiend descent method, to find optimum power-spectral density exponents (alpha)
259
- for generalized wiener filtering.
260
+ A simple gradiend descent method using the RProp algorithm,
261
+ for finding optimum power-spectral density exponents (alpha) for generalized wiener filtering.
260
262
Args:
261
263
sTarget : (2D ndarray) Magnitude Spectrogram of the target component
262
264
nResidual: (2D ndarray) Magnitude Spectrogram of the residual component or a list
@@ -273,8 +275,8 @@ def optAlpha(self, initloss):
273
275
numElements = len (slist )
274
276
slist = np .asarray (slist )
275
277
276
- alpha = np .array ([1.2 ] * (numElements )) # Initialize an array of alpha values to be found.
277
- dloss = np .array ([0. ] * (numElements )) # Initialize an array of loss functions to be used.
278
+ alpha = np .array ([1.15 ] * (numElements )) # Initialize an array of alpha values to be found.
279
+ dloss = np .array ([0. ] * (numElements )) # Initialize an array of loss functions to be used.
278
280
lrs = np .array ([self ._lr ] * (numElements )) # Initialize an array of learning rates to be applied to each source.
279
281
280
282
# Begin of otpimization
@@ -291,7 +293,7 @@ def optAlpha(self, initloss):
291
293
292
294
alpha -= (lrs * dloss )
293
295
294
- # Make sure of un-wanted values
296
+ # Make sure the initial alpha are inside reasonable values
295
297
alpha = np .clip (alpha , a_min = 0.5 , a_max = 2. )
296
298
297
299
# Check IS Loss by computing Xhat
@@ -301,16 +303,25 @@ def optAlpha(self, initloss):
301
303
302
304
isloss .append (self ._IS (Xhat ))
303
305
if (iter > 2 ):
306
+ # Apply RProp
307
+ if (isloss [- 2 ] - isloss [- 1 ] > 0 ):
308
+ lrs *= self ._hetaplus
309
+
304
310
if (isloss [- 2 ] - isloss [- 1 ] < 0 ):
305
- print ('Local Minimum Found' )
306
- alpha += (lrs * dloss )
307
- break
311
+ lrs *= self ._hetaminus
312
+
313
+ if (iter > 4 ):
314
+ if (np .abs (isloss [- 2 ] - isloss [- 1 ]) < 1e-4 and np .abs (isloss [- 3 ] - isloss [- 2 ]) < 1e-4 ):
315
+ print ('Local Minimum Found' )
316
+ print ('Final Loss: ' + str (isloss [- 1 ]) + ' with characteristic exponent(s): ' + str (alpha ))
317
+ break
308
318
309
319
print ('Loss: ' + str (isloss [- 1 ]) + ' with characteristic exponent(s): ' + str (alpha ))
310
320
311
321
# Evaluate Xhat for the mask update
312
322
self ._mask = np .divide ((slist [0 , :, :] ** alpha [0 ] + self ._eps ), (self ._mX ** self ._alpha + self ._eps ))
313
- self ._closs = isloss
323
+ self ._closs = isloss [- 1 ]
324
+ self ._alpha = alpha
314
325
315
326
def MWF (self ):
316
327
""" Multi-channel Wiener filtering as appears in:
0 commit comments