-
Notifications
You must be signed in to change notification settings - Fork 10
feat(lattice): Make lattice geometries differentiable and backend-agn… #30
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
9e01be8
9d22384
d71d4a1
bb65592
0ad707c
92bc8e4
efaee05
7063c6f
0660abf
589763e
daa3ff2
9575be5
d372f72
0b38522
04aca93
283e1fd
494a99b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,118 @@ | ||
""" | ||
Benchmark: Compare neighbor-building time between KDTree and distance-matrix | ||
methods in CustomizeLattice for varying lattice sizes. | ||
""" | ||
|
||
import argparse | ||
import csv | ||
import time | ||
from typing import Iterable, List, Tuple, Optional | ||
import logging | ||
|
||
import numpy as np | ||
|
||
# Silence verbose infos from the library during benchmarks | ||
|
||
logging.basicConfig(level=logging.WARNING) | ||
|
||
from tensorcircuit.templates.lattice import CustomizeLattice | ||
|
||
|
||
def _timeit(fn, repeats: int) -> float: | ||
"""Return average wall time (seconds) over repeats for calling fn().""" | ||
times: List[float] = [] | ||
for _ in range(repeats): | ||
t0 = time.perf_counter() | ||
fn() | ||
times.append(time.perf_counter() - t0) | ||
return float(np.mean(times)) | ||
|
||
|
||
def _gen_coords(n: int, d: int, seed: int) -> np.ndarray: | ||
rng = np.random.default_rng(seed) | ||
return rng.random((n, d), dtype=float) | ||
|
||
|
||
def run_once( | ||
n: int, d: int, max_k: int, repeats: int, seed: int | ||
) -> Tuple[float, float]: | ||
"""Run one size point and return (time_kdtree, time_matrix).""" | ||
coords = _gen_coords(n, d, seed) | ||
ids = list(range(n)) | ||
lat = CustomizeLattice(dimensionality=d, identifiers=ids, coordinates=coords) | ||
|
||
# KDTree path | ||
t_kdtree = _timeit( | ||
lambda: lat._build_neighbors(max_k=max_k, use_kdtree=True), repeats | ||
) | ||
|
||
# Distance-matrix path (fully differentiable) | ||
t_matrix = _timeit( | ||
lambda: lat._build_neighbors(max_k=max_k, use_kdtree=False), repeats | ||
) | ||
|
||
return t_kdtree, t_matrix | ||
|
||
|
||
def parse_sizes(s: str) -> List[int]: | ||
return [int(x) for x in s.split(",") if x.strip()] | ||
|
||
|
||
def format_row(n: int, t_kdtree: float, t_matrix: float) -> str: | ||
speedup = (t_matrix / t_kdtree) if t_kdtree > 0 else float("inf") | ||
return f"{n:>8} | {t_kdtree:>12.6f} | {t_matrix:>14.6f} | {speedup:>7.2f}x" | ||
|
||
|
||
def main(argv: Optional[Iterable[str]] = None) -> int: | ||
p = argparse.ArgumentParser(description="Neighbor-building time comparison") | ||
p.add_argument( | ||
"--sizes", | ||
type=parse_sizes, | ||
default=[128, 256, 512, 1024, 2048], | ||
help="Comma-separated site counts to benchmark (default: 128,256,512,1024,2048)", | ||
) | ||
p.add_argument( | ||
"--dims", type=int, default=2, help="Lattice dimensionality (default: 2)" | ||
) | ||
p.add_argument( | ||
"--max-k", type=int, default=6, help="Max neighbor shells k (default: 6)" | ||
) | ||
p.add_argument( | ||
"--repeats", type=int, default=5, help="Repeats per measurement (default: 5)" | ||
) | ||
p.add_argument("--seed", type=int, default=42, help="PRNG seed (default: 42)") | ||
p.add_argument("--csv", type=str, default="", help="Optional CSV output path") | ||
args = p.parse_args(list(argv) if argv is not None else None) | ||
|
||
print("=" * 74) | ||
print( | ||
f"Benchmark CustomizeLattice neighbor-building | dims={args.dims} max_k={args.max_k} repeats={args.repeats}" | ||
) | ||
print("=" * 74) | ||
print(f"{'N':>8} | {'KDTree(s)':>12} | {'DistMatrix(s)':>14} | {'Speedup':>7}") | ||
print("-" * 74) | ||
|
||
rows: List[Tuple[int, float, float]] = [] | ||
for n in args.sizes: | ||
t_kdtree, t_matrix = run_once(n, args.dims, args.max_k, args.repeats, args.seed) | ||
rows.append((n, t_kdtree, t_matrix)) | ||
print(format_row(n, t_kdtree, t_matrix)) | ||
|
||
if args.csv: | ||
with open(args.csv, "w", newline="", encoding="utf-8") as f: | ||
w = csv.writer(f) | ||
w.writerow(["N", "time_kdtree_s", "time_distance_matrix_s", "speedup"]) | ||
for n, t_kdtree, t_matrix in rows: | ||
speedup = (t_matrix / t_kdtree) if t_kdtree > 0 else float("inf") | ||
w.writerow([n, f"{t_kdtree:.6f}", f"{t_matrix:.6f}", f"{speedup:.2f}"]) | ||
|
||
print("-" * 74) | ||
print(f"Saved CSV to: {args.csv}") | ||
|
||
print("-" * 74) | ||
print("Done.") | ||
return 0 | ||
|
||
|
||
if __name__ == "__main__": | ||
raise SystemExit(main()) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -631,8 +631,8 @@ def meshgrid(self: Any, *args: Any, **kwargs: Any) -> Any: | |
""" | ||
Return coordinate matrices from coordinate vectors. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. one more space for the docstring? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. still one more space on the above line There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. still one more space on the above line! |
||
|
||
:param args: coordinate vectors | ||
:type args: Any | ||
:param args: coordinate vectors | ||
:type args: Any | ||
:param kwargs: keyword arguments for meshgrid, typically includes 'indexing' | ||
refraction-ray marked this conversation as resolved.
Show resolved
Hide resolved
|
||
which can be 'ij' (matrix indexing) or 'xy' (Cartesian indexing). | ||
- 'ij': matrix indexing, first dimension corresponds to rows (default) | ||
|
@@ -647,9 +647,9 @@ def meshgrid(self: Any, *args: Any, **kwargs: Any) -> Any: | |
[0, 1]] | ||
y = [[0, 0], | ||
[2, 2]] | ||
:type kwargs: Any | ||
:return: list of coordinate matrices | ||
:rtype: Any | ||
:type kwargs: Any | ||
:return: list of coordinate matrices | ||
:rtype: Any | ||
""" | ||
raise NotImplementedError( | ||
"Backend '{}' has not implemented `meshgrid`.".format(self.name) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
please delete this file, and ensure the previous lattice neighbor example
lattice_neighbor_benchmark.py
is doing correct, i.e. compare kdtree and the baseline