Skip to content

Commit 64782c3

Browse files
committed
Added RPRop algorithm for accelerating the solver.
1 parent bf527d2 commit 64782c3

File tree

1 file changed

+21
-10
lines changed

1 file changed

+21
-10
lines changed

MaskingMethods.py

Lines changed: 21 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,9 @@ def __init__(self, mX, sTarget, nResidual, psTarget = [], pnResidual = [], alpha
2222
self._alpha = alpha
2323
self._method = method
2424
self._iterations = 200
25-
self._lr = 2e-3
25+
self._lr = 3e-3#2e-3
26+
self._hetaplus = 1.2
27+
self._hetaminus = 0.5
2628

2729
def __call__(self, reverse = False):
2830

@@ -255,8 +257,8 @@ def phaseSensitive(self):
255257

256258
def optAlpha(self, initloss):
257259
"""
258-
A simple gradiend descent method, to find optimum power-spectral density exponents (alpha)
259-
for generalized wiener filtering.
260+
A simple gradiend descent method using the RProp algorithm,
261+
for finding optimum power-spectral density exponents (alpha) for generalized wiener filtering.
260262
Args:
261263
sTarget : (2D ndarray) Magnitude Spectrogram of the target component
262264
nResidual: (2D ndarray) Magnitude Spectrogram of the residual component or a list
@@ -273,8 +275,8 @@ def optAlpha(self, initloss):
273275
numElements = len(slist)
274276
slist = np.asarray(slist)
275277

276-
alpha = np.array([1.2] * (numElements)) # Initialize an array of alpha values to be found.
277-
dloss = np.array([0.] * (numElements)) # Initialize an array of loss functions to be used.
278+
alpha = np.array([1.15] * (numElements)) # Initialize an array of alpha values to be found.
279+
dloss = np.array([0.] * (numElements)) # Initialize an array of loss functions to be used.
278280
lrs = np.array([self._lr] * (numElements)) # Initialize an array of learning rates to be applied to each source.
279281

280282
# Begin of otpimization
@@ -291,7 +293,7 @@ def optAlpha(self, initloss):
291293

292294
alpha -= (lrs*dloss)
293295

294-
# Make sure of un-wanted values
296+
# Make sure the initial alpha are inside reasonable values
295297
alpha = np.clip(alpha, a_min = 0.5, a_max = 2.)
296298

297299
# Check IS Loss by computing Xhat
@@ -301,16 +303,25 @@ def optAlpha(self, initloss):
301303

302304
isloss.append(self._IS(Xhat))
303305
if (iter > 2):
306+
# Apply RProp
307+
if (isloss[-2] - isloss[-1] > 0):
308+
lrs *= self._hetaplus
309+
304310
if (isloss[-2] - isloss[-1] < 0):
305-
print('Local Minimum Found')
306-
alpha += (lrs * dloss)
307-
break
311+
lrs *= self._hetaminus
312+
313+
if (iter > 4):
314+
if (np.abs(isloss[-2] - isloss[-1]) < 1e-4 and np.abs(isloss[-3] - isloss[-2]) < 1e-4):
315+
print('Local Minimum Found')
316+
print('Final Loss: ' + str(isloss[-1]) + ' with characteristic exponent(s): ' + str(alpha))
317+
break
308318

309319
print('Loss: ' + str(isloss[-1]) + ' with characteristic exponent(s): ' + str(alpha))
310320

311321
# Evaluate Xhat for the mask update
312322
self._mask = np.divide((slist[0, :, :] ** alpha[0] + self._eps), (self._mX ** self._alpha + self._eps))
313-
self._closs = isloss
323+
self._closs = isloss[-1]
324+
self._alpha = alpha
314325

315326
def MWF(self):
316327
""" Multi-channel Wiener filtering as appears in:

0 commit comments

Comments
 (0)