Skip to content

Commit 0ecbc99

Browse files
thebabushpatacca
authored andcommitted
hammer lapx into qbindiff
1 parent 13e156e commit 0ecbc99

File tree

3 files changed

+20
-4
lines changed

3 files changed

+20
-4
lines changed

src/qbindiff/matcher/matcher.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,12 @@
2222

2323
# Third-party imports
2424
import numpy as np
25-
from lapjv import lapjv # type: ignore[import-not-found]
2625
from scipy.sparse import csr_matrix, coo_matrix # type: ignore[import-untyped]
2726

2827
# Local imports
2928
from qbindiff.matcher.squares import find_squares # type: ignore[import-untyped]
3029
from qbindiff.matcher.belief_propagation import BeliefMWM, BeliefQAP
30+
from qbindiff.utils import lap
3131

3232

3333
if TYPE_CHECKING:
@@ -50,7 +50,7 @@ def solve_linear_assignment(cost_matrix: Matrix) -> RawMapping:
5050
cost_matrix = cost_matrix.T
5151
full_cost_matrix = np.zeros((m, m), dtype=cost_matrix.dtype)
5252
full_cost_matrix[:n, :m] = cost_matrix
53-
col_indices = lapjv(full_cost_matrix)[0][:n]
53+
col_indices = lap(full_cost_matrix)[:n]
5454
if transposed:
5555
return col_indices, np.arange(n)
5656
return np.arange(n), col_indices

src/qbindiff/utils/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,4 +17,4 @@
1717
Collection of utilities used internally.
1818
"""
1919

20-
from .utils import is_debug, iter_csr_matrix, log_once, wrapper_iter
20+
from .utils import is_debug, iter_csr_matrix, lap, log_once, wrapper_iter

src/qbindiff/utils/utils.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,17 @@
2323
import logging
2424
from typing import TYPE_CHECKING
2525

26+
try:
27+
from lapjv import lapjv as _lap # type: ignore[import-not-found]
28+
_is_lapjv = True
29+
except ImportError:
30+
from lap import lapjv as _lap
31+
_is_lapjv = False
32+
2633
if TYPE_CHECKING:
2734
from collections.abc import Generator
2835
from typing import Any
29-
from qbindiff.types import SparseMatrix
36+
from qbindiff.types import Matrix, RawMapping, SparseMatrix
3037

3138

3239
def is_debug() -> bool:
@@ -49,6 +56,15 @@ def iter_csr_matrix(matrix: SparseMatrix) -> Generator[tuple[int, int, Any], Non
4956
yield (x, y, v)
5057

5158

59+
def lap(cost_matrix: Matrix) -> RawMapping:
60+
if _is_lapjv:
61+
row_ind_lapjv, _col_ind_lapjv, _ = _lap(cost_matrix)
62+
return row_ind_lapjv
63+
else:
64+
_cost, x, _y = _lap(cost_matrix)
65+
return x
66+
67+
5268
@cache
5369
def log_once(level: int, message: str) -> None:
5470
"""

0 commit comments

Comments
 (0)