diff --git a/examples/lattice_neighbor_benchmark.py b/examples/lattice_neighbor_benchmark.py index 98cb7f8..dada10a 100644 --- a/examples/lattice_neighbor_benchmark.py +++ b/examples/lattice_neighbor_benchmark.py @@ -1,95 +1,111 @@ """ -An example script to benchmark neighbor-finding algorithms in CustomizeLattice. - -This script demonstrates the performance difference between the KDTree-based -neighbor search and a baseline all-to-all distance matrix method. -As shown by the results, the KDTree approach offers a significant speedup, -especially when calculating for a large number of neighbor shells (large max_k). +Benchmark: Compare neighbor-building time between KDTree and distance-matrix +methods in CustomizeLattice for varying lattice sizes. """ -import timeit -from typing import Any, Dict, List +import argparse +import csv +import time +from typing import Iterable, List, Tuple, Optional +import logging + +import numpy as np + +# Silence verbose infos from the library during benchmarks + +logging.basicConfig(level=logging.WARNING) + +from tensorcircuit.templates.lattice import CustomizeLattice + + +def run_once( + n: int, d: int, max_k: int, repeats: int, seed: int +) -> Tuple[float, float]: + """Run one size point and return (time_kdtree, time_matrix).""" + rng = np.random.default_rng(seed) + ids = list(range(n)) + + # Collect times for each repeat with different random coordinates + kdtree_times: List[float] = [] + matrix_times: List[float] = [] + + for _ in range(repeats): + # Generate different coordinates for each repeat + coords = rng.random((n, d), dtype=float) + lat = CustomizeLattice(dimensionality=d, identifiers=ids, coordinates=coords) + + # KDTree path - single measurement + t0 = time.perf_counter() + lat._build_neighbors(max_k=max_k, use_kdtree=True) + kdtree_times.append(time.perf_counter() - t0) + # Distance-matrix path - single measurement + t0 = time.perf_counter() + lat._build_neighbors(max_k=max_k, use_kdtree=False) + matrix_times.append(time.perf_counter() - t0) -def run_benchmark() -> None: - """ - Executes the benchmark test and prints the results in a formatted table. - """ - # --- Benchmark Parameters --- - # A list of lattice sizes (N = number of sites) to test - site_counts: List[int] = [10, 50, 100, 200, 500, 1000, 1500, 2000] + return float(np.mean(kdtree_times)), float(np.mean(matrix_times)) - # Use a large k to better showcase the performance of KDTree in - # finding multiple neighbor shells, as suggested by the maintainer. - max_k: int = 2000 - # Reduce the number of runs to keep the total benchmark time reasonable, - # especially with a large max_k. - number_of_runs: int = 3 - # -------------------------- +def parse_sizes(s: str) -> List[int]: + return [int(x) for x in s.split(",") if x.strip()] - results: List[Dict[str, Any]] = [] - print("=" * 75) - print("Starting neighbor finding benchmark for CustomizeLattice...") - print(f"Parameters: max_k={max_k}, number_of_runs={number_of_runs}") - print("=" * 75) +def format_row(n: int, t_kdtree: float, t_matrix: float) -> str: + speedup = (t_matrix / t_kdtree) if t_kdtree > 0 else float("inf") + return f"{n:>8} | {t_kdtree:>12.6f} | {t_matrix:>14.6f} | {speedup:>7.2f}x" + + +def main(argv: Optional[Iterable[str]] = None) -> int: + p = argparse.ArgumentParser(description="Neighbor-building time comparison") + p.add_argument( + "--sizes", + type=parse_sizes, + default=[128, 256, 512, 1024, 2048], + help="Comma-separated site counts to benchmark (default: 128,256,512,1024,2048)", + ) + p.add_argument( + "--dims", type=int, default=2, help="Lattice dimensionality (default: 2)" + ) + p.add_argument( + "--max-k", type=int, default=6, help="Max neighbor shells k (default: 6)" + ) + p.add_argument( + "--repeats", type=int, default=5, help="Repeats per measurement (default: 5)" + ) + p.add_argument("--seed", type=int, default=42, help="PRNG seed (default: 42)") + p.add_argument("--csv", type=str, default="", help="Optional CSV output path") + args = p.parse_args(list(argv) if argv is not None else None) + + print("=" * 74) print( - f"{'Sites (N)':>10} | {'KDTree Time (s)':>18} | {'Baseline Time (s)':>20} | {'Speedup':>10}" + f"Benchmark CustomizeLattice neighbor-building | dims={args.dims} max_k={args.max_k} repeats={args.repeats}" ) - print("-" * 75) + print("=" * 74) + print(f"{'N':>8} | {'KDTree(s)':>12} | {'DistMatrix(s)':>14} | {'Speedup':>7}") + print("-" * 74) - for n_sites in site_counts: - # Prepare the setup code for timeit. - # This code generates a random lattice and is executed before timing begins. - # We use a fixed seed to ensure the coordinates are the same for both tests. - setup_code = f""" -import numpy as np -from tensorcircuit.templates.lattice import CustomizeLattice + rows: List[Tuple[int, float, float]] = [] + for n in args.sizes: + t_kdtree, t_matrix = run_once(n, args.dims, args.max_k, args.repeats, args.seed) + rows.append((n, t_kdtree, t_matrix)) + print(format_row(n, t_kdtree, t_matrix)) -np.random.seed(42) -coords = np.random.rand({n_sites}, 2) -ids = list(range({n_sites})) -lat = CustomizeLattice(dimensionality=2, identifiers=ids, coordinates=coords) -""" - # Define the Python statements to be timed. - stmt_kdtree = f"lat._build_neighbors(max_k={max_k})" - stmt_baseline = f"lat._build_neighbors_by_distance_matrix(max_k={max_k})" - - try: - # Execute the timing. timeit returns the total time for all runs. - time_kdtree = ( - timeit.timeit(stmt=stmt_kdtree, setup=setup_code, number=number_of_runs) - / number_of_runs - ) - time_baseline = ( - timeit.timeit( - stmt=stmt_baseline, setup=setup_code, number=number_of_runs - ) - / number_of_runs - ) - - # Calculate and store results, handling potential division by zero. - speedup = time_baseline / time_kdtree if time_kdtree > 0 else float("inf") - results.append( - { - "n_sites": n_sites, - "time_kdtree": time_kdtree, - "time_baseline": time_baseline, - "speedup": speedup, - } - ) - print( - f"{n_sites:>10} | {time_kdtree:>18.6f} | {time_baseline:>20.6f} | {speedup:>9.2f}x" - ) - - except Exception as e: - print(f"An error occurred at N={n_sites}: {e}") - break - - print("-" * 75) - print("Benchmark complete.") + if args.csv: + with open(args.csv, "w", newline="", encoding="utf-8") as f: + w = csv.writer(f) + w.writerow(["N", "time_kdtree_s", "time_distance_matrix_s", "speedup"]) + for n, t_kdtree, t_matrix in rows: + speedup = (t_matrix / t_kdtree) if t_kdtree > 0 else float("inf") + w.writerow([n, f"{t_kdtree:.6f}", f"{t_matrix:.6f}", f"{speedup:.2f}"]) + + print("-" * 74) + print(f"Saved CSV to: {args.csv}") + + print("-" * 74) + print("Done.") + return 0 if __name__ == "__main__": - run_benchmark() + raise SystemExit(main()) diff --git a/examples/lennard_jones_optimization.py b/examples/lennard_jones_optimization.py new file mode 100644 index 0000000..6d8349a --- /dev/null +++ b/examples/lennard_jones_optimization.py @@ -0,0 +1,121 @@ +""" +Lennard-Jones Potential Optimization Example + +This script demonstrates how to use TensorCircuit's differentiable lattice geometries +to optimize crystal structure. It finds the equilibrium lattice constant that minimizes +the total Lennard-Jones potential energy of a 2D square lattice. + +This example showcases a key capability of the differentiable lattice system: +making geometric parameters (like lattice constants) fully differentiable and +optimizable using automatic differentiation. This enables variational material design +where crystal structures can be optimized to minimize physical energy functions. +""" + +import optax +import numpy as np +import matplotlib.pyplot as plt +import tensorcircuit as tc + + +tc.set_dtype("float64") # Use tc for universal control +K = tc.set_backend("jax") + + +def calculate_potential(log_a, epsilon=0.5, sigma=1.0): + """ + Calculate the total Lennard-Jones potential energy for a given logarithm of the lattice constant (log_a). + This version creates the lattice inside the function to demonstrate truly differentiable geometry. + """ + lattice_constant = K.exp(log_a) + + # Create lattice with the differentiable parameter + size = (4, 4) # Smaller size for demonstration + lattice = tc.templates.lattice.SquareLattice( + size, lattice_constant=lattice_constant, pbc=True + ) + d = lattice.distance_matrix + + d_safe = K.where(d > 1e-9, d, K.convert_to_tensor(1e-9)) + + term12 = K.power(sigma / d_safe, 12) + term6 = K.power(sigma / d_safe, 6) + potential_matrix = 4 * epsilon * (term12 - term6) + + num_sites = lattice.num_sites + # Zero out self-interactions (diagonal elements) + eye_mask = K.eye(num_sites, dtype=potential_matrix.dtype) + potential_matrix = potential_matrix * (1 - eye_mask) + + potential_energy = K.sum(potential_matrix) / 2.0 + + return potential_energy + + +# Create value and grad function for optimization +value_and_grad_fun = K.jit(K.value_and_grad(calculate_potential)) + +optimizer = optax.adam(learning_rate=0.01) + +log_a = K.convert_to_tensor(K.log(K.convert_to_tensor(2.0))) + +opt_state = optimizer.init(log_a) + +history = {"a": [], "energy": []} + +print("Starting optimization of lattice constant...") +for i in range(200): + energy, grad = value_and_grad_fun(log_a) + + history["a"].append(K.exp(log_a)) + history["energy"].append(energy) + + updates, opt_state = optimizer.update(grad, opt_state) + log_a = optax.apply_updates(log_a, updates) + + if (i + 1) % 20 == 0: + current_a = K.exp(log_a) + print( + f"Iteration {i+1}/200: Total Energy = {energy:.4f}, Lattice Constant = {current_a:.4f}" + ) + +final_a = K.exp(log_a) +final_energy = calculate_potential(log_a) + +print("\nOptimization finished!") +print(f"Final optimized lattice constant: {final_a:.6f}") +print(f"Corresponding minimum total energy: {final_energy:.6f}") + +# Vectorized calculation for the potential curve +a_vals = np.linspace(0.8, 1.5, 200) +log_a_vals = K.log(K.convert_to_tensor(a_vals)) + +# Use vmap to create a vectorized version of the potential function +vmap_potential = K.vmap(lambda la: calculate_potential(la)) +potential_curve = vmap_potential(log_a_vals) + +plt.figure(figsize=(10, 6)) +plt.plot(a_vals, potential_curve, label="Lennard-Jones Potential", color="blue") +plt.scatter( + history["a"], + history["energy"], + color="red", + s=20, + zorder=5, + label="Optimization Steps", +) +plt.scatter( + final_a, + final_energy, + color="green", + s=100, + zorder=6, + marker="*", + label="Final Optimized Point", +) + +plt.title("Lennard-Jones Potential Optimization") +plt.xlabel("Lattice Constant (a)") +plt.ylabel("Total Potential Energy") +plt.legend() +plt.grid(True) +plt.show() diff --git a/tensorcircuit/backends/abstract_backend.py b/tensorcircuit/backends/abstract_backend.py index adbad83..8372080 100644 --- a/tensorcircuit/backends/abstract_backend.py +++ b/tensorcircuit/backends/abstract_backend.py @@ -596,6 +596,82 @@ def argsort(self: Any, a: Tensor, axis: int = -1) -> Tensor: "Backend '{}' has not implemented `argsort`.".format(self.name) ) + def sort(self: Any, a: Tensor, axis: int = -1) -> Tensor: + """ + Sort a tensor along the given axis. + + :param a: [description] + :type a: Tensor + :param axis: [description], defaults to -1 + :type axis: int, optional + :return: [description] + :rtype: Tensor + """ + raise NotImplementedError( + "Backend '{}' has not implemented `sort`.".format(self.name) + ) + + def all(self: Any, a: Tensor, axis: Optional[Sequence[int]] = None) -> Tensor: + """ + Test whether all array elements along a given axis evaluate to True. + + :param a: Input tensor + :type a: Tensor + :param axis: Axis or axes along which a logical AND reduction is performed, + defaults to None + :type axis: Optional[Sequence[int]], optional + :return: A new boolean or tensor resulting from the AND reduction + :rtype: Tensor + """ + raise NotImplementedError( + "Backend '{}' has not implemented `all`.".format(self.name) + ) + + def meshgrid(self: Any, *args: Any, **kwargs: Any) -> Any: + """ + Return coordinate matrices from coordinate vectors. + + :param args: coordinate vectors + :type args: Any + :param kwargs: keyword arguments for meshgrid, typically includes 'indexing' + which can be 'ij' (matrix indexing) or 'xy' (Cartesian indexing). + - 'ij': matrix indexing, first dimension corresponds to rows (default) + - 'xy': Cartesian indexing, first dimension corresponds to columns + Example: + >>> x, y = backend.meshgrid([0, 1], [0, 2], indexing='xy') + Shapes: + - x.shape == (2, 2) # rows correspond to y vector length + - y.shape == (2, 2) + Values: + x = [[0, 1], + [0, 1]] + y = [[0, 0], + [2, 2]] + :type kwargs: Any + :return: list of coordinate matrices + :rtype: Any + """ + raise NotImplementedError( + "Backend '{}' has not implemented `meshgrid`.".format(self.name) + ) + + def expand_dims(self: Any, a: Tensor, axis: int) -> Tensor: + """ + Expand the shape of a tensor. + Insert a new axis that will appear at the `axis` position in the expanded + tensor shape. + + :param a: Input tensor + :type a: Tensor + :param axis: Position in the expanded axes where the new axis is placed + :type axis: int + :return: Output tensor with the number of dimensions increased by one. + :rtype: Tensor + """ + raise NotImplementedError( + "Backend '{}' has not implemented `expand_dims`.".format(self.name) + ) + def unique_with_counts(self: Any, a: Tensor, **kws: Any) -> Tuple[Tensor, Tensor]: """ Find the unique elements and their corresponding counts of the given tensor ``a``. @@ -733,6 +809,21 @@ def cast(self: Any, a: Tensor, dtype: str) -> Tensor: "Backend '{}' has not implemented `cast`.".format(self.name) ) + def convert_to_tensor(self: Any, a: Tensor, dtype: Optional[str] = None) -> Tensor: + """ + Convert input to tensor. + + :param a: input data to be converted + :type a: Tensor + :param dtype: target dtype, optional + :type dtype: Optional[str] + :return: converted tensor + :rtype: Tensor + """ + raise NotImplementedError( + "Backend '{}' has not implemented `convert_to_tensor`.".format(self.name) + ) + def mod(self: Any, x: Tensor, y: Tensor) -> Tensor: """ Compute y-mod of x (negative number behavior is not guaranteed to be consistent) @@ -1404,6 +1495,28 @@ def cond( "Backend '{}' has not implemented `cond`.".format(self.name) ) + def where( + self: Any, + condition: Tensor, + x: Optional[Tensor] = None, + y: Optional[Tensor] = None, + ) -> Tensor: + """ + Return a tensor of elements selected from either x or y, depending on condition. + + :param condition: Where True, yield x, otherwise yield y. + :type condition: Tensor (bool) + :param x: Values from which to choose when condition is True. + :type x: Tensor + :param y: Values from which to choose when condition is False. + :type y: Tensor + :return: A tensor with elements from x where condition is True, and y otherwise. + :rtype: Tensor + """ + raise NotImplementedError( + "Backend '{}' has not implemented `where`.".format(self.name) + ) + def switch( self: Any, index: Tensor, branches: Sequence[Callable[[], Tensor]] ) -> Tensor: diff --git a/tensorcircuit/backends/cupy_backend.py b/tensorcircuit/backends/cupy_backend.py index 4525cf0..7ee0892 100644 --- a/tensorcircuit/backends/cupy_backend.py +++ b/tensorcircuit/backends/cupy_backend.py @@ -56,10 +56,12 @@ def __init__(self) -> None: cpx = cupyx self.name = "cupy" - def convert_to_tensor(self, a: Tensor) -> Tensor: + def convert_to_tensor(self, a: Tensor, dtype: Optional[str] = None) -> Tensor: if not isinstance(a, cp.ndarray) and not cp.isscalar(a): a = cp.array(a) a = cp.asarray(a) + if dtype is not None: + a = self.cast(a, dtype) return a def sum( diff --git a/tensorcircuit/backends/jax_backend.py b/tensorcircuit/backends/jax_backend.py index a9d17b9..d678b42 100644 --- a/tensorcircuit/backends/jax_backend.py +++ b/tensorcircuit/backends/jax_backend.py @@ -50,12 +50,17 @@ def update(self, grads: pytree, params: pytree) -> pytree: return params -def _convert_to_tensor_jax(self: Any, tensor: Tensor) -> Tensor: +def _convert_to_tensor_jax( + self: Any, tensor: Tensor, dtype: Optional[str] = None +) -> Tensor: if not isinstance(tensor, (np.ndarray, jnp.ndarray)) and not jnp.isscalar(tensor): raise TypeError( ("Expected a `jnp.array`, `np.array` or scalar. " f"Got {type(tensor)}") ) result = jnp.asarray(tensor) + if dtype is not None: + # Use the backend's cast method to handle dtype conversion + result = self.cast(result, dtype) return result @@ -243,8 +248,10 @@ def zeros(self, shape: Tuple[int, ...], dtype: Optional[str] = None) -> Tensor: def copy(self, tensor: Tensor) -> Tensor: return jnp.array(tensor, copy=True) - def convert_to_tensor(self, tensor: Tensor) -> Tensor: + def convert_to_tensor(self, tensor: Tensor, dtype: Optional[str] = None) -> Tensor: result = jnp.asarray(tensor) + if dtype is not None: + result = self.cast(result, dtype) return result def abs(self, a: Tensor) -> Tensor: @@ -390,6 +397,9 @@ def argmin(self, a: Tensor, axis: int = 0) -> Tensor: def argsort(self, a: Tensor, axis: int = -1) -> Tensor: return jnp.argsort(a, axis=axis) + def sort(self, a: Tensor, axis: int = -1) -> Tensor: + return jnp.sort(a, axis=axis) + def unique_with_counts( # type: ignore self, a: Tensor, *, size: Optional[int] = None, fill_value: Optional[int] = None ) -> Tuple[Tensor, Tensor]: @@ -410,6 +420,9 @@ def onehot(self, a: Tensor, num: int) -> Tensor: def cumsum(self, a: Tensor, axis: Optional[int] = None) -> Tensor: return jnp.cumsum(a, axis) + def all(self, a: Tensor, axis: Optional[Sequence[int]] = None) -> Tensor: + return jnp.all(a, axis=axis) + def is_tensor(self, a: Any) -> bool: if not isinstance(a, jnp.ndarray): return False @@ -812,4 +825,23 @@ def wrapper( vvag = vectorized_value_and_grad + def meshgrid(self, *args: Any, **kwargs: Any) -> Any: + """ + Backend-agnostic meshgrid function. + """ + return jnp.meshgrid(*args, **kwargs) + optimizer = optax_optimizer + + def expand_dims(self, a: Tensor, axis: int) -> Tensor: + return jnp.expand_dims(a, axis) + + def where( + self, + condition: Tensor, + x: Optional[Tensor] = None, + y: Optional[Tensor] = None, + ) -> Tensor: + if x is None and y is None: + return jnp.where(condition) + return jnp.where(condition, x, y) diff --git a/tensorcircuit/backends/numpy_backend.py b/tensorcircuit/backends/numpy_backend.py index 633a846..8678f3d 100644 --- a/tensorcircuit/backends/numpy_backend.py +++ b/tensorcircuit/backends/numpy_backend.py @@ -35,10 +35,14 @@ def _sum_numpy( # see https://github.com/google/TensorNetwork/issues/952 -def _convert_to_tensor_numpy(self: Any, a: Tensor) -> Tensor: +def _convert_to_tensor_numpy( + self: Any, a: Tensor, dtype: Optional[str] = None +) -> Tensor: if not isinstance(a, np.ndarray) and not np.isscalar(a): a = np.array(a) a = np.asarray(a) + if dtype is not None: + a = a.astype(getattr(np, dtype)) return a @@ -132,6 +136,9 @@ def eigvalsh(self, a: Tensor) -> Tensor: def kron(self, a: Tensor, b: Tensor) -> Tensor: return np.kron(a, b) + def meshgrid(self, *args: Any, **kwargs: Any) -> Any: + return np.meshgrid(*args, **kwargs) + def dtype(self, a: Tensor) -> str: return a.dtype.__str__() # type: ignore @@ -151,6 +158,9 @@ def i(self, dtype: Any = None) -> Tensor: dtype = getattr(np, dtype) return np.array(1j, dtype=dtype) + def expand_dims(self, a: Tensor, axis: int) -> Tensor: + return np.expand_dims(a, axis) + def stack(self, a: Sequence[Tensor], axis: int = 0) -> Tensor: return np.stack(a, axis=axis) @@ -173,6 +183,9 @@ def std( ) -> Tensor: return np.std(a, axis=axis, keepdims=keepdims) + def all(self, a: Tensor, axis: Optional[Sequence[int]] = None) -> Tensor: + return np.all(a, axis=axis) + def unique_with_counts(self, a: Tensor, **kws: Any) -> Tuple[Tensor, Tensor]: return np.unique(a, return_counts=True) # type: ignore @@ -188,6 +201,9 @@ def argmax(self, a: Tensor, axis: int = 0) -> Tensor: def argmin(self, a: Tensor, axis: int = 0) -> Tensor: return np.argmin(a, axis=axis) + def sort(self, a: Tensor, axis: int = -1) -> Tensor: + return np.sort(a, axis=axis) + def sigmoid(self, a: Tensor) -> Tensor: return expit(a) @@ -345,6 +361,17 @@ def to_dense(self, sp_a: Tensor) -> Tensor: def is_sparse(self, a: Tensor) -> bool: return issparse(a) # type: ignore + def where( + self, + condition: Tensor, + x: Optional[Tensor] = None, + y: Optional[Tensor] = None, + ) -> Tensor: + if x is None and y is None: + return np.where(condition) + assert x is not None and y is not None + return np.where(condition, x, y) + def cond( self, pred: bool, diff --git a/tensorcircuit/backends/pytorch_backend.py b/tensorcircuit/backends/pytorch_backend.py index 8317630..cd037b6 100644 --- a/tensorcircuit/backends/pytorch_backend.py +++ b/tensorcircuit/backends/pytorch_backend.py @@ -238,6 +238,15 @@ def zeros(self, shape: Tuple[int, ...], dtype: Optional[str] = None) -> Tensor: def copy(self, a: Tensor) -> Tensor: return a.clone() + def convert_to_tensor(self, tensor: Tensor, dtype: Optional[str] = None) -> Tensor: + if self.is_tensor(tensor): + result = tensor + else: + result = torchlib.tensor(tensor) + if dtype is not None: + result = self.cast(result, dtype) + return result + def expm(self, a: Tensor) -> Tensor: raise NotImplementedError("pytorch backend doesn't support expm") # in 2020, torch has no expm, hmmm. but that's ok, @@ -369,6 +378,17 @@ def argmax(self, a: Tensor, axis: int = 0) -> Tensor: def argmin(self, a: Tensor, axis: int = 0) -> Tensor: return torchlib.argmin(a, dim=axis) + def sort(self, a: Tensor, axis: int = -1) -> Tensor: + return torchlib.sort(a, dim=axis).values + + def all(self, tensor: Tensor, axis: Optional[Sequence[int]] = None) -> Tensor: + """ + Corresponds to torch.all. + """ + if axis is None: + return torchlib.all(tensor) + return torchlib.all(tensor, dim=axis) + def unique_with_counts(self, a: Tensor, **kws: Any) -> Tuple[Tensor, Tensor]: return torchlib.unique(a, return_counts=True) # type: ignore @@ -425,6 +445,16 @@ def searchsorted(self, a: Tensor, v: Tensor, side: str = "left") -> Tensor: v = self.convert_to_tensor(v) return torchlib.searchsorted(a, v, side=side) + def where( + self, + condition: Tensor, + x: Optional[Tensor] = None, + y: Optional[Tensor] = None, + ) -> Tensor: + if x is None and y is None: + return torchlib.where(condition) + return torchlib.where(condition, x, y) + def reverse(self, a: Tensor) -> Tensor: return torchlib.flip(a, dims=(-1,)) @@ -706,6 +736,12 @@ def wrapper( return wrapper + def expand_dims(self, a: Tensor, axis: int) -> Tensor: + return torchlib.unsqueeze(a, dim=axis) + vvag = vectorized_value_and_grad + def meshgrid(self, *args: Any, **kws: Any) -> Tensor: + return torchlib.meshgrid(*args, **kws) + optimizer = torch_optimizer diff --git a/tensorcircuit/backends/tensorflow_backend.py b/tensorcircuit/backends/tensorflow_backend.py index ab79ae6..f7c7ae5 100644 --- a/tensorcircuit/backends/tensorflow_backend.py +++ b/tensorcircuit/backends/tensorflow_backend.py @@ -75,6 +75,12 @@ def update(self, grads: pytree, params: pytree) -> pytree: def _tensordot_tf( self: Any, a: Tensor, b: Tensor, axes: Union[int, Sequence[Sequence[int]]] ) -> Tensor: + # Use TensorFlow's dtype promotion rules by converting both to a common dtype + if a.dtype != b.dtype: + # Find the result dtype using TensorFlow's type promotion rules + common_dtype = tf.experimental.numpy.result_type(a.dtype, b.dtype) + a = tf.cast(a, common_dtype) + b = tf.cast(b, common_dtype) return tf.tensordot(a, b, axes) @@ -441,6 +447,12 @@ def zeros(self, shape: Tuple[int, ...], dtype: Optional[str] = None) -> Tensor: def copy(self, a: Tensor) -> Tensor: return tf.identity(a) + def convert_to_tensor(self, tensor: Tensor, dtype: Optional[str] = None) -> Tensor: + result = tf.convert_to_tensor(tensor) + if dtype is not None: + result = self.cast(result, dtype) + return result + def expm(self, a: Tensor) -> Tensor: return tf.linalg.expm(a) @@ -524,6 +536,20 @@ def min(self, a: Tensor, axis: Optional[int] = None) -> Tensor: def max(self, a: Tensor, axis: Optional[int] = None) -> Tensor: return tf.reduce_max(a, axis=axis) + def all(self, a: Tensor, axis: Optional[Sequence[int]] = None) -> Tensor: + return tf.reduce_all(tf.cast(a, tf.bool), axis=axis) + + def where( + self, + condition: Tensor, + x: Optional[Tensor] = None, + y: Optional[Tensor] = None, + ) -> Tensor: + if x is None and y is None: + # Return a tuple of tensors to be consistent with other backends + return tuple(tf.unstack(tf.where(condition), axis=1)) + return tf.where(condition, x, y) + def argmax(self, a: Tensor, axis: int = 0) -> Tensor: return tf.math.argmax(a, axis=axis) @@ -533,6 +559,9 @@ def argmin(self, a: Tensor, axis: int = 0) -> Tensor: def argsort(self, a: Tensor, axis: int = -1) -> Tensor: return tf.argsort(a, axis=axis) + def sort(self, a: Tensor, axis: int = -1) -> Tensor: + return tf.sort(a, axis=axis) + def shape_tuple(self, a: Tensor) -> Tuple[int, ...]: return tuple(a.shape) @@ -1061,4 +1090,13 @@ def wrapper( vvag = vectorized_value_and_grad + def meshgrid(self, *args: Any, **kwargs: Any) -> Any: + """ + Backend-agnostic meshgrid function. + """ + return tf.meshgrid(*args, **kwargs) + optimizer = keras_optimizer + + def expand_dims(self, a: Tensor, axis: int) -> Tensor: + return tf.expand_dims(a, axis) diff --git a/tensorcircuit/templates/hamiltonians.py b/tensorcircuit/templates/hamiltonians.py index 0382b11..b7e0cc6 100644 --- a/tensorcircuit/templates/hamiltonians.py +++ b/tensorcircuit/templates/hamiltonians.py @@ -17,13 +17,14 @@ def _create_empty_sparse_matrix(shape: Tuple[int, int]) -> Any: def heisenberg_hamiltonian( lattice: AbstractLattice, j_coupling: Union[float, List[float], Tuple[float, ...]] = 1.0, + interaction_scope: str = "neighbors", ) -> Any: """ Generates the sparse matrix of the Heisenberg Hamiltonian for a given lattice. The Heisenberg Hamiltonian is defined as: - H = J * Σ_{} (X_i X_j + Y_i Y_j + Z_i Z_j) - where the sum is over all unique nearest-neighbor pairs . + H = J * Σ_{i,j} (X_i X_j + Y_i Y_j + Z_i Z_j) + where the sum is over a specified set of interacting pairs {i,j}. :param lattice: An instance of a class derived from AbstractLattice, which provides the geometric information of the system. @@ -32,11 +33,23 @@ def heisenberg_hamiltonian( isotropic model (Jx=Jy=Jz) or a list/tuple of 3 floats for an anisotropic model (Jx, Jy, Jz). Defaults to 1.0. :type j_coupling: Union[float, List[float], Tuple[float, ...]], optional + :param interaction_scope: Defines the range of interactions. + - "neighbors": Includes only nearest-neighbor pairs (default). + - "all": Includes all unique pairs of sites. + :type interaction_scope: str, optional :return: The Hamiltonian as a backend-agnostic sparse matrix. :rtype: Any """ num_sites = lattice.num_sites - neighbor_pairs = lattice.get_neighbor_pairs(k=1, unique=True) + if interaction_scope == "neighbors": + neighbor_pairs = lattice.get_neighbor_pairs(k=1, unique=True) + elif interaction_scope == "all": + neighbor_pairs = lattice.get_all_pairs() + else: + raise ValueError( + f"Invalid interaction_scope: '{interaction_scope}'. " + "Must be 'neighbors' or 'all'." + ) if isinstance(j_coupling, (float, int)): js = [float(j_coupling)] * 3 diff --git a/tensorcircuit/templates/lattice.py b/tensorcircuit/templates/lattice.py index 52f152c..1fa8dc7 100644 --- a/tensorcircuit/templates/lattice.py +++ b/tensorcircuit/templates/lattice.py @@ -18,11 +18,12 @@ Set, ) -logger = logging.getLogger(__name__) +import itertools +import math import numpy as np - from scipy.spatial import KDTree -from scipy.spatial.distance import pdist, squareform + +from .. import backend # This block resolves a name resolution issue for the static type checker (mypy). @@ -41,9 +42,13 @@ import matplotlib.axes from mpl_toolkits.mplot3d import Axes3D +logger = logging.getLogger(__name__) + +Tensor = Any SiteIndex = int SiteIdentifier = Hashable -Coordinates = np.ndarray[Any, Any] +Coordinates = Tensor + NeighborMap = Dict[SiteIndex, List[SiteIndex]] @@ -64,13 +69,27 @@ def __init__(self, dimensionality: int): """Initializes the base lattice class.""" self._dimensionality = dimensionality - # --- Internal Data Structures (to be populated by subclasses) --- - self._indices: List[SiteIndex] = [] - self._identifiers: List[SiteIdentifier] = [] - self._coordinates: List[Coordinates] = [] - self._ident_to_idx: Dict[SiteIdentifier, SiteIndex] = {} - self._neighbor_maps: Dict[int, NeighborMap] = {} - self._distance_matrix: Optional[Coordinates] = None + # Core data structures for storing site information. + self._indices: List[SiteIndex] = [] # List of integer indices [0, 1, ..., N-1] + self._identifiers: List[SiteIdentifier] = ( + [] + ) # List of unique, hashable site identifiers + # Always initialize to an empty coordinate tensor with correct dimensionality + # so that type checkers know this is indexable and not Optional. + self._coordinates: Coordinates = backend.zeros((0, dimensionality)) + + # Mappings for efficient lookups. + self._ident_to_idx: Dict[SiteIdentifier, SiteIndex] = ( + {} + ) # Maps identifiers to indices + + # Cached properties, computed on demand. + self._neighbor_maps: Dict[int, NeighborMap] = ( + {} + ) # Caches neighbor info for different k + self._distance_matrix: Optional[Coordinates] = ( + None # Caches the full N x N distance matrix + ) @property def num_sites(self) -> int: @@ -95,7 +114,6 @@ def distance_matrix(self) -> Coordinates: subsequent calls. This computation can be expensive for large lattices. """ if self._distance_matrix is None: - logger.info("Distance matrix not cached. Computing now...") self._distance_matrix = self._compute_distance_matrix() return self._distance_matrix @@ -116,7 +134,8 @@ def get_coordinates(self, index: SiteIndex) -> Coordinates: :rtype: Coordinates """ self._validate_index(index) - return self._coordinates[index] + coords = self._coordinates[index] + return coords def get_identifier(self, index: SiteIndex) -> SiteIdentifier: """Gets the abstract identifier of a site by its integer index. @@ -140,7 +159,8 @@ def get_index(self, identifier: SiteIdentifier) -> SiteIndex: :rtype: SiteIndex """ try: - return self._ident_to_idx[identifier] + index = self._ident_to_idx[identifier] + return index except KeyError as e: raise ValueError( f"Identifier {identifier} not found in the lattice." @@ -170,7 +190,7 @@ def get_site_info( idx = index_or_identifier self._validate_index(idx) return idx, self._identifiers[idx], self._coordinates[idx] - else: # Identifier + else: ident = index_or_identifier idx = self.get_index(ident) return idx, ident, self._coordinates[idx] @@ -237,7 +257,6 @@ def get_neighbor_pairs( ) self._build_neighbors(max_k=k) - # After attempting to build, check again. If still not found, return empty. if k not in self._neighbor_maps: return [] @@ -251,8 +270,28 @@ def get_neighbor_pairs( pairs.append((i, j)) return sorted(pairs) - # Sorting provides a deterministic output order - # --- Abstract Methods for Subclass Implementation --- + def get_all_pairs(self) -> List[Tuple[SiteIndex, SiteIndex]]: + """ + Returns a list of all unique pairs of site indices (i, j) where i < j. + + This method provides all-to-all connectivity, useful for Hamiltonians + where every site interacts with every other site. + + Note on Differentiability: + This method provides a static list of index pairs and is not differentiable + itself. However, it is designed to be used in combination with the fully + differentiable ``distance_matrix`` property. By using the pairs from this + method to index into the ``distance_matrix``, one can construct differentiable + objective functions based on all-pair interactions, effectively bypassing the + non-differentiable ``get_neighbor_pairs`` method for geometry optimization tasks. + + :return: A list of tuples, where each tuple is a unique pair of site indices. + :rtype: List[Tuple[SiteIndex, SiteIndex]] + """ + if self.num_sites < 2: + return [] + # Use itertools.combinations to efficiently generate all unique pairs (i, j) with i < j. + return sorted(list(itertools.combinations(range(self.num_sites), 2))) @abc.abstractmethod def _build_lattice(self, *args: Any, **kwargs: Any) -> None: @@ -281,14 +320,24 @@ def _build_neighbors(self, max_k: int = 1, **kwargs: Any) -> None: """ pass - @abc.abstractmethod def _compute_distance_matrix(self) -> Coordinates: """ - Abstract method for subclasses to implement the actual matrix calculation. - This method is called by the `distance_matrix` property when the matrix - needs to be computed for the first time. + Default generic distance matrix computation (no periodic images). + + Subclasses can override this when a specialized rule is required + (e.g., applying Minimum Image Convention for PBC in TILattice). """ - pass + # Handle empty lattices and trivial 1-site lattices + if self.num_sites == 0: + return backend.zeros((0, 0)) + + # Vectorized pairwise Euclidean distances + all_coords = self._coordinates + displacements = backend.expand_dims(all_coords, 1) - backend.expand_dims( + all_coords, 0 + ) + dist_matrix_sq = backend.sum(displacements**2, axis=-1) + return backend.sqrt(dist_matrix_sq) def show( self, @@ -334,7 +383,8 @@ def show( ) return - # creat "fig_created_internally" as flag + # Flag to track if the Matplotlib figure was created by this method. + # This prevents calling plt.show() if the user provided their own Axes. fig_created_internally = False if self.num_sites == 0: @@ -347,7 +397,7 @@ def show( return if ax is None: - # when ax is none, make fig_created_internally true + # If no Axes object is provided, create a new figure and axes. fig_created_internally = True if self.dimensionality == 3: fig = plt.figure(figsize=(8, 8)) @@ -358,6 +408,7 @@ def show( fig = ax.figure # type: ignore coords = np.array(self._coordinates) + # Prepare arguments for the scatter plot, allowing user overrides. scatter_args = {"s": 100, "zorder": 2} scatter_args.update(kwargs) if self.dimensionality == 1: @@ -371,12 +422,11 @@ def show( if show_indices or show_identifiers: for i in range(self.num_sites): label = str(self._identifiers[i]) if show_identifiers else str(i) + # Calculate a small offset for placing text labels to avoid overlap with sites. offset = ( 0.02 * np.max(np.ptp(coords, axis=0)) if coords.size > 0 else 0.1 ) - # Robust Logic: Decide plotting strategy based on known dimensionality. - if self.dimensionality == 1: ax.text(coords[i, 0], offset, label, fontsize=9, ha="center") elif self.dimensionality == 2: @@ -398,8 +448,6 @@ def show( zorder=3, ) - # Note: No 'else' needed as we already check dimensionality at the start. - if show_bonds_k is not None: if show_bonds_k not in self._neighbor_maps: logger.warning( @@ -433,7 +481,7 @@ def show( if self.dimensionality == 1: # type: ignore ax.plot([p1[0], p2[0]], [0, 0], **plot_bond_kwargs) # type: ignore - else: # dimensionality == 2 + else: ax.plot([p1[0], p2[0]], [p1[1], p2[1]], **plot_bond_kwargs) # type: ignore except ValueError as e: @@ -449,7 +497,7 @@ def show( ax.set_zlabel("z") ax.grid(True) - # 3. whether plt.show() + # Display the plot only if the figure was created within this function. if fig_created_internally: plt.show() @@ -475,26 +523,28 @@ def _identify_distance_shells( :return: A sorted list of squared distances representing the shells. :rtype: List[float] """ + # A small threshold to filter out zero distances (site to itself). ZERO_THRESHOLD_SQ = 1e-12 - all_distances_sq = np.asarray(all_distances_sq) + all_distances_sq = backend.convert_to_tensor(all_distances_sq) # Now, the .size call below is guaranteed to be safe. - if all_distances_sq.size == 0: + if backend.sizen(all_distances_sq) == 0: return [] - sorted_dist = np.sort(all_distances_sq[all_distances_sq > ZERO_THRESHOLD_SQ]) + # Filter out self-distances and sort the remaining squared distances. + sorted_dist = backend.sort( + all_distances_sq[all_distances_sq > ZERO_THRESHOLD_SQ] + ) - if sorted_dist.size == 0: + if backend.sizen(sorted_dist) == 0: return [] - # Identify shells using the user-provided tolerance. dist_shells = [sorted_dist[0]] for d_sq in sorted_dist[1:]: if len(dist_shells) >= max_k: break - # If the current distance is notably larger than the last shell's distance - if d_sq > dist_shells[-1] + tol**2: + if backend.sqrt(d_sq) - backend.sqrt(dist_shells[-1]) > tol: dist_shells.append(d_sq) return dist_shells @@ -503,11 +553,9 @@ def _build_neighbors_by_distance_matrix( self, max_k: int = 2, tol: float = 1e-6 ) -> None: """A generic, distance-based neighbor finding method. - This method calculates the full N x N distance matrix to find neighbor shells. It is computationally expensive for large N (O(N^2)) and is best suited for non-periodic or custom-defined lattices. - :param max_k: The maximum number of neighbor shells to calculate. Defaults to 2. :type max_k: int, optional @@ -518,26 +566,55 @@ def _build_neighbors_by_distance_matrix( if self.num_sites < 2: return - all_coords = np.array(self._coordinates) - dist_matrix_sq = np.sum( - (all_coords[:, np.newaxis, :] - all_coords[np.newaxis, :, :]) ** 2, axis=-1 + all_coords = self._coordinates + # Vectorized computation of the squared distance matrix: + # (N, 1, D) - (1, N, D) -> (N, N, D) -> (N, N) + displacements = backend.expand_dims(all_coords, 1) - backend.expand_dims( + all_coords, 0 ) + dist_matrix_sq = backend.sum(displacements**2, axis=-1) - all_distances_sq = dist_matrix_sq.flatten() + # Flatten the matrix to a list of all squared distances to identify shells. + all_distances_sq = backend.reshape(dist_matrix_sq, [-1]) dist_shells_sq = self._identify_distance_shells(all_distances_sq, max_k, tol) - self._neighbor_maps = {k: {} for k in range(1, len(dist_shells_sq) + 1)} + self._neighbor_maps = self._build_neighbor_map_from_distances( + dist_matrix_sq, dist_shells_sq, tol + ) + self._distance_matrix = backend.sqrt(dist_matrix_sq) + + def _build_neighbor_map_from_distances( + self, + dist_matrix_sq: Coordinates, + dist_shells_sq: List[float], + tol: float = 1e-6, + ) -> Dict[int, NeighborMap]: + """ + Builds a neighbor map from a squared distance matrix and identified shells. + This is a generic helper function to reduce code duplication. + """ + neighbor_maps: Dict[int, NeighborMap] = { + k: {} for k in range(1, len(dist_shells_sq) + 1) + } for k_idx, target_d_sq in enumerate(dist_shells_sq): k = k_idx + 1 current_k_map: Dict[int, List[int]] = {} - for i in range(self.num_sites): - neighbor_indices = np.where( - np.isclose(dist_matrix_sq[i], target_d_sq, rtol=0, atol=tol**2) - )[0] - if len(neighbor_indices) > 0: - current_k_map[i] = sorted(neighbor_indices.tolist()) - self._neighbor_maps[k] = current_k_map - self._distance_matrix = np.sqrt(dist_matrix_sq) + # For each shell, find all pairs of sites (i, j) with that distance. + is_close_matrix = backend.abs(dist_matrix_sq - target_d_sq) < tol + rows, cols = backend.where(is_close_matrix) + + for i, j in zip(backend.numpy(rows), backend.numpy(cols)): + if i == j: + continue + if i not in current_k_map: + current_k_map[i] = [] + current_k_map[i].append(j) + + for i in current_k_map: + current_k_map[i].sort() + + neighbor_maps[k] = current_k_map + return neighbor_maps class TILattice(AbstractLattice): @@ -588,150 +665,197 @@ def __init__( ): """Initializes the Translationally Invariant Lattice.""" super().__init__(dimensionality) - assert lattice_vectors.shape == ( - dimensionality, - dimensionality, - ), "Lattice vectors shape mismatch" - assert ( - basis_coords.shape[1] == dimensionality - ), "Basis coordinates dimension mismatch" - assert len(size) == dimensionality, "Size tuple length mismatch" - - self.lattice_vectors = lattice_vectors - self.basis_coords = basis_coords - self.num_basis = basis_coords.shape[0] + + self.lattice_vectors = backend.convert_to_tensor(lattice_vectors) + self.basis_coords = backend.convert_to_tensor(basis_coords) + + if self.lattice_vectors.shape != (dimensionality, dimensionality): + raise ValueError( + f"Lattice vectors shape {self.lattice_vectors.shape} does not match " + f"expected ({dimensionality}, {dimensionality})" + ) + if self.basis_coords.shape[1] != dimensionality: + raise ValueError( + f"Basis coordinates dimension {self.basis_coords.shape[1]} does not " + f"match lattice dimensionality {dimensionality}" + ) + if len(size) != dimensionality: + raise ValueError( + f"Size tuple length {len(size)} does not match dimensionality {dimensionality}" + ) + + self.num_basis = self.basis_coords.shape[0] self.size = size if isinstance(pbc, bool): self.pbc = tuple([pbc] * dimensionality) else: - assert len(pbc) == dimensionality, "PBC tuple length mismatch" + if len(pbc) != dimensionality: + raise ValueError( + f"PBC tuple length {len(pbc)} does not match dimensionality {dimensionality}" + ) self.pbc = tuple(pbc) - # Build the lattice sites and their neighbor relationships self._build_lattice() if precompute_neighbors is not None and precompute_neighbors > 0: logger.info(f"Pre-computing neighbors up to k={precompute_neighbors}...") self._build_neighbors(max_k=precompute_neighbors) def _build_lattice(self) -> None: - """Generates all site information for the periodic lattice. - - This method iterates through each unit cell defined by `self.size`, - and for each unit cell, it iterates through all basis sites. It then - calculates the real-space coordinates and creates a unique identifier - for each site, populating the internal lattice data structures. """ - current_index = 0 + Generates all site information for the periodic lattice in a vectorized manner. + """ + ranges = [backend.arange(s) for s in self.size] - # Iterate over all unit cell coordinates elegantly using np.ndindex - for cell_coord in np.ndindex(self.size): - cell_coord_arr = np.array(cell_coord) - # R = n1*a1 + n2*a2 + ... - cell_vector = np.dot(cell_coord_arr, self.lattice_vectors) + # Generate a grid of all integer unit cell coordinates. + grid = backend.meshgrid(*ranges, indexing="ij") + all_cell_coords = backend.reshape( + backend.stack(grid, axis=-1), (-1, self.dimensionality) + ) - # Iterate over the basis sites within the unit cell - for basis_index in range(self.num_basis): - basis_vec = self.basis_coords[basis_index] + all_cell_coords = backend.cast(all_cell_coords, self.lattice_vectors.dtype) + + cell_vectors = backend.tensordot( + all_cell_coords, self.lattice_vectors, axes=[[1], [0]] + ) + + cell_vectors = backend.cast(cell_vectors, self.basis_coords.dtype) + + # Combine cell vectors with basis coordinates to get all site positions + # via broadcasting: (num_cells, 1, D) + (1, num_basis, D) -> (num_cells, num_basis, D) + all_coords = backend.expand_dims(cell_vectors, 1) + backend.expand_dims( + self.basis_coords, 0 + ) - # Calculate the real-space coordinate - coord = cell_vector + basis_vec - # Create a structured identifier - identifier = cell_coord + (basis_index,) + self._coordinates = backend.reshape(all_coords, (-1, self.dimensionality)) - # Store site information + self._indices = [] + self._identifiers = [] + self._ident_to_idx = {} + current_index = 0 + + # Generate integer indices and tuple-based identifiers for all sites. + # e.g., identifier = (uc_x, uc_y, basis_idx) + size_ranges = [range(s) for s in self.size] + for cell_coord_tuple in itertools.product(*size_ranges): + for basis_index in range(self.num_basis): + identifier = cell_coord_tuple + (basis_index,) self._indices.append(current_index) self._identifiers.append(identifier) - self._coordinates.append(coord) self._ident_to_idx[identifier] = current_index current_index += 1 - def _get_distance_matrix_with_mic(self) -> Coordinates: + def _get_distance_matrix_with_mic_vectorized(self) -> Coordinates: """ - Computes the full N x N distance matrix, correctly applying the - Minimum Image Convention (MIC) for all periodic dimensions. + Computes the full N x N distance matrix using a fully vectorized approach + that correctly applies the Minimum Image Convention (MIC) for periodic + boundary conditions. + + This method uses full vectorization for optimal performance and compatibility + with JIT compilation frameworks like JAX. The implementation processes all + site pairs simultaneously rather than iterating row-by-row, which provides: + + - Better performance through vectorized operations + - Full compatibility with automatic differentiation + - JIT compilation support (e.g., JAX, TensorFlow) + - Consistent tensor operations throughout + + The trade-off is higher memory usage compared to iterative approaches, + as it computes all pairwise distances simultaneously. For very large + lattices (N > 10^4 sites), memory usage scales as O(N^2). + + :return: Distance matrix with shape (N, N) where entry (i,j) is the + minimum distance between sites i and j under periodic boundary conditions. + :rtype: Coordinates """ - all_coords = np.array(self._coordinates) - size_arr = np.array(self.size) - system_vectors = self.lattice_vectors * size_arr[:, np.newaxis] - - # Generate translation vectors ONLY for periodic dimensions - pbc_dims = [d for d in range(self.dimensionality) if self.pbc[d]] - translations = [np.zeros(self.dimensionality)] - if pbc_dims: - num_pbc_dims = len(pbc_dims) - pbc_system_vectors = system_vectors[pbc_dims, :] - - # Create all 3^k - 1 non-zero shifts for k periodic dimensions - shift_options = [np.array([-1, 0, 1])] * num_pbc_dims - shifts_grid = np.meshgrid(*shift_options, indexing="ij") - all_shifts = np.stack(shifts_grid, axis=-1).reshape(-1, num_pbc_dims) - all_shifts = all_shifts[np.any(all_shifts != 0, axis=1)] - - pbc_translations = all_shifts @ pbc_system_vectors - translations.extend(pbc_translations) - - translations_arr = np.array(translations, dtype=float) - - # Calculate the distance matrix applying MIC - dist_matrix_sq = np.full((self.num_sites, self.num_sites), np.inf, dtype=float) - for i in range(self.num_sites): - displacements = all_coords - all_coords[i] - image_displacements = ( - displacements[:, np.newaxis, :] - translations_arr[np.newaxis, :, :] - ) - image_d_sq = np.sum(image_displacements**2, axis=2) - dist_matrix_sq[i, :] = np.min(image_d_sq, axis=1) + # Ensure dtype consistency across backends (especially torch) by explicitly + # casting size and lattice_vectors to the same floating dtype used internally. + # Strategy: prefer existing lattice_vectors dtype; if it's an unusual dtype, + # fall back to float32 to avoid mixed-precision issues in vectorized ops. + # Note: `self.lattice_vectors` is always created via `backend.convert_to_tensor` + # in __init__, so `backend.dtype(...)` is reliable here and doesn't need try/except. + target_dt = str(backend.dtype(self.lattice_vectors)) # type: ignore + if target_dt not in ("float32", "float64"): + # fallback for unusual dtypes + target_dt = "float32" + + size_arr = backend.cast(backend.convert_to_tensor(self.size), target_dt) + lattice_vecs = backend.cast( + backend.convert_to_tensor(self.lattice_vectors), target_dt + ) + system_vectors = lattice_vecs * backend.expand_dims(size_arr, axis=1) + + pbc_mask = backend.convert_to_tensor(self.pbc) - return cast(Coordinates, np.sqrt(dist_matrix_sq)) + # Generate all 3^d possible image shifts (-1, 0, 1) for all dimensions + shift_options = [ + backend.convert_to_tensor([-1.0, 0.0, 1.0]) + ] * self.dimensionality + shifts_grid = backend.meshgrid(*shift_options, indexing="ij") + all_shifts = backend.reshape( + backend.stack(shifts_grid, axis=-1), (-1, self.dimensionality) + ) + + # Only apply shifts to periodic dimensions + masked_shifts = all_shifts * backend.cast(pbc_mask, all_shifts.dtype) + + # Calculate all translation vectors due to PBC + translations_arr = backend.tensordot( + masked_shifts, system_vectors, axes=[[1], [0]] + ) + + # Vectorized computation of all displacements between any two sites + # Shape: (N, 1, D) - (1, N, D) -> (N, N, D) + displacements = backend.expand_dims(self._coordinates, 1) - backend.expand_dims( + self._coordinates, 0 + ) + + # Consider all periodic images for each displacement + # Shape: (N, N, 1, D) - (1, 1, num_translations, D) -> (N, N, num_translations, D) + image_displacements = backend.expand_dims( + displacements, 2 + ) - backend.expand_dims(backend.expand_dims(translations_arr, 0), 0) + + # Sum of squares for distances + image_d_sq = backend.sum(image_displacements**2, axis=3) + + # Find the minimum distance among all images (Minimum Image Convention) + min_dist_sq = backend.min(image_d_sq, axis=2) + + safe_dist_matrix_sq = backend.where(min_dist_sq > 0, min_dist_sq, 0.0) + return backend.sqrt(safe_dist_matrix_sq) def _build_neighbors(self, max_k: int = 2, **kwargs: Any) -> None: """Calculates neighbor relationships for the periodic lattice. - This method calculates neighbor relationships by computing the full N x N - distance matrix. It robustly handles all boundary conditions (fully - periodic, open, or mixed) by applying the Minimum Image Convention - (MIC) only to the periodic dimensions. - - From this distance matrix, it identifies unique neighbor shells up to - the specified `max_k` and populates the neighbor maps. The computed - distance matrix is then cached for future use. + This method computes neighbor information by first calculating the full + distance matrix using the Minimum Image Convention (MIC) to correctly + handle periodic boundary conditions. It then identifies unique distance + shells (e.g., nearest, next-nearest) and populates the neighbor maps + accordingly. This approach is general and works for any periodic lattice + geometry defined by the TILattice class. - :param max_k: The maximum number of neighbor shells to - calculate. Defaults to 2. + :param max_k: The maximum order of neighbors to compute (e.g., k=1 for + nearest neighbors, k=2 for next-nearest, etc.). Defaults to 2. :type max_k: int, optional - :param tol: The numerical tolerance for distance - comparisons. Defaults to 1e-6. - :type tol: float, optional + :param kwargs: Additional keyword arguments. May include: + - ``tol`` (float): The numerical tolerance used to determine if two + distances are equal when identifying shells. Defaults to 1e-6. """ tol = kwargs.get("tol", 1e-6) - dist_matrix = self._get_distance_matrix_with_mic() + dist_matrix = self._get_distance_matrix_with_mic_vectorized() dist_matrix_sq = dist_matrix**2 self._distance_matrix = dist_matrix - all_distances_sq = dist_matrix_sq.flatten() + all_distances_sq = backend.reshape(dist_matrix_sq, [-1]) dist_shells_sq = self._identify_distance_shells(all_distances_sq, max_k, tol) - - self._neighbor_maps = {k: {} for k in range(1, len(dist_shells_sq) + 1)} - for k_idx, target_d_sq in enumerate(dist_shells_sq): - k = k_idx + 1 - current_k_map: Dict[int, List[int]] = {} - match_indices = np.where( - np.isclose(dist_matrix_sq, target_d_sq, rtol=0, atol=tol**2) - ) - for i, j in zip(*match_indices): - if i == j: - continue - if i not in current_k_map: - current_k_map[i] = [] - current_k_map[i].append(j) - - for i in current_k_map: - current_k_map[i].sort() - - self._neighbor_maps[k] = current_k_map + self._neighbor_maps = self._build_neighbor_map_from_distances( + dist_matrix_sq, dist_shells_sq, tol + ) def _compute_distance_matrix(self) -> Coordinates: """Computes the distance matrix using the Minimum Image Convention.""" - return self._get_distance_matrix_with_mic() + if self.num_sites == 0: + return backend.zeros((0, 0)) + return self._get_distance_matrix_with_mic_vectorized() class SquareLattice(TILattice): @@ -759,20 +883,24 @@ class SquareLattice(TILattice): def __init__( self, size: Tuple[int, int], - lattice_constant: float = 1.0, + lattice_constant: Union[float, Any] = 1.0, pbc: Union[bool, Tuple[bool, bool]] = True, precompute_neighbors: Optional[int] = None, ): """Initializes the SquareLattice.""" dimensionality = 2 + # Define orthogonal lattice vectors for a square. + # Avoid mixing Python floats with backend Tensors (TF would error), + # so first convert inputs to tensors of a unified dtype, then stack. + lc = backend.convert_to_tensor(lattice_constant) + dt = backend.dtype(lc) + z = backend.cast(backend.convert_to_tensor(0.0), dt) + row1 = backend.stack([lc, z]) + row2 = backend.stack([z, lc]) + lattice_vectors = backend.stack([row1, row2]) + # A square lattice is a Bravais lattice, so it has a single-site basis. + basis_coords = backend.stack([backend.stack([z, z])]) - # Define lattice vectors for a square lattice - lattice_vectors = np.array([[lattice_constant, 0.0], [0.0, lattice_constant]]) - - # A square lattice has a single site in its basis - basis_coords = np.array([[0.0, 0.0]]) - - # Call the parent TILattice constructor with these parameters super().__init__( dimensionality=dimensionality, lattice_vectors=lattice_vectors, @@ -808,19 +936,28 @@ class HoneycombLattice(TILattice): def __init__( self, size: Tuple[int, int], - lattice_constant: float = 1.0, + lattice_constant: Union[float, Any] = 1.0, pbc: Union[bool, Tuple[bool, bool]] = True, precompute_neighbors: Optional[int] = None, ): """Initializes the HoneycombLattice.""" dimensionality = 2 a = lattice_constant + a_t = backend.convert_to_tensor(a) + zero = a_t * 0.0 - # Define the primitive lattice vectors for the underlying triangular lattice - lattice_vectors = a * np.array([[1.5, np.sqrt(3) / 2], [1.5, -np.sqrt(3) / 2]]) - - # Define the coordinates of the two basis sites (A and B) - basis_coords = a * np.array([[0.0, 0.0], [1.0, 0.0]]) # Site A # Site B + # Define the two primitive lattice vectors for the underlying triangular Bravais lattice. + rt3_over_2 = math.sqrt(3.0) / 2.0 + lattice_vectors = backend.stack( + [ + backend.stack([a_t * 1.5, a_t * rt3_over_2]), + backend.stack([a_t * 1.5, -a_t * rt3_over_2]), + ] + ) + # Define the two basis sites (A and B) within the unit cell. + basis_coords = backend.stack( + [backend.stack([zero, zero]), backend.stack([a_t * 1.0, zero])] + ) super().__init__( dimensionality=dimensionality, @@ -855,19 +992,30 @@ class TriangularLattice(TILattice): def __init__( self, size: Tuple[int, int], - lattice_constant: float = 1.0, + lattice_constant: Union[float, Any] = 1.0, pbc: Union[bool, Tuple[bool, bool]] = True, precompute_neighbors: Optional[int] = None, ): """Initializes the TriangularLattice.""" dimensionality = 2 a = lattice_constant + a_t = backend.convert_to_tensor(a) + zero = a_t * 0.0 - # Define the primitive lattice vectors for a triangular lattice - lattice_vectors = a * np.array([[1.0, 0.0], [0.5, np.sqrt(3) / 2]]) - - # A triangular lattice is a Bravais lattice, with a single site in its basis - basis_coords = np.array([[0.0, 0.0]]) + # Define the primitive lattice vectors for a triangular lattice. + lattice_vectors = backend.stack( + [ + backend.stack([a_t * 1.0, zero]), + backend.stack( + [ + a_t * 0.5, + a_t * backend.sqrt(backend.convert_to_tensor(3.0)) / 2.0, + ] + ), + ] + ) + # A triangular lattice is a Bravais lattice with a single-site basis. + basis_coords = backend.stack([backend.stack([zero, zero])]) super().__init__( dimensionality=dimensionality, @@ -896,13 +1044,18 @@ class ChainLattice(TILattice): def __init__( self, size: Tuple[int], - lattice_constant: float = 1.0, + lattice_constant: Union[float, Any] = 1.0, pbc: bool = True, precompute_neighbors: Optional[int] = None, ): dimensionality = 1 - lattice_vectors = np.array([[lattice_constant]]) - basis_coords = np.array([[0.0]]) + # The lattice vector is just the lattice constant along one dimension. + lc = backend.convert_to_tensor(lattice_constant) + lattice_vectors = backend.stack([backend.stack([lc])]) + # A simple chain is a Bravais lattice with a single-site basis. + zero = lc * 0.0 + basis_coords = backend.stack([backend.stack([zero])]) + super().__init__( dimensionality=dimensionality, lattice_vectors=lattice_vectors, @@ -934,15 +1087,17 @@ class DimerizedChainLattice(TILattice): def __init__( self, size: Tuple[int], - lattice_constant: float = 1.0, + lattice_constant: Union[float, Any] = 1.0, pbc: bool = True, precompute_neighbors: Optional[int] = None, ): dimensionality = 1 - # The unit cell vector connects two A sites, spanning length 2*a - lattice_vectors = np.array([[2 * lattice_constant]]) - # Basis has site A at origin, site B at distance 'a' - basis_coords = np.array([[0.0], [lattice_constant]]) + # The unit cell is twice the bond length, as it contains two sites. + lc = backend.convert_to_tensor(lattice_constant) + lattice_vectors = backend.stack([backend.stack([2 * lc])]) + # Two basis sites (A and B) separated by the bond length. + zero = lc * 0.0 + basis_coords = backend.stack([backend.stack([zero]), backend.stack([lc])]) super().__init__( dimensionality=dimensionality, @@ -975,14 +1130,22 @@ class RectangularLattice(TILattice): def __init__( self, size: Tuple[int, int], - lattice_constants: Tuple[float, float] = (1.0, 1.0), + lattice_constants: Union[Tuple[float, float], Any] = (1.0, 1.0), pbc: Union[bool, Tuple[bool, bool]] = True, precompute_neighbors: Optional[int] = None, ): dimensionality = 2 ax, ay = lattice_constants - lattice_vectors = np.array([[ax, 0.0], [0.0, ay]]) - basis_coords = np.array([[0.0, 0.0]]) + ax_t = backend.convert_to_tensor(ax) + dt = backend.dtype(ax_t) + ay_t = backend.cast(backend.convert_to_tensor(ay), dt) + z = backend.cast(backend.convert_to_tensor(0.0), dt) + # Orthogonal lattice vectors with potentially different lengths. + row1 = backend.stack([ax_t, z]) + row2 = backend.stack([z, ay_t]) + lattice_vectors = backend.stack([row1, row2]) + # A rectangular lattice is a Bravais lattice with a single-site basis. + basis_coords = backend.stack([backend.stack([z, z])]) super().__init__( dimensionality=dimensionality, @@ -1013,16 +1176,26 @@ class CheckerboardLattice(TILattice): def __init__( self, size: Tuple[int, int], - lattice_constant: float = 1.0, + lattice_constant: Union[float, Any] = 1.0, pbc: Union[bool, Tuple[bool, bool]] = True, precompute_neighbors: Optional[int] = None, ): dimensionality = 2 a = lattice_constant - # Primitive vectors for a square lattice rotated by 45 degrees. - lattice_vectors = a * np.array([[1.0, 1.0], [1.0, -1.0]]) - # Two-site basis - basis_coords = a * np.array([[0.0, 0.0], [1.0, 0.0]]) + a_t = backend.convert_to_tensor(a) + # The unit cell is a square rotated by 45 degrees. + lattice_vectors = backend.stack( + [ + backend.stack([a_t * 1.0, a_t * 1.0]), + backend.stack([a_t * 1.0, a_t * -1.0]), + ] + ) + # Two basis sites (A and B) within the unit cell. + zero = a_t * 0.0 + basis_coords = backend.stack( + [backend.stack([zero, zero]), backend.stack([a_t * 1.0, zero])] + ) + super().__init__( dimensionality=dimensionality, lattice_vectors=lattice_vectors, @@ -1052,16 +1225,30 @@ class KagomeLattice(TILattice): def __init__( self, size: Tuple[int, int], - lattice_constant: float = 1.0, + lattice_constant: Union[float, Any] = 1.0, pbc: Union[bool, Tuple[bool, bool]] = True, precompute_neighbors: Optional[int] = None, ): dimensionality = 2 a = lattice_constant - # Using a rectangular unit cell definition for simplicity - lattice_vectors = a * np.array([[2.0, 0.0], [1.0, np.sqrt(3)]]) - # Three-site basis - basis_coords = a * np.array([[0.0, 0.0], [1.0, 0.0], [0.5, np.sqrt(3) / 2.0]]) + a_t = backend.convert_to_tensor(a) + # The Kagome lattice is based on a triangular Bravais lattice. + lattice_vectors = backend.stack( + [ + backend.stack([a_t * 2.0, a_t * 0.0]), + backend.stack([a_t * 1.0, a_t * backend.sqrt(3.0)]), + ] + ) + # It has a three-site basis, forming the corners of the triangles. + zero = a_t * 0.0 + basis_coords = backend.stack( + [ + backend.stack([zero, zero]), + backend.stack([a_t * 1.0, zero]), + backend.stack([a_t * 0.5, a_t * backend.sqrt(3.0) / 2.0]), + ] + ) + super().__init__( dimensionality=dimensionality, lattice_vectors=lattice_vectors, @@ -1092,29 +1279,26 @@ class LiebLattice(TILattice): def __init__( self, size: Tuple[int, int], - lattice_constant: float = 1.0, + lattice_constant: Union[float, Any] = 1.0, pbc: Union[bool, Tuple[bool, bool]] = True, precompute_neighbors: Optional[int] = None, ): """Initializes the LiebLattice.""" dimensionality = 2 - # Use a more descriptive name for clarity. In a Lieb lattice, - # the lattice_constant is the bond length between nearest neighbors. bond_length = lattice_constant - - # The unit cell of a Lieb lattice is a square with side length - # equal to twice the bond length. - unit_cell_side = 2 * bond_length - lattice_vectors = np.array([[unit_cell_side, 0.0], [0.0, unit_cell_side]]) - - # The three-site basis consists of a corner site, a site on the - # center of the horizontal edge, and a site on the center of the vertical edge. - # Their coordinates are defined directly in terms of the physical bond length. - basis_coords = np.array( + bl_t = backend.convert_to_tensor(bond_length) + unit_cell_side_t = 2 * bl_t + # The Lieb lattice is based on a square Bravais lattice. + z = bl_t * 0.0 + lattice_vectors = backend.stack( + [backend.stack([unit_cell_side_t, z]), backend.stack([z, unit_cell_side_t])] + ) + # It has a three-site basis: one corner and two edge-centers. + basis_coords = backend.stack( [ - [0.0, 0.0], # Corner site - [bond_length, 0.0], # Horizontal edge center - [0.0, bond_length], # Vertical edge center + backend.stack([z, z]), # Corner site + backend.stack([bl_t, z]), # x-edge center + backend.stack([z, bl_t]), # y-edge center ] ) @@ -1147,14 +1331,24 @@ class CubicLattice(TILattice): def __init__( self, size: Tuple[int, int, int], - lattice_constant: float = 1.0, + lattice_constant: Union[float, Any] = 1.0, pbc: Union[bool, Tuple[bool, bool, bool]] = True, precompute_neighbors: Optional[int] = None, ): dimensionality = 3 a = lattice_constant - lattice_vectors = np.array([[a, 0, 0], [0, a, 0], [0, 0, a]]) - basis_coords = np.array([[0.0, 0.0, 0.0]]) + a_t = backend.convert_to_tensor(a) + # Orthogonal lattice vectors of equal length in 3D. + z = a_t * 0.0 + lattice_vectors = backend.stack( + [ + backend.stack([a_t, z, z]), + backend.stack([z, a_t, z]), + backend.stack([z, z, a_t]), + ] + ) + # A simple cubic lattice is a Bravais lattice with a single-site basis. + basis_coords = backend.stack([backend.stack([z, z, z])]) super().__init__( dimensionality=dimensionality, lattice_vectors=lattice_vectors, @@ -1194,29 +1388,37 @@ def __init__( self, dimensionality: int, identifiers: List[SiteIdentifier], - coordinates: List[Union[List[float], Coordinates]], + coordinates: Any, precompute_neighbors: Optional[int] = None, ): """Initializes the CustomizeLattice.""" super().__init__(dimensionality) - if len(identifiers) != len(coordinates): + + self._coordinates = backend.convert_to_tensor(coordinates) + if len(identifiers) == 0: + self._coordinates = backend.reshape( + self._coordinates, (0, self.dimensionality) + ) + + if len(identifiers) != backend.shape_tuple(self._coordinates)[0]: raise ValueError( - "Identifiers and coordinates lists must have the same length." + "The number of identifiers must match the number of coordinates. " + f"Got {len(identifiers)} identifiers and " + f"{backend.shape_tuple(self._coordinates)[0]} coordinates." ) - # The _build_lattice logic is simple enough to be in __init__ self._identifiers = list(identifiers) - self._coordinates = [np.array(c) for c in coordinates] self._indices = list(range(len(identifiers))) self._ident_to_idx = {ident: idx for idx, ident in enumerate(identifiers)} - # Validate coordinate dimensions - for i, coord in enumerate(self._coordinates): - if coord.shape != (dimensionality,): - raise ValueError( - f"Coordinate at index {i} has shape {coord.shape}, " - f"expected ({dimensionality},)" - ) + if ( + self.num_sites > 0 + and backend.shape_tuple(self._coordinates)[1] != dimensionality + ): + raise ValueError( + f"Coordinates tensor has dimension {backend.shape_tuple(self._coordinates)[1]}, " + f"but expected dimensionality is {dimensionality}." + ) logger.info(f"CustomizeLattice with {self.num_sites} sites created.") @@ -1228,95 +1430,172 @@ def _build_lattice(self, *args: Any, **kwargs: Any) -> None: pass def _build_neighbors(self, max_k: int = 1, **kwargs: Any) -> None: - """Calculates neighbors using a KDTree for efficiency. - - This method uses a memory-efficient approach to identify neighbors without - initially computing the full N x N distance matrix. It leverages - `scipy.spatial.distance.pdist` to find unique distance shells and then - a `scipy.spatial.KDTree` for fast radius queries. This approach is - significantly more memory-efficient during the neighbor identification phase. + """ + Calculates neighbor relationships using either KDTree or distance matrix methods. - After the neighbors are identified, the full distance matrix is computed - from the pairwise distances and cached for potential future use. + This method supports two modes: + 1. KDTree mode (use_kdtree=True): Fast, O(N log N) performance for large lattices + but breaks differentiability due to scipy dependency + 2. Distance matrix mode (use_kdtree=False): Slower O(N²) but fully differentiable + and backend-agnostic - :param max_k: The maximum number of neighbor shells to - calculate. Defaults to 1. - :type max_k: int, optional - :param tol: The numerical tolerance for distance - comparisons. Defaults to 1e-6. - :type tol: float, optional + :param max_k: Maximum number of neighbor shells to compute + :type max_k: int + :param kwargs: Additional arguments including: + - use_kdtree (bool): Whether to use KDTree optimization. Defaults to False. + - tol (float): Distance tolerance for neighbor identification. Defaults to 1e-6. """ tol = kwargs.get("tol", 1e-6) - logger.info(f"Building neighbors for CustomizeLattice up to k={max_k}...") + # Reviewer suggestion: prefer differentiable method by default + use_kdtree = kwargs.get("use_kdtree", False) + if self.num_sites < 2: return - all_coords = np.array(self._coordinates) + # Choose algorithm based on user preference + if use_kdtree: + logger.info( + f"Using KDTree method for {self.num_sites} sites up to k={max_k}" + ) + self._build_neighbors_kdtree(max_k, tol) + else: + logger.info( + f"Using differentiable distance matrix method for {self.num_sites} sites up to k={max_k}" + ) - # 1. Use pdist for memory-efficient calculation of pairwise distances - # to robustly identify the distance shells. - all_distances_sq = pdist(all_coords, metric="sqeuclidean") - dist_shells_sq = self._identify_distance_shells(all_distances_sq, max_k, tol) + # Use the existing distance matrix method + self._build_neighbors_by_distance_matrix(max_k, tol) - if not dist_shells_sq: - logger.info("No distinct neighbor shells found.") - return + def _build_neighbors_kdtree(self, max_k: int, tol: float) -> None: + """ + Build neighbors using KDTree for optimal performance. - # 2. Build the KDTree for efficient querying. - tree = KDTree(all_coords) - self._neighbor_maps = {k: {} for k in range(1, len(dist_shells_sq) + 1)} + This method provides O(N log N) performance for neighbor finding but breaks + differentiability due to scipy dependency. Use this method when: + - Performance is critical + - Differentiability is not required + - Large lattices (N > 1000) - # 3. Find neighbors by isolating shells using inclusion-exclusion. - # `found_indices` will store all neighbors within a given radius. - found_indices: List[set[int]] = [] - for k_idx, target_d_sq in enumerate(dist_shells_sq): - radius = np.sqrt(target_d_sq) + tol - # Query for all points within the new, larger radius. - current_shell_indices = tree.query_ball_point( - all_coords, r=radius, return_sorted=True - ) + Note: This method uses numpy arrays directly and may not be compatible + with all backend types (JAX, TensorFlow, etc.). + """ + # Convert coordinates to numpy for KDTree + coords_np = backend.numpy(self._coordinates) - # Now, isolate the neighbors for the current shell k - k = k_idx + 1 - current_k_map: Dict[int, List[int]] = {} - for i in range(self.num_sites): + # Build KDTree + logger.info("Building KDTree...") + tree = KDTree(coords_np) - if k_idx == 0: - co_located_indices = tree.query_ball_point(all_coords[i], r=1e-12) - prev_found = set(co_located_indices) - else: - prev_found = found_indices[i] + # For small lattices or cases with potential duplicate coordinates, + # fall back to distance matrix method for robustness + if self.num_sites < 200: + logger.info( + "Small lattice detected, falling back to distance matrix method for robustness" + ) + self._build_neighbors_by_distance_matrix(max_k, tol) + return - # The new neighbors are those in the current radius shell, - # excluding those already found in smaller shells. - new_neighbors = set(current_shell_indices[i]) - prev_found + # Find all distances for shell identification - use comprehensive sampling + logger.info("Identifying distance shells...") + distances_for_shells: List[float] = [] - if new_neighbors: - current_k_map[i] = sorted(list(new_neighbors)) + # For robust shell identification, query all pairwise distances for smaller lattices + # or use dense sampling for larger ones + if self.num_sites <= 100: + # For small lattices, compute all pairwise distances for accuracy + for i in range(self.num_sites): + query_k = min(self.num_sites - 1, max_k * 20) + if query_k > 0: + dists, _ = tree.query( + coords_np[i], k=query_k + 1 + ) # +1 to exclude self + if isinstance(dists, np.ndarray): + distances_for_shells.extend(dists[1:]) # Skip distance to self + else: + distances_for_shells.append(dists) # Single distance + else: + # For larger lattices, use adaptive sampling but ensure we capture all shells + sample_size = min(1000, self.num_sites // 2) # More conservative sampling + for i in range(0, self.num_sites, max(1, self.num_sites // sample_size)): + query_k = min(max_k * 20 + 50, self.num_sites - 1) + if query_k > 0: + dists, _ = tree.query( + coords_np[i], k=query_k + 1 + ) # +1 to exclude self + if isinstance(dists, np.ndarray): + distances_for_shells.extend(dists[1:]) # Skip distance to self + else: + distances_for_shells.append(dists) # Single distance - self._neighbor_maps[k] = current_k_map - found_indices = [ - set(l) for l in current_shell_indices - ] # Update for next iteration - self._distance_matrix = np.sqrt(squareform(all_distances_sq)) + # Filter out zero distances (duplicate coordinates) before shell identification + ZERO_THRESHOLD = 1e-12 + distances_for_shells = [d for d in distances_for_shells if d > ZERO_THRESHOLD] - logger.info("Neighbor building complete using KDTree.") + if not distances_for_shells: + logger.warning("No valid distances found for shell identification") + self._neighbor_maps = {} + return - def _compute_distance_matrix(self) -> Coordinates: - """Computes the distance matrix from the stored coordinates. + # Use the same shell identification logic as distance matrix method + distances_for_shells_sq = [d * d for d in distances_for_shells] + dist_shells_sq = self._identify_distance_shells( + distances_for_shells_sq, max_k, tol + ) + dist_shells = [np.sqrt(d_sq) for d_sq in dist_shells_sq] - This implementation uses scipy.pdist for a memory-efficient - calculation of pairwise distances, which is then converted to a - full square matrix. - """ - if self.num_sites < 2: - return cast(Coordinates, np.empty((self.num_sites, self.num_sites))) + logger.info(f"Found {len(dist_shells)} distance shells: {dist_shells[:5]}...") - all_coords = np.array(self._coordinates) - # Use pdist for memory-efficiency, then build the full matrix. - all_distances_sq = pdist(all_coords, metric="sqeuclidean") - dist_matrix_sq = squareform(all_distances_sq) - return cast(Coordinates, np.sqrt(dist_matrix_sq)) + # Initialize neighbor maps + self._neighbor_maps = {k: {} for k in range(1, len(dist_shells) + 1)} + + # Build neighbor lists for each site + logger.info("Building neighbor lists...") + for i in range(self.num_sites): + # Query enough neighbors to capture all shells + query_k = min(max_k * 20 + 50, self.num_sites - 1) + if query_k > 0: + distances, indices = tree.query( + coords_np[i], k=query_k + 1 + ) # +1 for self + + # Skip the first entry (distance to self) + # Handle both single value and array cases + if isinstance(distances, np.ndarray) and len(distances) > 1: + distances_slice = distances[1:] + indices_slice = ( + indices[1:] + if isinstance(indices, np.ndarray) + else np.array([], dtype=int) + ) + else: + # Single value or empty case - no neighbors to process + distances_slice = np.array([]) + indices_slice = np.array([], dtype=int) + + # Filter out zero distances (duplicate coordinates) + valid_pairs = [ + (d, idx) + for d, idx in zip(distances_slice, indices_slice) + if d > ZERO_THRESHOLD + ] + + # Assign neighbors to shells + for shell_idx, shell_dist in enumerate(dist_shells): + k = shell_idx + 1 + shell_neighbors = [] + + for dist, neighbor_idx in valid_pairs: + if abs(dist - shell_dist) <= tol: + shell_neighbors.append(int(neighbor_idx)) + elif dist > shell_dist + tol: + break # Distances are sorted, no more matches + + if shell_neighbors: + self._neighbor_maps[k][i] = sorted(shell_neighbors) + + # Set distance matrix to None - will compute on demand + self._distance_matrix = None + logger.info("KDTree neighbor building completed") def _reset_computations(self) -> None: """Resets all cached data that depends on the lattice structure.""" @@ -1344,18 +1623,28 @@ def from_lattice(cls, lattice: "AbstractLattice") -> "CustomizeLattice": ) # Unzip the list of tuples into separate lists of identifiers and coordinates - _, identifiers, coordinates = zip(*all_sites_info) + _, identifiers, _ = zip(*all_sites_info) + + # Detach-and-copy coordinates while remaining in tensor form to avoid + # host roundtrips and device/dtype changes; this keeps CustomizeLattice + # decoupled from the original graph but backend-friendly. + # Some backends (e.g., NumPy) don't implement stop_gradient; fall back. + try: + coords_detached = backend.stop_gradient(lattice._coordinates) + except NotImplementedError: + coords_detached = lattice._coordinates + coords_tensor = backend.copy(coords_detached) return cls( dimensionality=lattice.dimensionality, identifiers=list(identifiers), - coordinates=list(coordinates), + coordinates=coords_tensor, ) def add_sites( self, identifiers: List[SiteIdentifier], - coordinates: List[Union[List[float], Coordinates]], + coordinates: Any, ) -> None: """Adds new sites to the lattice. @@ -1363,21 +1652,29 @@ def add_sites( previously computed neighbor information is cleared and must be recalculated. - :param identifiers: A list of unique, hashable identifiers for the new sites. + :param identifiers: A list of unique identifiers for the new sites. :type identifiers: List[SiteIdentifier] - :param coordinates: A list of coordinates for the new sites. - :type coordinates: List[Union[List[float], np.ndarray]] - :raises ValueError: If input lists have mismatched lengths, or if any new - identifier already exists in the lattice. + :param coordinates: The coordinates for the new sites. Can be a list of lists, + a NumPy array, or a backend-compatible tensor (e.g., jax.numpy.ndarray). + :type coordinates: Any """ - if len(identifiers) != len(coordinates): + if not identifiers: + return + + new_coords_tensor = backend.convert_to_tensor(coordinates) + + if len(identifiers) != backend.shape_tuple(new_coords_tensor)[0]: raise ValueError( "Identifiers and coordinates lists must have the same length." ) - if not identifiers: - return # Nothing to add - # Check for duplicate identifiers before making any changes + if backend.shape_tuple(new_coords_tensor)[1] != self.dimensionality: + raise ValueError( + f"New coordinate tensor has dimension {backend.shape_tuple(new_coords_tensor)[1]}, " + f"but expected dimensionality is {self.dimensionality}." + ) + + # Ensure that the new identifiers are unique and do not already exist. existing_ids = set(self._identifiers) new_ids = set(identifiers) if not new_ids.isdisjoint(existing_ids): @@ -1385,21 +1682,14 @@ def add_sites( f"Duplicate identifiers found: {new_ids.intersection(existing_ids)}" ) - for i, coord in enumerate(coordinates): - coord_arr = np.asarray(coord) - if coord_arr.shape != (self.dimensionality,): - raise ValueError( - f"New coordinate at index {i} has shape {coord_arr.shape}, " - f"expected ({self.dimensionality},)" - ) - self._coordinates.append(coord_arr) - self._identifiers.append(identifiers[i]) + self._coordinates = backend.concat( + [self._coordinates, new_coords_tensor], axis=0 + ) + self._identifiers.extend(identifiers) - # Rebuild index mappings from scratch self._indices = list(range(len(self._identifiers))) self._ident_to_idx = {ident: idx for idx, ident in enumerate(self._identifiers)} - # Invalidate any previously computed neighbors or distance matrices self._reset_computations() logger.info( f"{len(identifiers)} sites added. Lattice now has {self.num_sites} sites." @@ -1414,10 +1704,9 @@ def remove_sites(self, identifiers: List[SiteIdentifier]) -> None: :param identifiers: A list of identifiers for the sites to be removed. :type identifiers: List[SiteIdentifier] - :raises ValueError: If any of the specified identifiers do not exist. """ if not identifiers: - return # Nothing to remove + return ids_to_remove = set(identifiers) current_ids = set(self._identifiers) @@ -1426,23 +1715,25 @@ def remove_sites(self, identifiers: List[SiteIdentifier]) -> None: f"Non-existent identifiers provided for removal: {ids_to_remove - current_ids}" ) - # Create new lists containing only the sites to keep - new_identifiers: List[SiteIdentifier] = [] - new_coordinates: List[Coordinates] = [] - for ident, coord in zip(self._identifiers, self._coordinates): - if ident not in ids_to_remove: - new_identifiers.append(ident) - new_coordinates.append(coord) + # Find the indices of the sites that we want to keep. + indices_to_keep = [ + idx + for idx, ident in enumerate(self._identifiers) + if ident not in ids_to_remove + ] + + new_identifiers = [self._identifiers[i] for i in indices_to_keep] + + self._coordinates = backend.gather1d( + self._coordinates, + backend.cast(backend.convert_to_tensor(indices_to_keep), "int32"), + ) - # Replace old data with the new, filtered data self._identifiers = new_identifiers - self._coordinates = new_coordinates - # Rebuild index mappings self._indices = list(range(len(self._identifiers))) self._ident_to_idx = {ident: idx for idx, ident in enumerate(self._identifiers)} - # Invalidate caches self._reset_computations() logger.info( f"{len(ids_to_remove)} sites removed. Lattice now has {self.num_sites} sites." diff --git a/tests/test_backends.py b/tests/test_backends.py index a681a18..b589163 100644 --- a/tests/test_backends.py +++ b/tests/test_backends.py @@ -1174,3 +1174,62 @@ def fb2(*args): np.testing.assert_allclose(ya, tc.backend.transpose(yb, [1, 0, 2]), atol=1e-5) np.testing.assert_allclose(ya, yajit, atol=1e-5) np.testing.assert_allclose(yajit, tc.backend.transpose(ybjit, [1, 0, 2]), atol=1e-5) + + +@pytest.mark.parametrize("backend", [lf("npb"), lf("tfb"), lf("jaxb"), lf("torchb")]) +def test_backend_sort(backend): + """Test sort method.""" + a = tc.backend.convert_to_tensor([3, 1, 4, 1, 5]) + result = tc.backend.sort(a) + expected = tc.backend.convert_to_tensor([1, 1, 3, 4, 5]) + np.testing.assert_allclose(result, expected) + + +@pytest.mark.parametrize("backend", [lf("npb"), lf("tfb"), lf("jaxb"), lf("torchb")]) +def test_backend_all(backend): + """Test all method.""" + # Test all True + a = tc.backend.convert_to_tensor([True, True, True]) + result = tc.backend.all(a) + assert tc.backend.numpy(result).item() is True + + # Test with False + b = tc.backend.convert_to_tensor([True, False, True]) + result = tc.backend.all(b) + assert tc.backend.numpy(result).item() is False + + +@pytest.mark.parametrize("backend", [lf("npb"), lf("tfb"), lf("jaxb"), lf("torchb")]) +def test_backend_meshgrid(backend): + """Test meshgrid method.""" + x = tc.backend.convert_to_tensor([1, 2]) + y = tc.backend.convert_to_tensor([3, 4]) + xx, yy = tc.backend.meshgrid(x, y, indexing="ij") + + expected_xx = tc.backend.convert_to_tensor([[1, 1], [2, 2]]) + expected_yy = tc.backend.convert_to_tensor([[3, 4], [3, 4]]) + + np.testing.assert_allclose(xx, expected_xx) + np.testing.assert_allclose(yy, expected_yy) + + +@pytest.mark.parametrize("backend", [lf("npb"), lf("tfb"), lf("jaxb"), lf("torchb")]) +def test_backend_expand_dims(backend): + """Test expand_dims method.""" + a = tc.backend.convert_to_tensor([[1, 2], [3, 4]]) + result = tc.backend.expand_dims(a, axis=0) + assert result.shape == (1, 2, 2) + + result = tc.backend.expand_dims(a, axis=1) + assert result.shape == (2, 1, 2) + + +@pytest.mark.parametrize("backend", [lf("npb"), lf("tfb"), lf("jaxb"), lf("torchb")]) +def test_backend_where(backend): + """Test where method.""" + condition = tc.backend.convert_to_tensor([True, False, True]) + x = tc.backend.convert_to_tensor([1, 2, 3]) + y = tc.backend.convert_to_tensor([4, 5, 6]) + result = tc.backend.where(condition, x, y) + expected = tc.backend.convert_to_tensor([1, 5, 3]) + np.testing.assert_allclose(result, expected) diff --git a/tests/test_hamiltonians.py b/tests/test_hamiltonians.py index 7ecd011..8be5771 100644 --- a/tests/test_hamiltonians.py +++ b/tests/test_hamiltonians.py @@ -157,3 +157,35 @@ def test_anisotropic_heisenberg(self, backend): h_generated_dense = tc.backend.to_dense(h_generated) assert h_generated_dense.shape == (4, 4) assert np.allclose(h_generated_dense, h_expected) + + def test_heisenberg_hamiltonian_all_interactions(self): + """ + Test the Heisenberg Hamiltonian with 'all' interaction scope. + For a 3-site chain, this should include interactions (0,1), (0,2), and (1,2). + """ + lattice = ChainLattice(size=(3,), pbc=False) + j_coupling = 1.0 + h_generated = heisenberg_hamiltonian( + lattice, j_coupling=j_coupling, interaction_scope="all" + ) + + # Manually construct the expected Hamiltonian for all-to-all interaction + # H = J * [ (X0X1 + Y0Y1 + Z0Z1) + (X0X2 + Y0Y2 + Z0Z2) + (X1X2 + Y1Y2 + Z1Z2) ] + xx_01 = np.kron(PAULI_X, np.kron(PAULI_X, PAULI_I)) + yy_01 = np.kron(PAULI_Y, np.kron(PAULI_Y, PAULI_I)) + zz_01 = np.kron(PAULI_Z, np.kron(PAULI_Z, PAULI_I)) + + xx_02 = np.kron(PAULI_X, np.kron(PAULI_I, PAULI_X)) + yy_02 = np.kron(PAULI_Y, np.kron(PAULI_I, PAULI_Y)) + zz_02 = np.kron(PAULI_Z, np.kron(PAULI_I, PAULI_Z)) + + xx_12 = np.kron(PAULI_I, np.kron(PAULI_X, PAULI_X)) + yy_12 = np.kron(PAULI_I, np.kron(PAULI_Y, PAULI_Y)) + zz_12 = np.kron(PAULI_I, np.kron(PAULI_Z, PAULI_Z)) + + h_expected = j_coupling * ( + (xx_01 + yy_01 + zz_01) + (xx_02 + yy_02 + zz_02) + (xx_12 + yy_12 + zz_12) + ) + + assert h_generated.shape == (8, 8) + assert np.allclose(tc.backend.to_dense(h_generated), h_expected) diff --git a/tests/test_lattice.py b/tests/test_lattice.py index 12354f1..612c9f8 100644 --- a/tests/test_lattice.py +++ b/tests/test_lattice.py @@ -1,15 +1,13 @@ from unittest.mock import patch import logging -# import time - import matplotlib matplotlib.use("Agg") - import pytest import numpy as np +from pytest_lazyfixture import lazy_fixture as lf from tensorcircuit.templates.lattice import ( ChainLattice, @@ -23,13 +21,13 @@ RectangularLattice, SquareLattice, TriangularLattice, - AbstractLattice, get_compatible_layers, ) +import tensorcircuit as tc @pytest.fixture -def simple_square_lattice() -> CustomizeLattice: +def simple_square_lattice(): """ Provides a simple 2x2 square CustomizeLattice instance for neighbor tests. The sites are indexed as follows: @@ -46,7 +44,7 @@ def simple_square_lattice() -> CustomizeLattice: @pytest.fixture -def kagome_lattice_fragment() -> CustomizeLattice: +def kagome_lattice_fragment(): """ Pytest fixture to provide a standard CustomizeLattice instance. This represents the Kagome fragment from the project requirements, @@ -72,7 +70,10 @@ class TestCustomizeLattice: This helps in organizing the test suite. """ - def test_initialization_and_properties(self, kagome_lattice_fragment): + @pytest.mark.parametrize( + "backend", [lf("npb"), lf("tfb"), lf("jaxb"), lf("torchb")] + ) + def test_initialization_and_properties(self, backend, kagome_lattice_fragment): """ Test case for successful initialization and verification of basic properties. This test function receives the 'kagome_lattice_fragment' fixture as an argument. @@ -85,10 +86,15 @@ def test_initialization_and_properties(self, kagome_lattice_fragment): assert lattice.num_sites == 6 assert len(lattice) == 6 # This also tests the __len__ dunder method - # Verify that coordinates are correctly stored as numpy arrays. - # It's important to use np.testing.assert_array_equal for numpy array comparison. + # Verify that coordinates are correctly stored as backend tensors. + # It's important to use np.testing.assert_array_equal for coordinate comparison. expected_coord = np.array([0.5, np.sqrt(3) / 2]) - np.testing.assert_array_equal(lattice.get_coordinates(2), expected_coord) + np.testing.assert_allclose( + tc.backend.numpy(lattice.get_coordinates(2)), + expected_coord, + rtol=1e-6, + atol=1e-6, + ) # Verify that the mapping between identifiers and indices is correct. assert lattice.get_identifier(4) == 4 @@ -107,7 +113,7 @@ def test_input_validation_mismatched_lengths(self): # the specified exception is raised within the 'with' block. with pytest.raises( ValueError, - match="Identifiers and coordinates lists must have the same length.", + match="The number of identifiers must match the number of coordinates.", ): CustomizeLattice(dimensionality=2, identifiers=ids, coordinates=coords) @@ -122,14 +128,16 @@ def test_input_validation_wrong_dimension(self): # Act & Assert: Check for the specific error message. The 'r' before the string # indicates a raw string, which is good practice for regex patterns. - with pytest.raises( - ValueError, match=r"Coordinate at index 1 has shape \(3,\), expected \(2,\)" - ): + with pytest.raises(ValueError): + CustomizeLattice( dimensionality=2, identifiers=ids_ok, coordinates=coords_wrong_dim ) - def test_neighbor_finding(self, simple_square_lattice): + @pytest.mark.parametrize( + "backend", [lf("npb"), lf("tfb"), lf("jaxb"), lf("torchb")] + ) + def test_neighbor_finding(self, backend, simple_square_lattice): """ Tests the k-th nearest neighbor finding functionality (_build_neighbors and get_neighbors). @@ -139,17 +147,56 @@ def test_neighbor_finding(self, simple_square_lattice): # --- Assertions for k=1 (Nearest Neighbors) --- # We use set() for comparison to ignore the order of neighbors. - assert set(lattice.get_neighbors(0, k=1)) == {1, 2} - assert set(lattice.get_neighbors(1, k=1)) == {0, 3} - assert set(lattice.get_neighbors(2, k=1)) == {0, 3} - assert set(lattice.get_neighbors(3, k=1)) == {1, 2} + neighbors = set(lattice.get_neighbors(0, k=1)) + expected = {1, 2} + assert neighbors == expected, ( + f"Failed neighbor check for site 0 (k=1). " + f"Expected {expected}, but got {neighbors}. " + f"Missing: {expected - neighbors}. Extra: {neighbors - expected}." + ) + neighbors = set(lattice.get_neighbors(1, k=1)) + expected = {0, 3} + assert neighbors == expected, ( + f"Failed neighbor check for site 1 (k=1). " + f"Expected {expected}, but got {neighbors}. " + f"Missing: {expected - neighbors}. Extra: {neighbors - expected}." + ) + neighbors = set(lattice.get_neighbors(2, k=1)) + expected = {0, 3} + assert neighbors == expected, ( + f"Failed neighbor check for site 2 (k=1). " + f"Expected {expected}, but got {neighbors}. " + f"Missing: {expected - neighbors}. Extra: {neighbors - expected}." + ) + neighbors = set(lattice.get_neighbors(3, k=1)) + expected = {1, 2} + assert neighbors == expected, ( + f"Failed neighbor check for site 3 (k=1). " + f"Expected {expected}, but got {neighbors}. " + f"Missing: {expected - neighbors}. Extra: {neighbors - expected}." + ) # --- Assertions for k=2 (Next-Nearest Neighbors) --- # These should be the diagonal sites. - assert set(lattice.get_neighbors(0, k=2)) == {3} - assert set(lattice.get_neighbors(1, k=2)) == {2} + neighbors = set(lattice.get_neighbors(0, k=2)) + expected = {3} + assert neighbors == expected, ( + f"Failed neighbor check for site 0 (k=2). " + f"Expected {expected}, but got {neighbors}. " + f"Missing: {expected - neighbors}. Extra: {neighbors - expected}." + ) + neighbors = set(lattice.get_neighbors(1, k=2)) + expected = {2} + assert neighbors == expected, ( + f"Failed neighbor check for site 1 (k=2). " + f"Expected {expected}, but got {neighbors}. " + f"Missing: {expected - neighbors}. Extra: {neighbors - expected}." + ) - def test_neighbor_pairs(self, simple_square_lattice): + @pytest.mark.parametrize( + "backend", [lf("npb"), lf("tfb"), lf("jaxb"), lf("torchb")] + ) + def test_neighbor_pairs(self, backend, simple_square_lattice): """ Tests the retrieval of unique neighbor pairs (bonds) using get_neighbor_pairs. @@ -174,7 +221,10 @@ def test_neighbor_pairs(self, simple_square_lattice): expected_nnn_pairs = {(0, 3), (1, 2)} assert set(map(tuple, nnn_pairs)) == expected_nnn_pairs - def test_neighbor_pairs_non_unique(self, simple_square_lattice): + @pytest.mark.parametrize( + "backend", [lf("npb"), lf("tfb"), lf("jaxb"), lf("torchb")] + ) + def test_neighbor_pairs_non_unique(self, backend, simple_square_lattice): """ Tests get_neighbor_pairs with unique=False to ensure all directed pairs (bonds) are returned. @@ -477,15 +527,42 @@ def test_customizelattice_max_k_precomputation_and_ondemand(self): f"but found {computed_shells_after}." ) + @pytest.mark.parametrize( + "backend", [lf("npb"), lf("tfb"), lf("jaxb"), lf("torchb")] + ) + def test_precompute_neighbors_on_init_custom(self, backend): + """ + Tests that the `precompute_neighbors` argument correctly populates + the neighbor map upon initialization for CustomizeLattice. + """ + coords = [[float(i), float(i)] for i in range(10)] + ids = list(range(10)) + k_to_precompute = 3 + + # Initialize the lattice with the precompute_neighbors argument + lattice = CustomizeLattice( + dimensionality=2, + identifiers=ids, + coordinates=coords, + precompute_neighbors=k_to_precompute, + ) + + # Assert that the internal neighbor map is populated correctly + assert lattice._neighbor_maps is not None + # Check that all shells up to k_to_precompute are present + computed_shells = sorted(list(lattice._neighbor_maps.keys())) + expected_shells = list(range(1, k_to_precompute + 1)) + assert computed_shells == expected_shells + @pytest.fixture -def obc_square_lattice() -> SquareLattice: +def obc_square_lattice(): """Provides a 3x3 SquareLattice with Open Boundary Conditions.""" return SquareLattice(size=(3, 3), pbc=False) @pytest.fixture -def pbc_square_lattice() -> SquareLattice: +def pbc_square_lattice(): """Provides a 3x3 SquareLattice with Periodic Boundary Conditions.""" return SquareLattice(size=(3, 3), pbc=True) @@ -496,7 +573,10 @@ class TestSquareLattice: the core functionality of its parent, TILattice. """ - def test_initialization_and_properties(self, obc_square_lattice): + @pytest.mark.parametrize( + "backend", [lf("npb"), lf("tfb"), lf("jaxb"), lf("torchb")] + ) + def test_initialization_and_properties(self, backend, obc_square_lattice): """ Tests the basic properties of a SquareLattice instance. """ @@ -505,7 +585,10 @@ def test_initialization_and_properties(self, obc_square_lattice): assert lattice.num_sites == 9 # A 3x3 lattice should have 9 sites. assert len(lattice) == 9 - def test_site_info_and_identifiers(self, obc_square_lattice): + @pytest.mark.parametrize( + "backend", [lf("npb"), lf("tfb"), lf("jaxb"), lf("torchb")] + ) + def test_site_info_and_identifiers(self, backend, obc_square_lattice): """ Tests that site information (coordinates, identifiers) is correct. """ @@ -515,14 +598,17 @@ def test_site_info_and_identifiers(self, obc_square_lattice): _, ident, coords = lattice.get_site_info(center_idx) assert ident == (1, 1, 0) - np.testing.assert_array_equal(coords, np.array([1.0, 1.0])) + np.testing.assert_array_equal(tc.backend.numpy(coords), np.array([1.0, 1.0])) corner_idx = 0 _, ident, coords = lattice.get_site_info(corner_idx) assert ident == (0, 0, 0) - np.testing.assert_array_equal(coords, np.array([0.0, 0.0])) + np.testing.assert_array_equal(tc.backend.numpy(coords), np.array([0.0, 0.0])) - def test_neighbors_with_open_boundaries(self, obc_square_lattice): + @pytest.mark.parametrize( + "backend", [lf("npb"), lf("tfb"), lf("jaxb"), lf("torchb")] + ) + def test_neighbors_with_open_boundaries(self, backend, obc_square_lattice): """ Tests neighbor finding with Open Boundary Conditions (OBC) using specific neighbor identities. @@ -537,11 +623,29 @@ def test_neighbors_with_open_boundaries(self, obc_square_lattice): edge_idx = 3 # (1, 0, 0) # Assert center site (4) has neighbors 1, 3, 5, 7 - assert set(lattice.get_neighbors(center_idx, k=1)) == {1, 3, 5, 7} + neighbors = set(lattice.get_neighbors(center_idx, k=1)) + expected = {1, 3, 5, 7} + assert neighbors == expected, ( + f"Failed neighbor check for site {center_idx} (k=1). " + f"Expected {expected}, but got {neighbors}. " + f"Missing: {expected - neighbors}. Extra: {neighbors - expected}." + ) # Assert corner site (0) has neighbors 1, 3 - assert set(lattice.get_neighbors(corner_idx, k=1)) == {1, 3} + neighbors = set(lattice.get_neighbors(corner_idx, k=1)) + expected = {1, 3} + assert neighbors == expected, ( + f"Failed neighbor check for site {corner_idx} (k=1). " + f"Expected {expected}, but got {neighbors}. " + f"Missing: {expected - neighbors}. Extra: {neighbors - expected}." + ) # Assert edge site (3) has neighbors 0, 4, 6 - assert set(lattice.get_neighbors(edge_idx, k=1)) == {0, 4, 6} + neighbors = set(lattice.get_neighbors(edge_idx, k=1)) + expected = {0, 4, 6} + assert neighbors == expected, ( + f"Failed neighbor check for site {edge_idx} (k=1). " + f"Expected {expected}, but got {neighbors}. " + f"Missing: {expected - neighbors}. Extra: {neighbors - expected}." + ) def test_neighbors_with_periodic_boundaries(self, pbc_square_lattice): """ @@ -565,9 +669,9 @@ def test_neighbors_with_periodic_boundaries(self, pbc_square_lattice): @pytest.fixture -def pbc_honeycomb_lattice() -> HoneycombLattice: - """Provides a 2x2 HoneycombLattice with Periodic Boundary Conditions.""" - return HoneycombLattice(size=(2, 2), pbc=True) +def pbc_honeycomb_lattice(): + """Provides a 3x3 HoneycombLattice with Periodic Boundary Conditions.""" + return HoneycombLattice(size=(3, 3), pbc=True) class TestHoneycombLattice: @@ -575,15 +679,21 @@ class TestHoneycombLattice: Tests the HoneycombLattice class, focusing on its two-site basis. """ - def test_initialization_and_properties(self, pbc_honeycomb_lattice): + @pytest.mark.parametrize( + "backend", [lf("npb"), lf("tfb"), lf("jaxb"), lf("torchb")] + ) + def test_initialization_and_properties(self, backend, pbc_honeycomb_lattice): """ Tests that the total number of sites is correct for a composite lattice. """ lattice = pbc_honeycomb_lattice - assert lattice.num_sites == 8 + assert lattice.num_sites == 18 assert lattice.num_basis == 2 - def test_honeycomb_neighbors(self, pbc_honeycomb_lattice): + @pytest.mark.parametrize( + "backend", [lf("npb"), lf("tfb"), lf("jaxb"), lf("torchb")] + ) + def test_honeycomb_neighbors(self, backend, pbc_honeycomb_lattice): """ Tests that every site in a honeycomb lattice has 3 nearest neighbors. """ @@ -594,12 +704,29 @@ def test_honeycomb_neighbors(self, pbc_honeycomb_lattice): site_b_idx = lattice.get_index((0, 0, 1)) assert len(lattice.get_neighbors(site_b_idx, k=1)) == 3 + @pytest.mark.parametrize( + "backend", [lf("npb"), lf("tfb"), lf("jaxb"), lf("torchb")] + ) + def test_honeycomb_next_nearest_neighbors(self, backend, pbc_honeycomb_lattice): + """ + Tests that every site in a honeycomb lattice has 6 next-nearest neighbors + under periodic boundary conditions. + """ + lattice = pbc_honeycomb_lattice + # In a PBC honeycomb lattice, every site is equivalent. + # We can just test one site from each basis. + site_a_idx = lattice.get_index((0, 0, 0)) + assert len(lattice.get_neighbors(site_a_idx, k=2)) == 6 + + site_b_idx = lattice.get_index((0, 0, 1)) + assert len(lattice.get_neighbors(site_b_idx, k=2)) == 6 + # --- Tests for TriangularLattice --- @pytest.fixture -def pbc_triangular_lattice() -> TriangularLattice: +def pbc_triangular_lattice(): """ Provides a 3x3 TriangularLattice with Periodic Boundary Conditions. A 3x3 size is used to ensure all 6 nearest neighbors are unique sites. @@ -612,14 +739,20 @@ class TestTriangularLattice: Tests the TriangularLattice class, focusing on its coordination number. """ - def test_initialization_and_properties(self, pbc_triangular_lattice): + @pytest.mark.parametrize( + "backend", [lf("npb"), lf("tfb"), lf("jaxb"), lf("torchb")] + ) + def test_initialization_and_properties(self, backend, pbc_triangular_lattice): """ Tests the basic properties of the triangular lattice. """ lattice = pbc_triangular_lattice assert lattice.num_sites == 9 # 3 * 3 = 9 sites for a 3x3 grid - def test_triangular_neighbors(self, pbc_triangular_lattice): + @pytest.mark.parametrize( + "backend", [lf("npb"), lf("tfb"), lf("jaxb"), lf("torchb")] + ) + def test_triangular_neighbors(self, backend, pbc_triangular_lattice): """ Tests that every site in a triangular lattice has 6 nearest neighbors. """ @@ -657,7 +790,7 @@ class TestTILatticeEdgeCases: """ @pytest.fixture - def obc_1d_chain(self) -> ChainLattice: + def obc_1d_chain(self): """ Provides a 5-site 1D chain with Open Boundary Conditions. """ @@ -683,7 +816,7 @@ def test_1d_chain_properties_and_neighbors(self, obc_1d_chain): assert set(lattice.get_neighbors(middle_idx, k=1)) == {1, 3} @pytest.fixture - def nonsquare_lattice(self) -> SquareLattice: + def nonsquare_lattice(self): """Provides a non-square 2x3 lattice to test indexing.""" return SquareLattice(size=(2, 3), pbc=False) @@ -1018,10 +1151,24 @@ def test_init_with_mismatched_shapes_raises_error(self): """ # Act & Assert: # Pass a 'size' tuple with 3 elements to a 2D SquareLattice. - # This should trigger the AssertionError from the parent TILattice class. - with pytest.raises(AssertionError, match="Size tuple length mismatch"): + # This should trigger the ValueError from the parent TILattice class. + with pytest.raises( + ValueError, match="Size tuple length .* does not match dimensionality" + ): SquareLattice(size=(2, 2, 2)) + def test_init_with_mismatched_pbc_raises_error(self): + """ + Tests that TILattice raises ValueError if the 'pbc' tuple's + length does not match the dimensionality. + This addresses a gap identified in the code review. + """ + with pytest.raises( + ValueError, match="PBC tuple length .* does not match dimensionality" + ): + # A 2D lattice requires a pbc tuple of length 2, but we provide one of length 1. + SquareLattice(size=(2, 2), pbc=(True,)) + def test_init_with_tuple_pbc(self): """ Tests that TILattice correctly handles a tuple input for the 'pbc' @@ -1078,6 +1225,24 @@ def test_tilattice_max_k_precomputation_and_ondemand( f"but found {computed_shells_after}." ) + def test_precompute_neighbors_on_init(self): + """ + Tests that the `precompute_neighbors` argument correctly populates + the neighbor map upon initialization for a TILattice subclass. + """ + k_to_precompute = 2 + # Initialize a SquareLattice with the precompute_neighbors argument + lattice = SquareLattice( + size=(4, 4), pbc=True, precompute_neighbors=k_to_precompute + ) + + # Assert that the internal neighbor map is populated correctly + assert lattice._neighbor_maps is not None + # Check that all shells up to k_to_precompute are present + computed_shells = sorted(list(lattice._neighbor_maps.keys())) + expected_shells = list(range(1, k_to_precompute + 1)) + assert computed_shells == expected_shells + class TestLongRangeNeighborFinding: """ @@ -1086,7 +1251,7 @@ class TestLongRangeNeighborFinding: """ @pytest.fixture(scope="class") - def large_pbc_square_lattice(self) -> SquareLattice: + def large_pbc_square_lattice(self): """ Provides a single 6x8 SquareLattice with PBC for all tests in this class. Using scope="class" makes it more efficient as it's created only once. @@ -1241,6 +1406,34 @@ def test_mixed_boundary_conditions(self): edge_neighbors == expected_edge_indices ), "Failed for edge site with mixed BC." + @pytest.mark.parametrize( + "backend", [lf("npb"), lf("tfb"), lf("jaxb"), lf("torchb")] + ) + def test_mixed_boundary_conditions_on_honeycomb(self, backend): + """ + Tests neighbor finding on a HoneycombLattice with mixed PBC. + This ensures the logic correctly handles composite lattices with + anisotropic boundary conditions. + """ + lattice = HoneycombLattice(size=(3, 3), pbc=(True, False)) + + test_site_idx = lattice.get_index((1, 0, 0)) + + neighbors = lattice.get_neighbors(test_site_idx, k=1) + + assert ( + len(neighbors) == 2 + ), "Site on the open boundary has incorrect number of neighbors." + + expected_neighbor_idents = { + (1, 0, 1), + (0, 0, 1), + } + + actual_neighbor_idents = {lattice.get_identifier(i) for i in neighbors} + + assert actual_neighbor_idents == expected_neighbor_idents + class TestAllTILattices: """ @@ -1330,16 +1523,20 @@ class TestCustomizeLatticeDynamic: """Tests the dynamic modification capabilities of CustomizeLattice.""" @pytest.fixture - def initial_lattice(self) -> CustomizeLattice: + def initial_lattice(self): """Provides a basic 3-site lattice for modification tests.""" return CustomizeLattice( dimensionality=2, identifiers=["A", "B", "C"], - coordinates=[[0, 0], [1, 0], [0, 1]], + coordinates=[[0.0, 0.0], [1.0, 0.0], [0.0, 1.0]], ) - def test_from_lattice_conversion(self): + @pytest.mark.parametrize( + "backend", [lf("npb"), lf("tfb"), lf("jaxb"), lf("torchb")] + ) + def test_from_lattice_conversion(self, backend): """Tests creating a CustomizeLattice from a TILattice.""" + # Arrange sq_lattice = SquareLattice(size=(2, 2), pbc=False) @@ -1351,28 +1548,40 @@ def test_from_lattice_conversion(self): assert custom_lattice.num_sites == sq_lattice.num_sites assert custom_lattice.dimensionality == sq_lattice.dimensionality # Verify a site to be sure + # custom_lattice coordinates are normalized to python lists; source lattice returns backend tensor np.testing.assert_array_equal( - custom_lattice.get_coordinates(3), sq_lattice.get_coordinates(3) + np.array(custom_lattice.get_coordinates(3)), + tc.backend.numpy(sq_lattice.get_coordinates(3)), ) assert custom_lattice.get_identifier(3) == sq_lattice.get_identifier(3) - def test_add_sites_successfully(self, initial_lattice): + @pytest.mark.parametrize( + "backend", [lf("npb"), lf("tfb"), lf("jaxb"), lf("torchb")] + ) + def test_add_sites_successfully(self, backend, initial_lattice): """Tests adding new, valid sites to the lattice.""" + # Arrange lat = initial_lattice assert lat.num_sites == 3 # Act - lat.add_sites(identifiers=["D", "E"], coordinates=[[1, 1], [2, 2]]) + lat.add_sites(identifiers=["D", "E"], coordinates=[[1.0, 1.0], [2.0, 2.0]]) # Assert assert lat.num_sites == 5 assert lat.get_identifier(4) == "E" - np.testing.assert_array_equal(lat.get_coordinates(3), np.array([1, 1])) + np.testing.assert_array_equal( + tc.backend.numpy(lat.get_coordinates(3)), np.array([1.0, 1.0]) + ) assert "E" in lat._ident_to_idx - def test_remove_sites_successfully(self, initial_lattice): + @pytest.mark.parametrize( + "backend", [lf("npb"), lf("tfb"), lf("jaxb"), lf("torchb")] + ) + def test_remove_sites_successfully(self, backend, initial_lattice): """Tests removing existing sites from the lattice.""" + # Arrange lat = initial_lattice assert lat.num_sites == 3 @@ -1384,29 +1593,43 @@ def test_remove_sites_successfully(self, initial_lattice): assert lat.num_sites == 1 assert lat.get_identifier(0) == "B" # Site 'B' is now at index 0 assert "A" not in lat._ident_to_idx - np.testing.assert_array_equal(lat.get_coordinates(0), np.array([1, 0])) + np.testing.assert_array_equal( + tc.backend.numpy(lat.get_coordinates(0)), np.array([1.0, 0.0]) + ) - def test_add_duplicate_identifier_raises_error(self, initial_lattice): + @pytest.mark.parametrize( + "backend", [lf("npb"), lf("tfb"), lf("jaxb"), lf("torchb")] + ) + def test_add_duplicate_identifier_raises_error(self, backend, initial_lattice): """Tests that adding a site with an existing identifier fails.""" + with pytest.raises(ValueError, match="Duplicate identifiers found"): - initial_lattice.add_sites(identifiers=["A"], coordinates=[[9, 9]]) + initial_lattice.add_sites(identifiers=["A"], coordinates=[[9.0, 9.0]]) - def test_remove_nonexistent_identifier_raises_error(self, initial_lattice): + @pytest.mark.parametrize( + "backend", [lf("npb"), lf("tfb"), lf("jaxb"), lf("torchb")] + ) + def test_remove_nonexistent_identifier_raises_error(self, backend, initial_lattice): """Tests that removing a non-existent site fails.""" + with pytest.raises(ValueError, match="Non-existent identifiers provided"): initial_lattice.remove_sites(identifiers=["Z"]) - def test_modification_clears_neighbor_cache(self, initial_lattice): + @pytest.mark.parametrize( + "backend", [lf("npb"), lf("tfb"), lf("jaxb"), lf("torchb")] + ) + def test_modification_clears_neighbor_cache(self, backend, initial_lattice): """ Tests that add_sites and remove_sites correctly invalidate the pre-computed neighbor map. """ + # Arrange: Pre-compute neighbors on the initial lattice initial_lattice._build_neighbors(max_k=1) assert 0 in initial_lattice._neighbor_maps[1] # Check that neighbors exist # Act 1: Add a site - initial_lattice.add_sites(identifiers=["D"], coordinates=[[5, 5]]) + initial_lattice.add_sites(identifiers=["D"], coordinates=[[5.0, 5.0]]) # Assert 1: The neighbor map should now be empty assert not initial_lattice._neighbor_maps @@ -1421,21 +1644,25 @@ def test_modification_clears_neighbor_cache(self, initial_lattice): # Assert 2: The neighbor map should be empty again assert not initial_lattice._neighbor_maps - def test_modification_clears_distance_matrix_cache(self, initial_lattice): + @pytest.mark.parametrize( + "backend", [lf("npb"), lf("tfb"), lf("jaxb"), lf("torchb")] + ) + def test_modification_clears_distance_matrix_cache(self, backend, initial_lattice): """ Tests that add_sites and remove_sites correctly invalidate the cached distance matrix and that the recomputed matrix is correct. """ + # Arrange 1: Compute, cache, and perform a meaningful check on the original matrix. lat = initial_lattice original_matrix = lat.distance_matrix assert lat._distance_matrix is not None assert original_matrix.shape == (3, 3) # Meaningful check: distance from 'A'(idx 0) to 'B'(idx 1) should be 1.0 - np.testing.assert_allclose(original_matrix[0, 1], 1.0) + np.testing.assert_allclose(tc.backend.numpy(original_matrix)[0, 1], 1.0) # Act 1: Add a site. This should invalidate the cache. - lat.add_sites(identifiers=["D"], coordinates=[[1, 1]]) + lat.add_sites(identifiers=["D"], coordinates=[[1.0, 1.0]]) # Assert 1: Check cache is cleared and the new matrix is correct. assert lat._distance_matrix is None # Verify cache invalidation @@ -1443,7 +1670,7 @@ def test_modification_clears_distance_matrix_cache(self, initial_lattice): assert new_matrix_added.shape == (4, 4) # Meaningful check: distance from 'B'(idx 1) to new site 'D'(idx 3) should be 1.0 # Coords: B=[1,0], D=[1,1] - np.testing.assert_allclose(new_matrix_added[1, 3], 1.0) + np.testing.assert_allclose(tc.backend.numpy(new_matrix_added)[1, 3], 1.0) # Act 2: Remove a site. This should also invalidate the cache. lat.remove_sites(identifiers=["A"]) @@ -1457,13 +1684,17 @@ def test_modification_clears_distance_matrix_cache(self, initial_lattice): # 'C' is now at index 1 (coords [0,1]) # 'D' is now at index 2 (coords [1,1]) # Distance from new 'B' (idx 0) to new 'D' (idx 2) should be 1.0 - np.testing.assert_allclose(final_matrix[0, 2], 1.0) + np.testing.assert_allclose(tc.backend.numpy(final_matrix)[0, 2], 1.0) - def test_neighbor_finding_returns_sorted_list(self, simple_square_lattice): + @pytest.mark.parametrize( + "backend", [lf("npb"), lf("tfb"), lf("jaxb"), lf("torchb")] + ) + def test_neighbor_finding_returns_sorted_list(self, backend, simple_square_lattice): """ Ensures that the list of neighbors returned by get_neighbors is always sorted. This provides a stricter check than set-based comparisons. """ + # Arrange lattice = simple_square_lattice @@ -1480,10 +1711,100 @@ def test_neighbor_finding_returns_sorted_list(self, simple_square_lattice): 3, ], "The neighbor list should be sorted in ascending order." + @pytest.mark.parametrize( + "backend", [lf("npb"), lf("tfb"), lf("jaxb"), lf("torchb")] + ) + def test_from_lattice_from_empty_lattice(self, backend): + """Tests creating a CustomizeLattice from an empty TILattice.""" + + # Arrange: Create an empty TILattice instance. + empty_sq = SquareLattice(size=(0, 0)) + + # Act: Convert it to a CustomizeLattice. + custom_from_empty = CustomizeLattice.from_lattice(empty_sq) + + # Assert: The resulting lattice should also be empty and have the correct properties. + assert isinstance(custom_from_empty, CustomizeLattice) + assert custom_from_empty.num_sites == 0 + assert custom_from_empty.dimensionality == 2 + + @pytest.mark.parametrize( + "backend", [lf("npb"), lf("tfb"), lf("jaxb"), lf("torchb")] + ) + def test_add_sites_to_empty_lattice(self, backend): + """Tests adding sites to a previously empty CustomizeLattice.""" + + # Arrange: Create an empty CustomizeLattice. + empty_lat = CustomizeLattice(dimensionality=2, identifiers=[], coordinates=[]) + assert empty_lat.num_sites == 0 + + # Act: Add new sites to it. + empty_lat.add_sites( + identifiers=["X", "Y"], coordinates=[[1.0, 1.0], [2.0, 2.0]] + ) + + # Assert: The lattice should now contain the new sites. + assert empty_lat.num_sites == 2 + assert empty_lat.get_identifier(0) == "X" + np.testing.assert_array_equal( + tc.backend.numpy(empty_lat.get_coordinates(1)), np.array([2.0, 2.0]) + ) + + @pytest.mark.parametrize( + "backend", [lf("npb"), lf("tfb"), lf("jaxb"), lf("torchb")] + ) + def test_add_and_remove_empty_list_of_sites(self, backend, initial_lattice): + """ + Tests that calling add_sites and remove_sites with empty lists + is a no-op and doesn't change the lattice state. + """ + + # Arrange + lat = initial_lattice + original_num_sites = lat.num_sites + # In backends like JAX, tensors are immutable. We can check object identity. + original_coords_id = id(lat._coordinates) + + # Act 1: Add an empty list of sites. + lat.add_sites(identifiers=[], coordinates=[]) + + # Assert 1: Nothing should have changed. + assert lat.num_sites == original_num_sites + assert id(lat._coordinates) == original_coords_id + + # Act 2: Remove an empty list of sites. + lat.remove_sites(identifiers=[]) + + # Assert 2: Still no changes. + assert lat.num_sites == original_num_sites + assert id(lat._coordinates) == original_coords_id + + @pytest.mark.parametrize( + "backend", [lf("npb"), lf("tfb"), lf("jaxb"), lf("torchb")] + ) + def test_remove_all_sites(self, backend, initial_lattice): + """Tests removing all sites from a lattice, resulting in an empty lattice.""" + + # Arrange + lat = initial_lattice + # Get all identifiers before removal. + all_ids = list(lat._identifiers) + + # Act + lat.remove_sites(identifiers=all_ids) + + # Assert: The lattice should now be empty. + assert lat.num_sites == 0 + assert len(lat._ident_to_idx) == 0 + assert lat._coordinates.shape[0] == 0 + class TestDistanceMatrix: # This is the upgraded, parameterized test. + @pytest.mark.parametrize( + "backend", [lf("npb"), lf("tfb"), lf("jaxb"), lf("torchb")] + ) @pytest.mark.parametrize( # We define test scenarios as tuples: # (build_k, check_site_identifier, expected_dist_sq) @@ -1504,7 +1825,7 @@ class TestDistanceMatrix: ], ) def test_tilattice_full_pbc_distance_matrix_is_correct_regardless_of_build_k( - self, build_k, check_site_identifier, expected_dist_sq + self, backend, build_k, check_site_identifier, expected_dist_sq ): """ Tests that the distance matrix for a fully periodic TILattice is @@ -1543,11 +1864,15 @@ def test_tilattice_full_pbc_distance_matrix_is_correct_regardless_of_build_k( actual_dist_sq, expected_dist_sq, err_msg=error_message ) - def test_tilattice_mixed_bc_distance_matrix_is_correct(self): + @pytest.mark.parametrize( + "backend", [lf("npb"), lf("tfb"), lf("jaxb"), lf("torchb")] + ) + def test_tilattice_mixed_bc_distance_matrix_is_correct(self, backend): """ Tests that the distance matrix is correctly calculated for a TILattice with mixed boundary conditions (e.g., periodic in x, open in y). """ + # Arrange # pbc=(True, False) means periodic along x-axis, open along y-axis. lat = SquareLattice(size=(5, 5), pbc=(True, False)) @@ -1578,25 +1903,44 @@ def test_tilattice_mixed_bc_distance_matrix_is_correct(self): ) # --- This list and the following test are now at the correct indentation level --- - lattice_instances_for_invariant_test = [ - SquareLattice(size=(4, 4), pbc=True), - SquareLattice(size=(4, 3), pbc=(True, False)), # Mixed BC, non-square - HoneycombLattice(size=(3, 3), pbc=True), - TriangularLattice(size=(4, 4), pbc=False), - CustomizeLattice( - dimensionality=2, - identifiers=list(range(4)), - coordinates=[[0, 0], [1, 1], [0, 1], [1, 0]], + # Use factory functions instead of pre-instantiated objects + # to avoid sharing a cached _distance_matrix (NumPy array) across backends + lattice_factories_for_invariant_test = [ + pytest.param(lambda: SquareLattice(size=(4, 4), pbc=True), id="Square_4x4_pbc"), + pytest.param( + lambda: SquareLattice(size=(4, 3), pbc=(True, False)), id="Square_4x3_mixed" + ), + pytest.param( + lambda: HoneycombLattice(size=(3, 3), pbc=True), id="Honeycomb_3x3_pbc" + ), + pytest.param( + lambda: TriangularLattice(size=(4, 4), pbc=False), id="Triangular_4x4_obc" + ), + pytest.param( + lambda: CustomizeLattice( + dimensionality=2, + identifiers=list(range(4)), + coordinates=[[0.0, 0.0], [1.0, 1.0], [0.0, 1.0], [1.0, 0.0]], + ), + id="Customize_4sites", ), ] - @pytest.mark.parametrize("lattice", lattice_instances_for_invariant_test) - def test_distance_matrix_invariants_for_all_lattice_types(self, lattice): + @pytest.mark.parametrize( + "backend", [lf("npb"), lf("tfb"), lf("jaxb"), lf("torchb")] + ) + @pytest.mark.parametrize("lattice_factory", lattice_factories_for_invariant_test) + def test_distance_matrix_invariants_for_all_lattice_types( + self, backend, lattice_factory + ): """ Tests that the distance matrix for any lattice type adheres to fundamental mathematical properties (invariants): symmetry, zero diagonal, and positive off-diagonal elements. """ + # Re-instantiate the lattice to ensure no cache side effects across backends + lattice = lattice_factory() + # Arrange n = lattice.num_sites if n < 2: @@ -1609,15 +1953,16 @@ def test_distance_matrix_invariants_for_all_lattice_types(self, lattice): # Assert # 1. Symmetry: The matrix must be equal to its transpose. + matrix_numpy = tc.backend.numpy(matrix) np.testing.assert_allclose( - matrix, - matrix.T, + matrix_numpy, + matrix_numpy.T, err_msg=f"Distance matrix for {type(lattice).__name__} is not symmetric.", ) # 2. Zero Diagonal: All diagonal elements must be zero. np.testing.assert_allclose( - np.diag(matrix), + np.diag(matrix_numpy), np.zeros(n), err_msg=f"Diagonal of distance matrix for {type(lattice).__name__} is not zero.", ) @@ -1626,49 +1971,38 @@ def test_distance_matrix_invariants_for_all_lattice_types(self, lattice): # We create a boolean mask for the off-diagonal elements. off_diagonal_mask = ~np.eye(n, dtype=bool) assert np.all( - matrix[off_diagonal_mask] > 1e-9 + matrix_numpy[off_diagonal_mask] > 1e-9 ), f"Found non-positive off-diagonal elements in distance matrix for {type(lattice).__name__}." + @pytest.mark.parametrize( + "backend", [lf("npb"), lf("tfb"), lf("jaxb"), lf("torchb")] + ) + def test_distance_matrix_caching_is_effective(self, backend): + """ + Tests that the distance_matrix property is cached after the first access. + """ + + # Arrange: Create a lattice instance. + lattice = CustomizeLattice( + dimensionality=2, + identifiers=["A", "B", "C"], + coordinates=[[0.0, 0.0], [1.0, 0.0], [0.0, 1.0]], + ) + + # Act & Assert + # We patch the internal _compute_distance_matrix method to spy on it. + with patch.object( + lattice, "_compute_distance_matrix", wraps=lattice._compute_distance_matrix + ) as spy_compute: + # Access the property twice. + _ = lattice.distance_matrix + _ = lattice.distance_matrix + + # The computation method should have been called only on the first access. + spy_compute.assert_called_once() + -# @pytest.mark.slow -# class TestPerformance: -# def test_pbc_implementation_is_not_significantly_slower_than_obc(self): -# """ -# A performance regression test. -# It ensures that the specialized implementation for fully periodic -# lattices (pbc=True) is not substantially slower than the general -# implementation used for open boundaries (pbc=False). -# This test will FAIL with the current code, exposing the performance bug. -# """ -# # Arrange: Use a large-enough lattice to make performance differences apparent -# size = (30, 30) -# k = 1 - -# # Act 1: Measure the execution time of the general (OBC) implementation -# start_time_obc = time.time() -# _ = SquareLattice(size=size, pbc=False, precompute_neighbors=k) -# duration_obc = time.time() - start_time_obc - -# # Act 2: Measure the execution time of the specialized (PBC) implementation -# start_time_pbc = time.time() -# _ = SquareLattice(size=size, pbc=True, precompute_neighbors=k) -# duration_pbc = time.time() - start_time_pbc - -# print( -# f"\n[Performance] OBC ({size}): {duration_obc:.4f}s | PBC ({size}): {duration_pbc:.4f}s" -# ) - -# # Assert: The PBC implementation should not be drastically slower. -# # We allow it to be up to 3 times slower to account for minor overheads, -# # but this will catch the current 10x+ regression. -# # THIS ASSERTION WILL FAIL with the current buggy code. -# assert duration_pbc < duration_obc * 5, ( -# "The specialized PBC implementation is significantly slower " -# "than the general-purpose implementation." -# ) - - -def _validate_layers(bonds, layers) -> None: +def _validate_layers(bonds, layers): """ A helper function to scientifically validate the output of get_compatible_layers. """ @@ -1708,7 +2042,7 @@ def _validate_layers(bonds, layers) -> None: "HoneycombLattice_2x2_OBC", ], ) -def test_layering_on_various_lattices(lattice_instance: AbstractLattice): +def test_layering_on_various_lattices(lattice_instance): """Tests gate layering for various standard lattice types.""" bonds = lattice_instance.get_neighbor_pairs(k=1, unique=True) layers = get_compatible_layers(bonds) @@ -1717,6 +2051,51 @@ def test_layering_on_various_lattices(lattice_instance: AbstractLattice): _validate_layers(bonds, layers) +@pytest.mark.parametrize("backend", [lf("npb"), lf("tfb"), lf("jaxb"), lf("torchb")]) +def test_square_lattice_accepts_backend_scalar_lattice_constant(backend): + """ + Ensure SquareLattice can be constructed when lattice_constant is a backend scalar tensor + (e.g., tf.constant, jnp.array, torch.tensor), without mixed-type errors. + """ + + lc = tc.backend.convert_to_tensor(0.5) + lat = SquareLattice(size=(2, 2), lattice_constant=lc, pbc=False) + + # basic shape sanity + assert lat.num_sites == 4 + assert tc.backend.shape_tuple(lat._coordinates)[1] == 2 # type: ignore[attr-defined] + + # distance check along x and y + dm = lat.distance_matrix + o = lat.get_index((0, 0, 0)) + x1 = lat.get_index((1, 0, 0)) + y1 = lat.get_index((0, 1, 0)) + np.testing.assert_allclose(tc.backend.numpy(dm[o, x1]), 0.5) + np.testing.assert_allclose(tc.backend.numpy(dm[o, y1]), 0.5) + + +@pytest.mark.parametrize("backend", [lf("npb"), lf("tfb"), lf("jaxb"), lf("torchb")]) +def test_rectangular_lattice_mixed_type_constants(backend): + """ + RectangularLattice should accept a tuple where one constant is a backend scalar tensor + and the other is a Python float. + """ + + ax = tc.backend.convert_to_tensor(0.5) # tensor scalar + ay = 2.0 # python float + lat = RectangularLattice(size=(2, 2), lattice_constants=(ax, ay), pbc=False) + + assert lat.num_sites == 4 + assert tc.backend.shape_tuple(lat._coordinates)[1] == 2 # type: ignore[attr-defined] + + dm = lat.distance_matrix + o = lat.get_index((0, 0, 0)) + x1 = lat.get_index((1, 0, 0)) + y1 = lat.get_index((0, 1, 0)) + np.testing.assert_allclose(tc.backend.numpy(dm[o, x1]), 0.5) + np.testing.assert_allclose(tc.backend.numpy(dm[o, y1]), 2.0) + + def test_layering_on_1d_chain_pbc(): """Test layering on a 1D chain with periodic boundaries (a cycle graph).""" lattice_even = ChainLattice(size=(6,), pbc=True) @@ -1748,3 +2127,511 @@ def test_layering_on_edge_cases(): layers_single = get_compatible_layers(single_edge) assert layers_single == [[(0, 1)]] _validate_layers(single_edge, layers_single) + + +def test_layering_on_disconnected_graph(): + """ + Tests that the layering algorithm correctly handles a graph consisting + of multiple disconnected components. + """ + disconnected_bonds = [ + (0, 1), + (1, 2), + (2, 0), + (3, 4), + (4, 5), + (5, 3), + ] + + layers = get_compatible_layers(disconnected_bonds) + + # Assert + _validate_layers(disconnected_bonds, layers) + + assert len(layers) == 3, "Two disconnected triangles should be 3-edge-colorable." + + layer_with_01 = next( + layer for layer in layers if (0, 1) in layer or (1, 0) in layer + ) + assert (3, 4) in layer_with_01 or (4, 3) in layer_with_01 + + +class TestBackendIntegration: + """ + Tests to ensure lattice functionalities are consistent and correct + across different computation backends. + """ + + # Define the test cases for parameterization + lattice_test_cases = [ + ( + SquareLattice, + {"size": (3, 3), "pbc": False}, + (0, 4, np.sqrt(2.0)), + "SquareLattice", + ), + ( + CustomizeLattice, + { + "dimensionality": 2, + "identifiers": [0, 1], + "coordinates": [[0.0, 0.0], [1.0, 1.0]], + }, + (0, 1, np.sqrt(2.0)), + "CustomizeLattice", + ), + ( + HoneycombLattice, + {"size": (2, 2), "pbc": True}, + (0, 1, 1.0), + "HoneycombLattice", + ), + ( + RectangularLattice, + {"size": (2, 3), "lattice_constants": (1.0, 2.0), "pbc": False}, + (0, 1, 2.0), + "RectangularLattice", + ), + ] + + @pytest.mark.parametrize( + "backend", [lf("npb"), lf("tfb"), lf("jaxb"), lf("torchb")] + ) + @pytest.mark.parametrize( + "LatticeClass, init_args, expected_distance_check, name", + lattice_test_cases, + ids=[case[3] for case in lattice_test_cases], + ) + def test_lattice_creation_and_properties_across_backends( + self, + backend, + LatticeClass, + init_args, + expected_distance_check, + name, + ): + """ + Tests that various lattices can be created with each backend and that their + core properties (_coordinates, distance_matrix) have the correct + tensor types and values. + """ + # Create the lattice instance inside the test function + lat = LatticeClass(**init_args) + + # Verify that the lattice can be created and has correct properties + assert lat.num_sites > 0, f"Failed for {type(lat).__name__} on current backend" + assert hasattr( + lat, "_coordinates" + ), f"Failed for {type(lat).__name__} on current backend" + assert hasattr( + lat, "distance_matrix" + ), f"Failed for {type(lat).__name__} on current backend" + + # Unpack the distance check information + idx1, idx2, expected_distance = expected_distance_check + + # Extract the actual distance from the matrix and convert to a numpy float + # for a consistent comparison across all backends. + actual_distance = tc.backend.numpy(lat.distance_matrix)[idx1, idx2] + + # Assert that the computed distance matches the expected value. + np.testing.assert_allclose( + actual_distance, + expected_distance, + err_msg=f"Distance check failed for {type(lat).__name__} on current backend", + ) + + @pytest.mark.parametrize( + "lattice_class, init_params, differentiable_arg_name, test_value", + [ + ( + SquareLattice, + {"size": (2, 2)}, + "lattice_constant", + 1.0, + ), + ( + RectangularLattice, + {"size": (2, 2)}, + "lattice_constants", + (1.0, 1.5), + ), + ( + DimerizedChainLattice, + {"size": (3,)}, + "lattice_constant", + 0.8, + ), + ], + ids=["SquareLattice", "RectangularLattice", "DimerizedChainLattice"], + ) + def test_tilattice_differentiability( + self, + jaxb, + lattice_class, + init_params, + differentiable_arg_name, + test_value, + ): + """ + Tests that the distance_matrix of various TILattices is differentiable + with respect to their geometric parameters. This test has been expanded + based on code review feedback to cover more lattice types. + """ + + def get_total_distance(param): + """A scalar-in, scalar-out function for jax.grad.""" + # Dynamically create the lattice with the parameter being differentiated + lat = lattice_class(**init_params, **{differentiable_arg_name: param}) + return tc.backend.sum(lat.distance_matrix) + + # Compute the gradient. The `argnums` parameter is used for tuple inputs. + grad_fn = ( + tc.backend.grad(get_total_distance, argnums=0) + if isinstance(test_value, tuple) + else tc.backend.grad(get_total_distance) + ) + grad_val = grad_fn(test_value) + + assert grad_val is not None + + # For tuple gradients, check that at least one element is non-zero. + if isinstance(grad_val, tuple): + assert any( + not np.isclose(float(g), 0.0) for g in grad_val + ), f"Gradient for {lattice_class.__name__} was all zeros." + else: + assert not np.isclose( + float(grad_val), 0.0 + ), f"Gradient for {lattice_class.__name__} was zero." + + def test_customizelattice_differentiability(self, jaxb): + """ + Tests that the distance_matrix of a CustomizeLattice is differentiable + with respect to its input coordinates. + """ + initial_coords = tc.backend.convert_to_tensor( + [[0.0, 0.0], [1.0, 1.0], [0.5, 0.5]] + ) + + def get_total_distance_custom(coords): + """ + A helper function that takes coordinates, creates a CustomizeLattice, + and returns a scalar value (the sum of its distance matrix). + """ + lat = CustomizeLattice( + dimensionality=2, identifiers=[0, 1, 2], coordinates=coords + ) + return tc.backend.sum(lat.distance_matrix) + + # Compute the gradient of the total distance with respect to the initial coordinates. + grad_tensor = tc.backend.grad(get_total_distance_custom)(initial_coords) + + # Assert that the gradient tensor is not None and is not all zeros. + # A non-zero gradient confirms that the output is indeed differentiable + # with respect to the input coordinates. + assert grad_tensor is not None + assert not np.all(np.isclose(grad_tensor, 0.0)) + + def test_tilattice_gradient_value_correctness(self, jaxb): + """ + Tests that the AD gradient for a TILattice parameter matches the + analytically calculated, correct gradient value. This is a stronger + test than just checking for non-zero gradients. + """ + + # 1. Define a simple objective function + def get_energy(a): + """ + A simple energy function for a 2-site chain. + Energy = (distance between site 0 and 1)^2 = a^2 + """ + # Using a 2-site chain, the simplest possible TILattice + lat = ChainLattice(size=(2,), pbc=False, lattice_constant=a) + # The distance matrix for a 2-site chain is [[0, a], [a, 0]] + dist_matrix = lat.distance_matrix + # Sum of squared distances for all unique pairs. There is only one pair (0,1). + return tc.backend.sum(dist_matrix**2) / 2.0 + + # 2. Define the analytical (manually calculated) gradient + def analytical_gradient(a): + """ + The analytical derivative of the energy function E(a) = a^2. + dE/da = 2a + """ + return 2 * a + + # 3. Set up the test + test_lattice_constant = 1.5 + # Compute the gradient using automatic differentiation + ad_grad = tc.backend.grad(get_energy)(test_lattice_constant) + # Compute the expected gradient using our analytical formula + expected_grad = analytical_gradient(test_lattice_constant) + + # 4. Assert that the two gradients are numerically very close + np.testing.assert_allclose( + ad_grad, + expected_grad, + rtol=1e-6, + err_msg="The automatically differentiated gradient does not match the analytical gradient.", + ) + + @pytest.mark.parametrize( + "backend", [lf("npb"), lf("tfb"), lf("jaxb"), lf("torchb")] + ) + def test_dynamic_modification_across_backends(self, backend): + """ + Tests that the dynamic modification methods (add_sites, remove_sites) + of CustomizeLattice work correctly across all supported backends, + specifically checking tensor shapes. + """ + # --- Initial State --- + # Create a simple lattice with 3 sites + initial_coords = [[0.0, 0.0], [1.0, 0.0], [0.0, 1.0]] + initial_ids = ["A", "B", "C"] + lattice = CustomizeLattice( + dimensionality=2, identifiers=initial_ids, coordinates=initial_coords + ) + assert lattice.num_sites == 3 + assert lattice._coordinates.shape[0] == 3 + + # --- Test add_sites --- + # Act: Add 2 new sites + lattice.add_sites(identifiers=["D", "E"], coordinates=[[1.0, 1.0], [2.0, 0.0]]) + + # Assert: Check new size and tensor shape + assert lattice.num_sites == 5 + assert ( + lattice._coordinates.shape[0] == 5 + ), "Tensor shape incorrect after add_sites on current backend." + + # --- Test remove_sites --- + # Act: Remove 1 site from the modified lattice + lattice.remove_sites(identifiers=["A"]) + + # Assert: Check final size and tensor shape + assert lattice.num_sites == 4 + assert ( + lattice._coordinates.shape[0] == 4 + ), "Tensor shape incorrect after remove_sites on current backend." + + +@pytest.mark.parametrize("backend", [lf("npb"), lf("tfb"), lf("jaxb"), lf("torchb")]) +def test_dtype_consistency_across_backends(backend): + """ + Tests that the dtype of user-provided coordinate data is preserved + in internal calculations across all backends. + """ + # Prepare input data with a specific, non-default dtype + coords_float32 = np.array([[0.0, 0.0], [1.0, 2.0]], dtype=np.float32) + + # Act: Create a lattice with this data + lattice = CustomizeLattice( + dimensionality=2, identifiers=[0, 1], coordinates=coords_float32 + ) + + # Assert: Check that the internal tensors have the correct dtype + assert tc.backend.dtype(lattice._coordinates) == "float32", ( + f"Mismatch in coordinate dtype for current backend. " + f"Expected 'float32', got {tc.backend.dtype(lattice._coordinates)}." + ) + assert tc.backend.dtype(lattice.distance_matrix) == "float32", ( + f"Mismatch in distance matrix dtype for current backend. " + f"Expected 'float32', got {tc.backend.dtype(lattice.distance_matrix)}." + ) + + +class TestPrivateHelpers: + """ + A dedicated test class for private helper methods of AbstractLattice. + This allows for focused testing of internal logic that is critical for + the public-facing API but not directly exposed. + """ + + @pytest.fixture + def simple_lattice_for_helpers(self): + """ + Provides a very simple lattice instance, primarily to gain access to + the private helper methods for testing. The geometry itself is trivial. + """ + # A simple 3-site lattice is sufficient to call the helper methods. + return CustomizeLattice( + dimensionality=1, identifiers=[0, 1, 2], coordinates=[[0.0], [1.0], [2.0]] + ) + + @pytest.mark.parametrize( + "backend", [lf("npb"), lf("tfb"), lf("jaxb"), lf("torchb")] + ) + def test_identify_distance_shells_basic(self, backend, simple_lattice_for_helpers): + """ + Tests the basic functionality of _identify_distance_shells with a + clear separation between distance shells. + """ + + # Arrange + lattice = simple_lattice_for_helpers + # A set of squared distances with clear gaps between them. + all_distances_sq = np.array([0, 1.0, 1.0, 4.0, 4.0, 9.0]) + + # Act + # Call the private helper method directly. + shells = lattice._identify_distance_shells(all_distances_sq, max_k=10) + + # Assert + # The method should identify the unique, non-zero distances. + np.testing.assert_allclose(shells, [1.0, 4.0, 9.0]) + + @pytest.mark.parametrize( + "backend", [lf("npb"), lf("tfb"), lf("jaxb"), lf("torchb")] + ) + def test_identify_distance_shells_with_max_k_limit( + self, backend, simple_lattice_for_helpers + ): + """ + Tests that _identify_distance_shells respects the max_k parameter, + limiting the number of returned shells. + """ + + # Arrange + lattice = simple_lattice_for_helpers + all_distances_sq = np.array([0, 1.0, 4.0, 9.0, 16.0, 25.0]) + max_k = 3 # We only want the first 3 shells. + + # Act + shells = lattice._identify_distance_shells(all_distances_sq, max_k=max_k) + + # Assert + # The number of shells should be limited by max_k. + assert len(shells) == max_k + # The returned shells should be the first `max_k` smallest distances. + np.testing.assert_allclose(shells, [1.0, 4.0, 9.0]) + + @pytest.mark.parametrize( + "backend", [lf("npb"), lf("tfb"), lf("jaxb"), lf("torchb")] + ) + def test_identify_distance_shells_with_tolerance_merging( + self, backend, simple_lattice_for_helpers + ): + """ + Tests that distances that are very close together are merged into a + single shell when the tolerance `tol` is large enough. + """ + + # Arrange + lattice = simple_lattice_for_helpers + # Two distances are very close: 1.0 and 1.000001 + all_distances_sq = np.array([0, 1.0, 1.000001, 4.0]) + # A tolerance larger than the difference between the two close distances. + tol = 1e-5 + + # Act + shells = lattice._identify_distance_shells(all_distances_sq, tol=tol, max_k=10) + + # Assert + # Because of the tolerance, 1.0 and 1.000001 should be considered the same shell. + # The method should return the first distance of the merged group (1.0). + np.testing.assert_allclose(shells, [1.0, 4.0]) + + def test_identify_distance_shells_with_tolerance_separation( + self, simple_lattice_for_helpers + ): + """ + Tests that very close distances are correctly identified as separate + shells when the tolerance `tol` is small enough. + """ + # Arrange + lattice = simple_lattice_for_helpers + all_distances_sq = np.array([0, 1.0, 1.000001, 4.0]) + # A tolerance smaller than the difference. + tol = 1e-7 + + # Act + shells = lattice._identify_distance_shells(all_distances_sq, tol=tol, max_k=10) + + # Assert + # With a small tolerance, the two close distances should be treated as distinct shells. + np.testing.assert_allclose(shells, [1.0, 1.000001, 4.0]) + + def test_identify_distance_shells_with_empty_and_zero_input( + self, simple_lattice_for_helpers + ): + """ + Tests that the method handles edge cases like empty arrays or arrays + containing only zeros, returning an empty list. + """ + # Arrange + lattice = simple_lattice_for_helpers + + # Act & Assert for empty input + shells_empty = lattice._identify_distance_shells(np.array([]), max_k=10) + assert ( + shells_empty == [] + ), "Should return an empty list for an empty distance array." + + # Act & Assert for zero-only input + shells_zero = lattice._identify_distance_shells(np.array([0, 0, 0]), max_k=10) + assert ( + shells_zero == [] + ), "Should return an empty list for a distance array with only zeros." + + def test_get_distance_matrix_with_mic_vectorized(self): + """ + Tests the internal _get_distance_matrix_with_mic_vectorized method for TILattice + to ensure it correctly applies the Minimum Image Convention. + """ + # --- Test Case 1: Fully Periodic Boundary Conditions (PBC) --- + lattice_pbc = SquareLattice(size=(3, 3), pbc=True, lattice_constant=1.0) + # We need to use the numpy backend for direct comparison + dist_matrix_pbc = tc.backend.numpy( + lattice_pbc._get_distance_matrix_with_mic_vectorized() + ) + + # For a 3x3 PBC lattice, the distance between opposite edges should be 1. + # Example: site (0,0) and site (2,0) + idx1 = lattice_pbc.get_index((0, 0, 0)) + idx2 = lattice_pbc.get_index((2, 0, 0)) + np.testing.assert_allclose( + dist_matrix_pbc[idx1, idx2], + 1.0, + err_msg="Distance check failed for x-direction in PBC case.", + ) + + # Example: site (0,0) and site (0,2) + idx3 = lattice_pbc.get_index((0, 2, 0)) + np.testing.assert_allclose( + dist_matrix_pbc[idx1, idx3], + 1.0, + err_msg="Distance check failed for y-direction in PBC case.", + ) + + # Example: corner-to-corner distance, e.g., (0,0) to (2,2) + # dx = min(|0-2|, 3-|0-2|) = 1; dy = min(|0-2|, 3-|0-2|) = 1 + # distance = sqrt(1^2 + 1^2) = sqrt(2) + idx4 = lattice_pbc.get_index((2, 2, 0)) + np.testing.assert_allclose( + dist_matrix_pbc[idx1, idx4], + np.sqrt(2.0), + err_msg="Diagonal distance check failed in PBC case.", + ) + + # --- Test Case 2: Mixed Boundary Conditions --- + lattice_mixed = SquareLattice( + size=(3, 3), pbc=(True, False), lattice_constant=1.0 + ) + dist_matrix_mixed = tc.backend.numpy( + lattice_mixed._get_distance_matrix_with_mic_vectorized() + ) + + # In the periodic x-direction, distance should be 1. + np.testing.assert_allclose( + dist_matrix_mixed[idx1, idx2], + 1.0, + err_msg="Distance check failed for periodic x-direction in mixed BC case.", + ) + + # In the open y-direction, distance should be 2. + np.testing.assert_allclose( + dist_matrix_mixed[idx1, idx3], + 2.0, + err_msg="Distance check failed for open y-direction in mixed BC case.", + )