Skip to content

Commit efaee05

Browse files
committed
fix mypy errors
1 parent 92bc8e4 commit efaee05

File tree

1 file changed

+18
-6
lines changed

1 file changed

+18
-6
lines changed

tensorcircuit/templates/lattice.py

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1420,7 +1420,7 @@ def _build_neighbors_kdtree(self, max_k: int, tol: float) -> None:
14201420

14211421
# Find all distances for shell identification - use comprehensive sampling
14221422
logger.info("Identifying distance shells...")
1423-
distances_for_shells = []
1423+
distances_for_shells: List[float] = []
14241424

14251425
# For robust shell identification, query all pairwise distances for smaller lattices
14261426
# or use dense sampling for larger ones
@@ -1432,7 +1432,10 @@ def _build_neighbors_kdtree(self, max_k: int, tol: float) -> None:
14321432
dists, _ = tree.query(
14331433
coords_np[i], k=query_k + 1
14341434
) # +1 to exclude self
1435-
distances_for_shells.extend(dists[1:]) # Skip distance to self
1435+
if isinstance(dists, np.ndarray):
1436+
distances_for_shells.extend(dists[1:]) # Skip distance to self
1437+
else:
1438+
distances_for_shells.append(dists) # Single distance
14361439
else:
14371440
# For larger lattices, use adaptive sampling but ensure we capture all shells
14381441
sample_size = min(1000, self.num_sites // 2) # More conservative sampling
@@ -1442,7 +1445,10 @@ def _build_neighbors_kdtree(self, max_k: int, tol: float) -> None:
14421445
dists, _ = tree.query(
14431446
coords_np[i], k=query_k + 1
14441447
) # +1 to exclude self
1445-
distances_for_shells.extend(dists[1:]) # Skip distance to self
1448+
if isinstance(dists, np.ndarray):
1449+
distances_for_shells.extend(dists[1:]) # Skip distance to self
1450+
else:
1451+
distances_for_shells.append(dists) # Single distance
14461452

14471453
# Filter out zero distances (duplicate coordinates) before shell identification
14481454
ZERO_THRESHOLD = 1e-12
@@ -1476,12 +1482,18 @@ def _build_neighbors_kdtree(self, max_k: int, tol: float) -> None:
14761482
) # +1 for self
14771483

14781484
# Skip the first entry (distance to self)
1479-
distances = distances[1:]
1480-
indices = indices[1:]
1485+
# Handle both single value and array cases
1486+
if isinstance(distances, np.ndarray) and len(distances) > 1:
1487+
distances_slice = distances[1:]
1488+
indices_slice = indices[1:] if isinstance(indices, np.ndarray) else np.array([], dtype=int)
1489+
else:
1490+
# Single value or empty case - no neighbors to process
1491+
distances_slice = np.array([])
1492+
indices_slice = np.array([], dtype=int)
14811493

14821494
# Filter out zero distances (duplicate coordinates)
14831495
valid_pairs = [
1484-
(d, idx) for d, idx in zip(distances, indices) if d > ZERO_THRESHOLD
1496+
(d, idx) for d, idx in zip(distances_slice, indices_slice) if d > ZERO_THRESHOLD
14851497
]
14861498

14871499
# Assign neighbors to shells

0 commit comments

Comments
 (0)