Skip to content

Commit 7548357

Browse files
committed
[src] Implement ssspy.linalg.solve for backward compatibility
1 parent 3918b2b commit 7548357

File tree

5 files changed

+46
-19
lines changed

5 files changed

+46
-19
lines changed

ssspy/bss/_update_spatial_model.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,9 @@
33

44
import numpy as np
55

6-
from ..linalg import eigh2, inv2
6+
from ..linalg._solve import solve
7+
from ..linalg.eigh import eigh2
8+
from ..linalg.inv import inv2
79
from ..linalg.lqpqm import lqpqm2
810
from ..special.flooring import identity, max_flooring
911
from ..special.psd import to_psd
@@ -64,7 +66,7 @@ def update_by_ip1(
6466
e_n = E[:, src_idx, :] # (n_bins, n_n_channels)
6567

6668
WU = W @ U_n
67-
w_n = np.linalg.solve(WU, e_n) # (n_bins, n_channels)
69+
w_n = solve(WU, e_n) # (n_bins, n_channels)
6870
wUw = w_n[:, np.newaxis, :].conj() @ U_n @ w_n[:, :, np.newaxis]
6971
wUw = np.real(wUw[..., 0])
7072
wUw = np.maximum(wUw, 0)
@@ -358,8 +360,8 @@ def update_by_ip2_one_pair(
358360
WU_m = W @ U_m
359361
WU_n = W @ U_n
360362

361-
P_m = np.linalg.solve(WU_m, E_mn)
362-
P_n = np.linalg.solve(WU_n, E_mn)
363+
P_m = solve(WU_m, E_mn)
364+
P_n = solve(WU_n, E_mn)
363365

364366
PUP_m = P_m.transpose(0, 2, 1).conj() @ U_m @ P_m
365367
PUP_n = P_n.transpose(0, 2, 1).conj() @ U_n @ P_n
@@ -457,7 +459,7 @@ def update_by_ipa(
457459
C_n = d_n @ E_n
458460
d_n = d_n[:, :, source_idx]
459461

460-
Cd_n = np.linalg.solve(C_n, d_n)
462+
Cd_n = solve(C_n, d_n)
461463
dCd_n = np.sum(d_n.conj() * Cd_n, axis=-1)
462464
dCd_n = np.real(dCd_n)
463465
eUe_n = U_tilde_n_inverse[:, source_idx, source_idx]
@@ -490,7 +492,7 @@ def update_by_ipa(
490492
Eq_n = q_n.conj() @ E_n.transpose(1, 0)
491493
q_tilde_n = e_n.transpose(1, 0) - Eq_n
492494

493-
Uq_n = np.linalg.solve(U_tilde_n, q_tilde_n)
495+
Uq_n = solve(U_tilde_n, q_tilde_n)
494496
qUq_n = np.sum(q_tilde_n.conj() * Uq_n, axis=-1, keepdims=True)
495497

496498
qUq_n = np.real(qUq_n)
@@ -577,8 +579,8 @@ def _is_zero(x: np.ndarray) -> np.ndarray:
577579
gamma_in = np.sum(pad_mask_i[:, na] * RXY_in[..., 0], axis=1)
578580

579581
WU_in = W[:, neighbor_idx, :, :] @ U_in
580-
eta_in = np.linalg.solve(WU_in, e_n)
581-
eta_hat_in = np.linalg.solve(U_in, gamma_in)
582+
eta_in = solve(WU_in, e_n)
583+
eta_hat_in = solve(U_in, gamma_in)
582584
eta_U_in = eta_in[:, na, :].conj() @ U_in
583585

584586
xi_in = eta_U_in @ eta_in[:, :, na]

ssspy/bss/admmbss.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import numpy as np
55

66
from ..linalg import prox
7+
from ..linalg._solve import solve
78
from .proxbss import ProxBSSBase
89

910
EPS = 1e-10
@@ -232,7 +233,7 @@ def update_once(self) -> None:
232233
VY_tilde = np.sum(V_tilde - Y_tilde, axis=0)
233234
XVY_tilde = X.transpose(1, 0, 2).conj() @ VY_tilde.transpose(1, 2, 0)
234235

235-
W = np.linalg.solve(n_penalties * XX + E, VY + XVY_tilde.transpose(0, 2, 1))
236+
W = solve(n_penalties * XX + E, VY + XVY_tilde.transpose(0, 2, 1))
236237
XW = self.separate(X, demix_filter=W)
237238

238239
U = alpha * W + (1 - alpha) * V
@@ -426,7 +427,7 @@ def update_once(self) -> None:
426427
VY_tilde = V_tilde - Y_tilde
427428
XVY_tilde = X.transpose(1, 0, 2).conj() @ VY_tilde.transpose(1, 2, 0)
428429

429-
W = np.linalg.solve(XX + E, VY + XVY_tilde.transpose(0, 2, 1))
430+
W = solve(XX + E, VY + XVY_tilde.transpose(0, 2, 1))
430431
XW = self.separate(X, demix_filter=W)
431432

432433
U = alpha * W + (1 - alpha) * V

ssspy/bss/mnmf.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
import numpy as np
55

6+
from ..linalg._solve import solve
67
from ..linalg.mean import gmeanmh
78
from ..special.flooring import identity, max_flooring
89
from ..special.psd import to_psd
@@ -753,7 +754,7 @@ def separate(self, input: np.ndarray) -> np.ndarray:
753754
R = np.sum(R_n, axis=0)
754755
R = to_psd(R, flooring_fn=self.flooring_fn)
755756
R = np.tile(R, reps=(n_sources, 1, 1, 1, 1))
756-
W_Hermite = np.linalg.solve(R, R_n)
757+
W_Hermite = solve(R, R_n)
757758
W = W_Hermite.transpose(0, 1, 2, 4, 3).conj()
758759
W_ref = W[:, :, :, reference_id, :]
759760
W_ref = W_ref.transpose(0, 3, 1, 2)
@@ -777,7 +778,7 @@ def compute_loss(self) -> float:
777778
R = self.reconstruct_mnmf(T, V, H)
778779

779780
R = to_psd(R, flooring_fn=self.flooring_fn)
780-
XXR_inv = np.linalg.solve(R, XX) # Hermitian transpose of XX @ np.linalg.inv(R)
781+
XXR_inv = solve(R, XX) # Hermitian transpose of XX @ np.linalg.inv(R)
781782
trace = np.trace(XXR_inv, axis1=-2, axis2=-1)
782783
trace = np.real(trace)
783784
logdet = self.compute_logdet(R)
@@ -857,10 +858,10 @@ def update_basis(
857858
def _compute_traces(
858859
target: np.ndarray, reconstructed: np.ndarray, spatial: np.ndarray
859860
) -> np.ndarray:
860-
RXX = np.linalg.solve(reconstructed, target)
861+
RXX = solve(reconstructed, target)
861862
R = np.tile(reconstructed, reps=(n_sources, 1, 1, 1, 1))
862863
H = np.tile(spatial[:, :, na, :, :], reps=(1, 1, n_frames, 1, 1))
863-
RH = np.linalg.solve(R, H)
864+
RH = solve(R, H)
864865

865866
trace_RXXRH = np.trace(RXX @ RH, axis1=-2, axis2=-1)
866867
trace_RXXRH = np.real(trace_RXXRH)
@@ -924,10 +925,10 @@ def update_activation(
924925
def _compute_traces(
925926
target: np.ndarray, reconstructed: np.ndarray, spatial: np.ndarray
926927
) -> np.ndarray:
927-
RXX = np.linalg.solve(reconstructed, target)
928+
RXX = solve(reconstructed, target)
928929
R = np.tile(reconstructed, reps=(n_sources, 1, 1, 1, 1))
929930
H = np.tile(spatial[:, :, na, :, :], reps=(1, 1, n_frames, 1, 1))
930-
RH = np.linalg.solve(R, H)
931+
RH = solve(R, H)
931932

932933
trace_RXXRH = np.trace(RXX @ RH, axis1=-2, axis2=-1)
933934
trace_RXXRH = np.real(trace_RXXRH)
@@ -1039,10 +1040,10 @@ def update_latent(
10391040
def _compute_traces(
10401041
target: np.ndarray, reconstructed: np.ndarray, spatial: np.ndarray
10411042
) -> np.ndarray:
1042-
RXX = np.linalg.solve(reconstructed, target)
1043+
RXX = solve(reconstructed, target)
10431044
R = np.tile(reconstructed, reps=(n_sources, 1, 1, 1, 1))
10441045
H = np.tile(spatial[:, :, na, :, :], reps=(1, 1, n_frames, 1, 1))
1045-
RH = np.linalg.solve(R, H)
1046+
RH = solve(R, H)
10461047

10471048
trace_RXXRH = np.trace(RXX @ RH, axis1=-2, axis2=-1)
10481049
trace_RXXRH = np.real(trace_RXXRH)
@@ -1207,7 +1208,7 @@ def separate(self, input: np.ndarray) -> np.ndarray:
12071208
R = np.sum(R_n, axis=0)
12081209
R = to_psd(R, flooring_fn=self.flooring_fn)
12091210
R = np.tile(R, reps=(n_sources, 1, 1, 1, 1))
1210-
W_Hermite = np.linalg.solve(R, R_n)
1211+
W_Hermite = solve(R, R_n)
12111212
W = W_Hermite.transpose(0, 1, 2, 4, 3).conj()
12121213
W_ref = W[:, :, :, reference_id, :]
12131214
W_ref = W_ref.transpose(0, 3, 1, 2)

ssspy/linalg/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from ._solve import solve
12
from .cubic import cbrt
23
from .eigh import eigh, eigh2
34
from .inv import inv2
@@ -18,4 +19,5 @@
1819
"gmeanmh",
1920
"solve_cubic",
2021
"lqpqm2",
22+
"solve",
2123
]

ssspy/linalg/_solve.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
import numpy as np
2+
from packaging import version
3+
4+
np_version = np.__version__
5+
6+
IS_NUMPY_GE_2 = version.parse(np.__version__) >= version.parse("2")
7+
8+
9+
def solve(a: np.ndarray, b: np.ndarray) -> np.ndarray:
10+
requires_new_axis = IS_NUMPY_GE_2 and a.ndim == b.ndim + 1
11+
12+
if requires_new_axis:
13+
b = b[..., np.newaxis]
14+
15+
x = np.linalg.solve(a, b)
16+
17+
if requires_new_axis:
18+
x = x[..., 0]
19+
b = b[..., 0]
20+
21+
return x

0 commit comments

Comments
 (0)