Skip to content

Commit 98ba9fa

Browse files
early switch for kdtree neighbor building
1 parent 494a99b commit 98ba9fa

File tree

3 files changed

+10
-10
lines changed

3 files changed

+10
-10
lines changed

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@
2424

2525
- Add transformation method between tensornetwork, quimb, tenpy and QuOperator in tc-ng including `qop2tenpy`, `qop2quimb`, `qop2tn`, `tenpy2qop`, support both MPS and MPO formats.
2626

27+
- Make the lattice module backend agnostic, now the lattice follows `tc.set_backend`.
28+
2729
### Fixed
2830

2931
- Fixed `one_hot` in numpy backend.

examples/lattice_neighbor_benchmark.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ def main(argv: Optional[Iterable[str]] = None) -> int:
6868
"--dims", type=int, default=2, help="Lattice dimensionality (default: 2)"
6969
)
7070
p.add_argument(
71-
"--max-k", type=int, default=6, help="Max neighbor shells k (default: 6)"
71+
"--max-k", type=int, default=10, help="Max neighbor shells k (default: 6)"
7272
)
7373
p.add_argument(
7474
"--repeats", type=int, default=5, help="Repeats per measurement (default: 5)"

tensorcircuit/templates/lattice.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -377,7 +377,7 @@ def show(
377377
try:
378378
import matplotlib.pyplot as plt
379379
except ImportError:
380-
logger.error(
380+
logger.warning(
381381
"Matplotlib is required for visualization. "
382382
"Please install it using 'pip install matplotlib'."
383383
)
@@ -1479,12 +1479,6 @@ def _build_neighbors_kdtree(self, max_k: int, tol: float) -> None:
14791479
Note: This method uses numpy arrays directly and may not be compatible
14801480
with all backend types (JAX, TensorFlow, etc.).
14811481
"""
1482-
# Convert coordinates to numpy for KDTree
1483-
coords_np = backend.numpy(self._coordinates)
1484-
1485-
# Build KDTree
1486-
logger.info("Building KDTree...")
1487-
tree = KDTree(coords_np)
14881482

14891483
# For small lattices or cases with potential duplicate coordinates,
14901484
# fall back to distance matrix method for robustness
@@ -1495,6 +1489,12 @@ def _build_neighbors_kdtree(self, max_k: int, tol: float) -> None:
14951489
self._build_neighbors_by_distance_matrix(max_k, tol)
14961490
return
14971491

1492+
# Convert coordinates to numpy for KDTree
1493+
coords_np = backend.numpy(self._coordinates)
1494+
1495+
# Build KDTree
1496+
logger.info("Building KDTree...")
1497+
tree = KDTree(coords_np)
14981498
# Find all distances for shell identification - use comprehensive sampling
14991499
logger.info("Identifying distance shells...")
15001500
distances_for_shells: List[float] = []
@@ -1549,7 +1549,6 @@ def _build_neighbors_kdtree(self, max_k: int, tol: float) -> None:
15491549
self._neighbor_maps = {k: {} for k in range(1, len(dist_shells) + 1)}
15501550

15511551
# Build neighbor lists for each site
1552-
logger.info("Building neighbor lists...")
15531552
for i in range(self.num_sites):
15541553
# Query enough neighbors to capture all shells
15551554
query_k = min(max_k * 20 + 50, self.num_sites - 1)
@@ -1595,7 +1594,6 @@ def _build_neighbors_kdtree(self, max_k: int, tol: float) -> None:
15951594

15961595
# Set distance matrix to None - will compute on demand
15971596
self._distance_matrix = None
1598-
logger.info("KDTree neighbor building completed")
15991597

16001598
def _reset_computations(self) -> None:
16011599
"""Resets all cached data that depends on the lattice structure."""

0 commit comments

Comments
 (0)