Skip to content

Commit 283e1fd

Browse files
committed
fix according to the review
1 parent 04aca93 commit 283e1fd

File tree

6 files changed

+106
-236
lines changed

6 files changed

+106
-236
lines changed
Lines changed: 95 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -1,95 +1,111 @@
11
"""
2-
An example script to benchmark neighbor-finding algorithms in CustomizeLattice.
3-
4-
This script demonstrates the performance difference between the KDTree-based
5-
neighbor search and a baseline all-to-all distance matrix method.
6-
As shown by the results, the KDTree approach offers a significant speedup,
7-
especially when calculating for a large number of neighbor shells (large max_k).
2+
Benchmark: Compare neighbor-building time between KDTree and distance-matrix
3+
methods in CustomizeLattice for varying lattice sizes.
84
"""
95

10-
import timeit
11-
from typing import Any, Dict, List
6+
import argparse
7+
import csv
8+
import time
9+
from typing import Iterable, List, Tuple, Optional
10+
import logging
11+
12+
import numpy as np
13+
14+
# Silence verbose infos from the library during benchmarks
15+
16+
logging.basicConfig(level=logging.WARNING)
17+
18+
from tensorcircuit.templates.lattice import CustomizeLattice
19+
20+
21+
def run_once(
22+
n: int, d: int, max_k: int, repeats: int, seed: int
23+
) -> Tuple[float, float]:
24+
"""Run one size point and return (time_kdtree, time_matrix)."""
25+
rng = np.random.default_rng(seed)
26+
ids = list(range(n))
27+
28+
# Collect times for each repeat with different random coordinates
29+
kdtree_times: List[float] = []
30+
matrix_times: List[float] = []
31+
32+
for _ in range(repeats):
33+
# Generate different coordinates for each repeat
34+
coords = rng.random((n, d), dtype=float)
35+
lat = CustomizeLattice(dimensionality=d, identifiers=ids, coordinates=coords)
36+
37+
# KDTree path - single measurement
38+
t0 = time.perf_counter()
39+
lat._build_neighbors(max_k=max_k, use_kdtree=True)
40+
kdtree_times.append(time.perf_counter() - t0)
1241

42+
# Distance-matrix path - single measurement
43+
t0 = time.perf_counter()
44+
lat._build_neighbors(max_k=max_k, use_kdtree=False)
45+
matrix_times.append(time.perf_counter() - t0)
1346

14-
def run_benchmark() -> None:
15-
"""
16-
Executes the benchmark test and prints the results in a formatted table.
17-
"""
18-
# --- Benchmark Parameters ---
19-
# A list of lattice sizes (N = number of sites) to test
20-
site_counts: List[int] = [10, 50, 100, 200, 500, 1000, 1500, 2000]
47+
return float(np.mean(kdtree_times)), float(np.mean(matrix_times))
2148

22-
# Use a large k to better showcase the performance of KDTree in
23-
# finding multiple neighbor shells, as suggested by the maintainer.
24-
max_k: int = 2000
2549

26-
# Reduce the number of runs to keep the total benchmark time reasonable,
27-
# especially with a large max_k.
28-
number_of_runs: int = 3
29-
# --------------------------
50+
def parse_sizes(s: str) -> List[int]:
51+
return [int(x) for x in s.split(",") if x.strip()]
3052

31-
results: List[Dict[str, Any]] = []
3253

33-
print("=" * 75)
34-
print("Starting neighbor finding benchmark for CustomizeLattice...")
35-
print(f"Parameters: max_k={max_k}, number_of_runs={number_of_runs}")
36-
print("=" * 75)
54+
def format_row(n: int, t_kdtree: float, t_matrix: float) -> str:
55+
speedup = (t_matrix / t_kdtree) if t_kdtree > 0 else float("inf")
56+
return f"{n:>8} | {t_kdtree:>12.6f} | {t_matrix:>14.6f} | {speedup:>7.2f}x"
57+
58+
59+
def main(argv: Optional[Iterable[str]] = None) -> int:
60+
p = argparse.ArgumentParser(description="Neighbor-building time comparison")
61+
p.add_argument(
62+
"--sizes",
63+
type=parse_sizes,
64+
default=[128, 256, 512, 1024, 2048],
65+
help="Comma-separated site counts to benchmark (default: 128,256,512,1024,2048)",
66+
)
67+
p.add_argument(
68+
"--dims", type=int, default=2, help="Lattice dimensionality (default: 2)"
69+
)
70+
p.add_argument(
71+
"--max-k", type=int, default=6, help="Max neighbor shells k (default: 6)"
72+
)
73+
p.add_argument(
74+
"--repeats", type=int, default=5, help="Repeats per measurement (default: 5)"
75+
)
76+
p.add_argument("--seed", type=int, default=42, help="PRNG seed (default: 42)")
77+
p.add_argument("--csv", type=str, default="", help="Optional CSV output path")
78+
args = p.parse_args(list(argv) if argv is not None else None)
79+
80+
print("=" * 74)
3781
print(
38-
f"{'Sites (N)':>10} | {'KDTree Time (s)':>18} | {'Baseline Time (s)':>20} | {'Speedup':>10}"
82+
f"Benchmark CustomizeLattice neighbor-building | dims={args.dims} max_k={args.max_k} repeats={args.repeats}"
3983
)
40-
print("-" * 75)
84+
print("=" * 74)
85+
print(f"{'N':>8} | {'KDTree(s)':>12} | {'DistMatrix(s)':>14} | {'Speedup':>7}")
86+
print("-" * 74)
4187

42-
for n_sites in site_counts:
43-
# Prepare the setup code for timeit.
44-
# This code generates a random lattice and is executed before timing begins.
45-
# We use a fixed seed to ensure the coordinates are the same for both tests.
46-
setup_code = f"""
47-
import numpy as np
48-
from tensorcircuit.templates.lattice import CustomizeLattice
88+
rows: List[Tuple[int, float, float]] = []
89+
for n in args.sizes:
90+
t_kdtree, t_matrix = run_once(n, args.dims, args.max_k, args.repeats, args.seed)
91+
rows.append((n, t_kdtree, t_matrix))
92+
print(format_row(n, t_kdtree, t_matrix))
4993

50-
np.random.seed(42)
51-
coords = np.random.rand({n_sites}, 2)
52-
ids = list(range({n_sites}))
53-
lat = CustomizeLattice(dimensionality=2, identifiers=ids, coordinates=coords)
54-
"""
55-
# Define the Python statements to be timed.
56-
stmt_kdtree = f"lat._build_neighbors(max_k={max_k})"
57-
stmt_baseline = f"lat._build_neighbors_by_distance_matrix(max_k={max_k})"
58-
59-
try:
60-
# Execute the timing. timeit returns the total time for all runs.
61-
time_kdtree = (
62-
timeit.timeit(stmt=stmt_kdtree, setup=setup_code, number=number_of_runs)
63-
/ number_of_runs
64-
)
65-
time_baseline = (
66-
timeit.timeit(
67-
stmt=stmt_baseline, setup=setup_code, number=number_of_runs
68-
)
69-
/ number_of_runs
70-
)
71-
72-
# Calculate and store results, handling potential division by zero.
73-
speedup = time_baseline / time_kdtree if time_kdtree > 0 else float("inf")
74-
results.append(
75-
{
76-
"n_sites": n_sites,
77-
"time_kdtree": time_kdtree,
78-
"time_baseline": time_baseline,
79-
"speedup": speedup,
80-
}
81-
)
82-
print(
83-
f"{n_sites:>10} | {time_kdtree:>18.6f} | {time_baseline:>20.6f} | {speedup:>9.2f}x"
84-
)
85-
86-
except Exception as e:
87-
print(f"An error occurred at N={n_sites}: {e}")
88-
break
89-
90-
print("-" * 75)
91-
print("Benchmark complete.")
94+
if args.csv:
95+
with open(args.csv, "w", newline="", encoding="utf-8") as f:
96+
w = csv.writer(f)
97+
w.writerow(["N", "time_kdtree_s", "time_distance_matrix_s", "speedup"])
98+
for n, t_kdtree, t_matrix in rows:
99+
speedup = (t_matrix / t_kdtree) if t_kdtree > 0 else float("inf")
100+
w.writerow([n, f"{t_kdtree:.6f}", f"{t_matrix:.6f}", f"{speedup:.2f}"])
101+
102+
print("-" * 74)
103+
print(f"Saved CSV to: {args.csv}")
104+
105+
print("-" * 74)
106+
print("Done.")
107+
return 0
92108

93109

94110
if __name__ == "__main__":
95-
run_benchmark()
111+
raise SystemExit(main())

examples/lattice_neighbor_time_compare.py

Lines changed: 0 additions & 111 deletions
This file was deleted.

examples/lennard_jones_optimization.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,10 @@
55
to optimize crystal structure. It finds the equilibrium lattice constant that minimizes
66
the total Lennard-Jones potential energy of a 2D square lattice.
77
8-
The optimization showcases the key Task 3 capability: making lattice parameters
9-
differentiable for variational material design.
8+
This example showcases a key capability of the differentiable lattice system:
9+
making geometric parameters (like lattice constants) fully differentiable and
10+
optimizable using automatic differentiation. This enables variational material design
11+
where crystal structures can be optimized to minimize physical energy functions.
1012
"""
1113

1214
import optax
@@ -54,7 +56,7 @@ def calculate_potential(log_a, epsilon=0.5, sigma=1.0):
5456

5557
optimizer = optax.adam(learning_rate=0.01)
5658

57-
log_a = K.convert_to_tensor(K.log(K.convert_to_tensor(1.1)))
59+
log_a = K.convert_to_tensor(K.log(K.convert_to_tensor(2.0)))
5860

5961
opt_state = optimizer.init(log_a)
6062

tensorcircuit/backends/abstract_backend.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -629,7 +629,7 @@ def all(self: Any, a: Tensor, axis: Optional[Sequence[int]] = None) -> Tensor:
629629

630630
def meshgrid(self: Any, *args: Any, **kwargs: Any) -> Any:
631631
"""
632-
Return coordinate matrices from coordinate vectors.
632+
Return coordinate matrices from coordinate vectors.
633633
634634
:param args: coordinate vectors
635635
:type args: Any

tensorcircuit/templates/lattice.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,11 @@
1818
Set,
1919
)
2020

21-
logger = logging.getLogger(__name__)
2221
import itertools
2322
import math
2423
import numpy as np
2524
from scipy.spatial import KDTree
25+
2626
from .. import backend
2727

2828

@@ -42,6 +42,8 @@
4242
import matplotlib.axes
4343
from mpl_toolkits.mplot3d import Axes3D
4444

45+
logger = logging.getLogger(__name__)
46+
4547
Tensor = Any
4648
SiteIndex = int
4749
SiteIdentifier = Hashable
@@ -835,7 +837,7 @@ def _build_neighbors(self, max_k: int = 2, **kwargs: Any) -> None:
835837
:param max_k: The maximum order of neighbors to compute (e.g., k=1 for
836838
nearest neighbors, k=2 for next-nearest, etc.). Defaults to 2.
837839
:type max_k: int, optional
838-
:param \**kwargs: Additional keyword arguments. May include:
840+
:param kwargs: Additional keyword arguments. May include:
839841
- ``tol`` (float): The numerical tolerance used to determine if two
840842
distances are equal when identifying shells. Defaults to 1e-6.
841843
"""
@@ -1486,7 +1488,7 @@ def _build_neighbors_kdtree(self, max_k: int, tol: float) -> None:
14861488

14871489
# For small lattices or cases with potential duplicate coordinates,
14881490
# fall back to distance matrix method for robustness
1489-
if self.num_sites < 1000:
1491+
if self.num_sites < 200:
14901492
logger.info(
14911493
"Small lattice detected, falling back to distance matrix method for robustness"
14921494
)

0 commit comments

Comments
 (0)