diff --git a/.gitignore b/.gitignore index 9cdf3d46..fb713641 100644 --- a/.gitignore +++ b/.gitignore @@ -31,3 +31,4 @@ docs/source/locale/zh/LC_MESSAGES/textbook.po docs/source/locale/zh/LC_MESSAGES/whitepapertoc_cn.po docs/source/locale/zh/LC_MESSAGES/textbooktoc.po test.qasm +venv/ diff --git a/CHANGELOG.md b/CHANGELOG.md index f6c5c01a..4ef22ad8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,6 +1,7 @@ # Change Log ## Unreleased +- Add `Lattice` module (`tensorcircuit.templates.lattice`) for creating and manipulating various lattice geometries, including `SquareLattice`, `HoneycombLattice`, and `CustomizeLattice`. ## v1.2.1 diff --git a/examples/lattice_neighbor_benchmark.py b/examples/lattice_neighbor_benchmark.py new file mode 100644 index 00000000..c380231c --- /dev/null +++ b/examples/lattice_neighbor_benchmark.py @@ -0,0 +1,98 @@ +""" +An example script to benchmark neighbor-finding algorithms in CustomizeLattice. + +This script demonstrates the performance difference between the KDTree-based +neighbor search and a baseline all-to-all distance matrix method. +As shown by the results, the KDTree approach offers a significant speedup, +especially when calculating for a large number of neighbor shells (large max_k). + +To run this script from the project's root directory: + python examples/templates/lattice_neighbor_benchmark.py +""" + +import timeit +from typing import Any, Dict, List + + +def run_benchmark() -> None: + """ + Executes the benchmark test and prints the results in a formatted table. + """ + # --- Benchmark Parameters --- + # A list of lattice sizes (N = number of sites) to test + site_counts: List[int] = [10, 50, 100, 200, 500, 1000, 1500, 2000] + + # Use a large k to better showcase the performance of KDTree in + # finding multiple neighbor shells, as suggested by the maintainer. + max_k: int = 2000 + + # Reduce the number of runs to keep the total benchmark time reasonable, + # especially with a large max_k. + number_of_runs: int = 3 + # -------------------------- + + results: List[Dict[str, Any]] = [] + + print("=" * 75) + print("Starting neighbor finding benchmark for CustomizeLattice...") + print(f"Parameters: max_k={max_k}, number_of_runs={number_of_runs}") + print("=" * 75) + print( + f"{'Sites (N)':>10} | {'KDTree Time (s)':>18} | {'Baseline Time (s)':>20} | {'Speedup':>10}" + ) + print("-" * 75) + + for n_sites in site_counts: + # Prepare the setup code for timeit. + # This code generates a random lattice and is executed before timing begins. + # We use a fixed seed to ensure the coordinates are the same for both tests. + setup_code = f""" +import numpy as np +from tensorcircuit.templates.lattice import CustomizeLattice + +np.random.seed(42) +coords = np.random.rand({n_sites}, 2) +ids = list(range({n_sites})) +lat = CustomizeLattice(dimensionality=2, identifiers=ids, coordinates=coords) +""" + # Define the Python statements to be timed. + stmt_kdtree = f"lat._build_neighbors(max_k={max_k})" + stmt_baseline = f"lat._build_neighbors_by_distance_matrix(max_k={max_k})" + + try: + # Execute the timing. timeit returns the total time for all runs. + time_kdtree = ( + timeit.timeit(stmt=stmt_kdtree, setup=setup_code, number=number_of_runs) + / number_of_runs + ) + time_baseline = ( + timeit.timeit( + stmt=stmt_baseline, setup=setup_code, number=number_of_runs + ) + / number_of_runs + ) + + # Calculate and store results, handling potential division by zero. + speedup = time_baseline / time_kdtree if time_kdtree > 0 else float("inf") + results.append( + { + "n_sites": n_sites, + "time_kdtree": time_kdtree, + "time_baseline": time_baseline, + "speedup": speedup, + } + ) + print( + f"{n_sites:>10} | {time_kdtree:>18.6f} | {time_baseline:>20.6f} | {speedup:>9.2f}x" + ) + + except Exception as e: + print(f"An error occurred at N={n_sites}: {e}") + break + + print("-" * 75) + print("Benchmark complete.") + + +if __name__ == "__main__": + run_benchmark() diff --git a/tensorcircuit/templates/__init__.py b/tensorcircuit/templates/__init__.py index be2236ab..b250d435 100644 --- a/tensorcircuit/templates/__init__.py +++ b/tensorcircuit/templates/__init__.py @@ -5,5 +5,6 @@ from . import graphs from . import measurements from . import conversions +from . import lattice costfunctions = measurements diff --git a/tensorcircuit/templates/lattice.py b/tensorcircuit/templates/lattice.py new file mode 100644 index 00000000..d4e3e541 --- /dev/null +++ b/tensorcircuit/templates/lattice.py @@ -0,0 +1,1448 @@ +# -*- coding: utf-8 -*- +""" +The lattice module for defining and manipulating lattice geometries. +""" +import logging +import abc +from typing import ( + Any, + Dict, + Hashable, + Iterator, + List, + Optional, + Tuple, + Union, + TYPE_CHECKING, + cast, +) + +logger = logging.getLogger(__name__) +import numpy as np + +from scipy.spatial import KDTree +from scipy.spatial.distance import pdist, squareform + + +# This block resolves a name resolution issue for the static type checker (mypy). +# GOAL: +# Keep `matplotlib` as an optional dependency, so it is only imported +# inside the `show()` method, not at the module level. +# PROBLEM: +# The type hint for the `ax` parameter in `show()`'s signature +# (`ax: Optional["matplotlib.axes.Axes"]`) needs to know what `matplotlib` is. +# Without this block, mypy would raise a "Name 'matplotlib' is not defined" error. +# SOLUTION: +# The `if TYPE_CHECKING:` block is ignored at runtime but processed by mypy. +# This makes the name `matplotlib` available to the type checker without +# creating a hard dependency for the user. +if TYPE_CHECKING: + import matplotlib.axes + from mpl_toolkits.mplot3d import Axes3D + +SiteIndex = int +SiteIdentifier = Hashable +Coordinates = np.ndarray[Any, Any] +NeighborMap = Dict[SiteIndex, List[SiteIndex]] + + +class AbstractLattice(abc.ABC): + """Abstract base class for describing lattice systems. + + This class defines the common interface for all lattice structures, + providing access to fundamental properties like site information + (count, coordinates, identifiers) and neighbor relationships. + Subclasses are responsible for implementing the specific logic for + generating the lattice points and calculating neighbor connections. + + :param dimensionality: The spatial dimension of the lattice (e.g., 1, 2, 3). + :type dimensionality: int + """ + + def __init__(self, dimensionality: int): + """Initializes the base lattice class.""" + self._dimensionality = dimensionality + + # --- Internal Data Structures (to be populated by subclasses) --- + self._indices: List[SiteIndex] = [] + self._identifiers: List[SiteIdentifier] = [] + self._coordinates: List[Coordinates] = [] + self._ident_to_idx: Dict[SiteIdentifier, SiteIndex] = {} + self._neighbor_maps: Dict[int, NeighborMap] = {} + self._distance_matrix: Optional[Coordinates] = None + + @property + def num_sites(self) -> int: + """Returns the total number of sites (N) in the lattice.""" + return len(self._indices) + + @property + def dimensionality(self) -> int: + """Returns the spatial dimension of the lattice.""" + return self._dimensionality + + def __len__(self) -> int: + """Returns the total number of sites, enabling `len(lattice)`.""" + return self.num_sites + + # --- Public API for Accessing Lattice Information --- + @property + def distance_matrix(self) -> Coordinates: + """ + Returns the full N x N distance matrix. + The matrix is computed on the first access and then cached for + subsequent calls. This computation can be expensive for large lattices. + """ + if self._distance_matrix is None: + logger.info("Distance matrix not cached. Computing now...") + self._distance_matrix = self._compute_distance_matrix() + return self._distance_matrix + + def _validate_index(self, index: SiteIndex) -> None: + """A private helper to check if a site index is within the valid range.""" + if not (0 <= index < self.num_sites): + raise IndexError( + f"Site index {index} out of range (0-{self.num_sites - 1})" + ) + + def get_coordinates(self, index: SiteIndex) -> Coordinates: + """Gets the spatial coordinates of a site by its integer index. + + :param index: The integer index of the site. + :type index: SiteIndex + :raises IndexError: If the site index is out of range. + :return: The spatial coordinates as a NumPy array. + :rtype: Coordinates + """ + self._validate_index(index) + return self._coordinates[index] + + def get_identifier(self, index: SiteIndex) -> SiteIdentifier: + """Gets the abstract identifier of a site by its integer index. + + :param index: The integer index of the site. + :type index: SiteIndex + :raises IndexError: If the site index is out of range. + :return: The unique, hashable identifier of the site. + :rtype: SiteIdentifier + """ + self._validate_index(index) + return self._identifiers[index] + + def get_index(self, identifier: SiteIdentifier) -> SiteIndex: + """Gets the integer index of a site by its unique identifier. + + :param identifier: The unique identifier of the site. + :type identifier: SiteIdentifier + :raises ValueError: If the identifier is not found in the lattice. + :return: The corresponding integer index of the site. + :rtype: SiteIndex + """ + try: + return self._ident_to_idx[identifier] + except KeyError as e: + raise ValueError( + f"Identifier {identifier} not found in the lattice." + ) from e + + def get_site_info( + self, index_or_identifier: Union[SiteIndex, SiteIdentifier] + ) -> Tuple[SiteIndex, SiteIdentifier, Coordinates]: + """Gets all information for a single site. + + This method provides a convenient way to retrieve all relevant data for a + site (its index, identifier, and coordinates) by using either its + integer index or its unique identifier. + + :param index_or_identifier: The integer + index or the unique identifier of the site to look up. + :type index_or_identifier: Union[SiteIndex, SiteIdentifier] + :raises IndexError: If the given index is out of bounds. + :raises ValueError: If the given identifier is not found in the lattice. + :return: A tuple containing: + - The site's integer index. + - The site's unique identifier. + - The site's coordinates as a NumPy array. + :rtype: Tuple[SiteIndex, SiteIdentifier, Coordinates] + """ + if isinstance(index_or_identifier, int): # SiteIndex is an int + idx = index_or_identifier + self._validate_index(idx) + return idx, self._identifiers[idx], self._coordinates[idx] + else: # Identifier + ident = index_or_identifier + idx = self.get_index(ident) + return idx, ident, self._coordinates[idx] + + def sites(self) -> Iterator[Tuple[SiteIndex, SiteIdentifier, Coordinates]]: + """Returns an iterator over all sites in the lattice. + + This provides a convenient way to loop through all sites, for example: + `for idx, ident, coords in my_lattice.sites(): ...` + + :return: An iterator where each item is a tuple containing the site's + index, identifier, and coordinates. + :rtype: Iterator[Tuple[SiteIndex, SiteIdentifier, Coordinates]] + """ + for i in range(self.num_sites): + yield i, self._identifiers[i], self._coordinates[i] + + def get_neighbors(self, index: SiteIndex, k: int = 1) -> List[SiteIndex]: + """Gets the list of k-th nearest neighbor indices for a given site. + + :param index: The integer index of the center site. + :type index: SiteIndex + :param k: The order of the neighbors, where k=1 corresponds + to nearest neighbors (NN), k=2 to next-nearest neighbors (NNN), + and so on. Defaults to 1. + :type k: int, optional + :return: A list of integer indices for the neighboring sites. + Returns an empty list if neighbors for the given `k` have not been + pre-calculated or if the site has no such neighbors. + :rtype: List[SiteIndex] + """ + if k not in self._neighbor_maps: + logger.info( + f"Neighbors for k={k} not pre-computed. Building now up to max_k={k}." + ) + self._build_neighbors(max_k=k) + + if k not in self._neighbor_maps: + return [] + + return self._neighbor_maps[k].get(index, []) + + def get_neighbor_pairs( + self, k: int = 1, unique: bool = True + ) -> List[Tuple[SiteIndex, SiteIndex]]: + """Gets all pairs of k-th nearest neighbors, representing bonds. + + :param k: The order of the neighbors to consider. + Defaults to 1. + :type k: int, optional + :param unique: If True, returns only one representation + for each pair (i, j) such that i < j, avoiding duplicates + like (j, i). If False, returns all directed pairs. + Defaults to True. + :type unique: bool, optional + :return: A list of tuples, where each + tuple is a pair of neighbor indices. + :rtype: List[Tuple[SiteIndex, SiteIndex]] + """ + + if k not in self._neighbor_maps: + logger.info( + f"Neighbor pairs for k={k} not pre-computed. Building now up to max_k={k}." + ) + self._build_neighbors(max_k=k) + + # After attempting to build, check again. If still not found, return empty. + if k not in self._neighbor_maps: + return [] + + pairs = [] + for i, neighbors in self._neighbor_maps[k].items(): + for j in neighbors: + if unique: + if i < j: + pairs.append((i, j)) + else: + pairs.append((i, j)) + return sorted(pairs) + + # Sorting provides a deterministic output order + # --- Abstract Methods for Subclass Implementation --- + + @abc.abstractmethod + def _build_lattice(self, *args: Any, **kwargs: Any) -> None: + """ + Abstract method for subclasses to generate the lattice data. + + A concrete implementation of this method in a subclass is responsible + for populating the following internal attributes: + - self._indices + - self._identifiers + - self._coordinates + - self._ident_to_idx + """ + pass + + @abc.abstractmethod + def _build_neighbors(self, max_k: int = 1, **kwargs: Any) -> None: + """ + Abstract method for subclasses to calculate neighbor relationships. + + A concrete implementation of this method should calculate the neighbor + relationships up to `max_k` and populate the `self._neighbor_maps` + dictionary. The keys of the dictionary should be the neighbor order (k), + and the values should be a dictionary mapping site indices to their + list of k-th neighbors. + """ + pass + + @abc.abstractmethod + def _compute_distance_matrix(self) -> Coordinates: + """ + Abstract method for subclasses to implement the actual matrix calculation. + This method is called by the `distance_matrix` property when the matrix + needs to be computed for the first time. + """ + pass + + def show( + self, + show_indices: bool = False, + show_identifiers: bool = False, + show_bonds_k: Optional[int] = None, + ax: Optional["matplotlib.axes.Axes"] = None, + bond_kwargs: Optional[Dict[str, Any]] = None, + **kwargs: Any, + ) -> None: + """Visualizes the lattice structure using Matplotlib. + + This method supports 1D, 2D, and 3D plotting. For 1D lattices, sites + are plotted along the x-axis. + + :param show_indices: If True, displays the integer index + next to each site. Defaults to False. + :type show_indices: bool, optional + :param show_identifiers: If True, displays the unique + identifier next to each site. Defaults to False. + :type show_identifiers: bool, optional + :param show_bonds_k: Specifies which order of + neighbor bonds to draw (e.g., 1 for NN, 2 for NNN). If None, + no bonds are drawn. If the specified neighbors have not been + calculated, a warning is printed. Defaults to None. + :type show_bonds_k: Optional[int], optional + :param ax: An existing Matplotlib Axes object to plot on. + If None, a new Figure and Axes are created automatically. Defaults to None. + :type ax: Optional["matplotlib.axes.Axes"], optional + :param bond_kwargs: A dictionary of keyword arguments for customizing bond appearance, + passed directly to the Matplotlib plot function. Defaults to None. + :type bond_kwargs: Optional[Dict[str, Any]], optional + + :param kwargs: Additional keyword arguments to be passed directly to the + `matplotlib.pyplot.scatter` function for customizing site appearance. + """ + try: + import matplotlib.pyplot as plt + except ImportError: + logger.error( + "Matplotlib is required for visualization. " + "Please install it using 'pip install matplotlib'." + ) + return + + # creat "fig_created_internally" as flag + fig_created_internally = False + + if self.num_sites == 0: + logger.info("Lattice is empty, nothing to show.") + return + if self.dimensionality not in [1, 2, 3]: + logger.warning( + f"show() is not implemented for {self.dimensionality}D lattices." + ) + return + + if ax is None: + # when ax is none, make fig_created_internally true + fig_created_internally = True + if self.dimensionality == 3: + fig = plt.figure(figsize=(8, 8)) + ax = fig.add_subplot(111, projection="3d") + else: + fig, ax = plt.subplots(figsize=(8, 8)) + else: + fig = ax.figure # type: ignore + + coords = np.array(self._coordinates) + scatter_args = {"s": 100, "zorder": 2} + scatter_args.update(kwargs) + if self.dimensionality == 1: + ax.scatter(coords[:, 0], np.zeros_like(coords[:, 0]), **scatter_args) # type: ignore + elif self.dimensionality == 2: + ax.scatter(coords[:, 0], coords[:, 1], **scatter_args) # type: ignore + elif self.dimensionality > 2: # Safely handle 3D and future higher dimensions + scatter_args["s"] = scatter_args.get("s", 100) // 2 + ax.scatter(coords[:, 0], coords[:, 1], coords[:, 2], **scatter_args) # type: ignore + + if show_indices or show_identifiers: + for i in range(self.num_sites): + label = str(self._identifiers[i]) if show_identifiers else str(i) + offset = ( + 0.02 * np.max(np.ptp(coords, axis=0)) if coords.size > 0 else 0.1 + ) + + # Robust Logic: Decide plotting strategy based on known dimensionality. + + if self.dimensionality == 1: + ax.text(coords[i, 0], offset, label, fontsize=9, ha="center") + elif self.dimensionality == 2: + ax.text( + coords[i, 0] + offset, + coords[i, 1] + offset, + label, + fontsize=9, + zorder=3, + ) + elif self.dimensionality == 3: + ax_3d = cast("Axes3D", ax) + ax_3d.text( + coords[i, 0], + coords[i, 1], + coords[i, 2] + offset, + label, + fontsize=9, + zorder=3, + ) + + # Note: No 'else' needed as we already check dimensionality at the start. + + if show_bonds_k is not None: + if show_bonds_k not in self._neighbor_maps: + logger.warning( + f"Cannot draw bonds. k={show_bonds_k} neighbors have not been calculated." + ) + else: + try: + bonds = self.get_neighbor_pairs(k=show_bonds_k, unique=True) + plot_bond_kwargs = { + "color": "k", + "linestyle": "-", + "alpha": 0.6, + "zorder": 1, + } + if bond_kwargs: + plot_bond_kwargs.update(bond_kwargs) + + if self.dimensionality > 2: + ax_3d = cast("Axes3D", ax) + for i, j in bonds: + p1, p2 = self._coordinates[i], self._coordinates[j] + ax_3d.plot( + [p1[0], p2[0]], + [p1[1], p2[1]], + [p1[2], p2[2]], + **plot_bond_kwargs, + ) + else: + for i, j in bonds: + p1, p2 = self._coordinates[i], self._coordinates[j] + if self.dimensionality == 1: # type: ignore + + ax.plot([p1[0], p2[0]], [0, 0], **plot_bond_kwargs) # type: ignore + else: # dimensionality == 2 + ax.plot([p1[0], p2[0]], [p1[1], p2[1]], **plot_bond_kwargs) # type: ignore + + except ValueError as e: + logger.info(f"Could not draw bonds: {e}") + + ax.set_title(f"{self.__class__.__name__} ({self.num_sites} sites)") + if self.dimensionality == 2: + ax.set_aspect("equal", adjustable="box") + ax.set_xlabel("x") + if self.dimensionality > 1: + ax.set_ylabel("y") + if self.dimensionality > 2 and hasattr(ax, "set_zlabel"): + ax.set_zlabel("z") + ax.grid(True) + + # 3. whether plt.show() + if fig_created_internally: + plt.show() + + def _identify_distance_shells( + self, + all_distances_sq: Union[Coordinates, List[float]], + max_k: int, + tol: float = 1e-6, + ) -> List[float]: + """Identifies unique distance shells from a list of squared distances. + + This helper function takes a flat list of squared distances, sorts them, + and identifies the first `max_k` unique distance shells based on a + numerical tolerance. + + :param all_distances_sq: A list or array + of all squared distances between pairs of sites. + :type all_distances_sq: Union[np.ndarray, List[float]] + :param max_k: The maximum number of neighbor shells to identify. + :type max_k: int + :param tol: The numerical tolerance to consider two distances equal. + :type tol: float + :return: A sorted list of squared distances representing the shells. + :rtype: List[float] + """ + ZERO_THRESHOLD_SQ = 1e-12 + + all_distances_sq = np.asarray(all_distances_sq) + # Now, the .size call below is guaranteed to be safe. + if all_distances_sq.size == 0: + return [] + + sorted_dist = np.sort(all_distances_sq[all_distances_sq > ZERO_THRESHOLD_SQ]) + + if sorted_dist.size == 0: + return [] + + # Identify shells using the user-provided tolerance. + dist_shells = [sorted_dist[0]] + + for d_sq in sorted_dist[1:]: + if len(dist_shells) >= max_k: + break + # If the current distance is notably larger than the last shell's distance + if d_sq > dist_shells[-1] + tol**2: + dist_shells.append(d_sq) + + return dist_shells + + def _build_neighbors_by_distance_matrix( + self, max_k: int = 2, tol: float = 1e-6 + ) -> None: + """A generic, distance-based neighbor finding method. + + This method calculates the full N x N distance matrix to find neighbor + shells. It is computationally expensive for large N (O(N^2)) and is + best suited for non-periodic or custom-defined lattices. + + :param max_k: The maximum number of neighbor shells to + calculate. Defaults to 2. + :type max_k: int, optional + :param tol: The numerical tolerance for distance + comparisons. Defaults to 1e-6. + :type tol: float, optional + """ + if self.num_sites < 2: + return + + all_coords = np.array(self._coordinates) + dist_matrix_sq = np.sum( + (all_coords[:, np.newaxis, :] - all_coords[np.newaxis, :, :]) ** 2, axis=-1 + ) + + all_distances_sq = dist_matrix_sq.flatten() + dist_shells_sq = self._identify_distance_shells(all_distances_sq, max_k, tol) + + self._neighbor_maps = {k: {} for k in range(1, len(dist_shells_sq) + 1)} + for k_idx, target_d_sq in enumerate(dist_shells_sq): + k = k_idx + 1 + current_k_map: Dict[int, List[int]] = {} + for i in range(self.num_sites): + neighbor_indices = np.where( + np.isclose(dist_matrix_sq[i], target_d_sq, rtol=0, atol=tol**2) + )[0] + if len(neighbor_indices) > 0: + current_k_map[i] = sorted(neighbor_indices.tolist()) + self._neighbor_maps[k] = current_k_map + self._distance_matrix = np.sqrt(dist_matrix_sq) + + +class TILattice(AbstractLattice): + """Describes a periodic lattice with translational invariance. + + This class serves as a base for any lattice defined by a repeating unit + cell. The geometry is specified by lattice vectors, the coordinates of + basis sites within a unit cell, and the total size of the lattice in + terms of unit cells. + + The site identifier for this class is a tuple in the format of + `(uc_coord_1, ..., uc_coord_d, basis_index)`, where `uc_coord` represents + the integer coordinate of the unit cell and `basis_index` is the index + of the site within that unit cell's basis. + + :param dimensionality: The spatial dimension of the lattice. + :type dimensionality: int + :param lattice_vectors: The lattice vectors defining the unit + cell, given as row vectors. Shape: (dimensionality, dimensionality). + For example, in 2D: `np.array([[ax, ay], [bx, by]])`. + :type lattice_vectors: np.ndarray + :param basis_coords: The Cartesian coordinates of the basis sites + within the unit cell. Shape: (num_basis_sites, dimensionality). + For a simple Bravais lattice, this would be `np.array([[0, 0]])`. + :type basis_coords: np.ndarray + :param size: A tuple specifying the number of unit cells + to generate in each lattice vector direction (e.g., (Nx, Ny)). + :type size: Tuple[int, ...] + :param pbc: Specifies whether + periodic boundary conditions are applied. Can be a single boolean + for all dimensions or a tuple of booleans for each dimension + individually. Defaults to True. + :type pbc: Union[bool, Tuple[bool, ...]], optional + :param precompute_neighbors: If specified, pre-computes neighbor relationships + up to the given order `k` upon initialization. Defaults to None. + :type precompute_neighbors: Optional[int], optional + + """ + + def __init__( + self, + dimensionality: int, + lattice_vectors: Coordinates, + basis_coords: Coordinates, + size: Tuple[int, ...], + pbc: Union[bool, Tuple[bool, ...]] = True, + precompute_neighbors: Optional[int] = None, + ): + """Initializes the Translationally Invariant Lattice.""" + super().__init__(dimensionality) + assert lattice_vectors.shape == ( + dimensionality, + dimensionality, + ), "Lattice vectors shape mismatch" + assert ( + basis_coords.shape[1] == dimensionality + ), "Basis coordinates dimension mismatch" + assert len(size) == dimensionality, "Size tuple length mismatch" + + self.lattice_vectors = lattice_vectors + self.basis_coords = basis_coords + self.num_basis = basis_coords.shape[0] + self.size = size + if isinstance(pbc, bool): + self.pbc = tuple([pbc] * dimensionality) + else: + assert len(pbc) == dimensionality, "PBC tuple length mismatch" + self.pbc = tuple(pbc) + + # Build the lattice sites and their neighbor relationships + self._build_lattice() + if precompute_neighbors is not None and precompute_neighbors > 0: + logger.info(f"Pre-computing neighbors up to k={precompute_neighbors}...") + self._build_neighbors(max_k=precompute_neighbors) + + def _build_lattice(self) -> None: + """Generates all site information for the periodic lattice. + + This method iterates through each unit cell defined by `self.size`, + and for each unit cell, it iterates through all basis sites. It then + calculates the real-space coordinates and creates a unique identifier + for each site, populating the internal lattice data structures. + """ + current_index = 0 + + # Iterate over all unit cell coordinates elegantly using np.ndindex + for cell_coord in np.ndindex(self.size): + cell_coord_arr = np.array(cell_coord) + # R = n1*a1 + n2*a2 + ... + cell_vector = np.dot(cell_coord_arr, self.lattice_vectors) + + # Iterate over the basis sites within the unit cell + for basis_index in range(self.num_basis): + basis_vec = self.basis_coords[basis_index] + + # Calculate the real-space coordinate + coord = cell_vector + basis_vec + # Create a structured identifier + identifier = cell_coord + (basis_index,) + + # Store site information + self._indices.append(current_index) + self._identifiers.append(identifier) + self._coordinates.append(coord) + self._ident_to_idx[identifier] = current_index + current_index += 1 + + def _get_distance_matrix_with_mic(self) -> Coordinates: + """ + Computes the full N x N distance matrix, correctly applying the + Minimum Image Convention (MIC) for all periodic dimensions. + """ + all_coords = np.array(self._coordinates) + size_arr = np.array(self.size) + system_vectors = self.lattice_vectors * size_arr[:, np.newaxis] + + # Generate translation vectors ONLY for periodic dimensions + pbc_dims = [d for d in range(self.dimensionality) if self.pbc[d]] + translations = [np.zeros(self.dimensionality)] + if pbc_dims: + num_pbc_dims = len(pbc_dims) + pbc_system_vectors = system_vectors[pbc_dims, :] + + # Create all 3^k - 1 non-zero shifts for k periodic dimensions + shift_options = [np.array([-1, 0, 1])] * num_pbc_dims + shifts_grid = np.meshgrid(*shift_options, indexing="ij") + all_shifts = np.stack(shifts_grid, axis=-1).reshape(-1, num_pbc_dims) + all_shifts = all_shifts[np.any(all_shifts != 0, axis=1)] + + pbc_translations = all_shifts @ pbc_system_vectors + translations.extend(pbc_translations) + + translations_arr = np.array(translations, dtype=float) + + # Calculate the distance matrix applying MIC + dist_matrix_sq = np.full((self.num_sites, self.num_sites), np.inf, dtype=float) + for i in range(self.num_sites): + displacements = all_coords - all_coords[i] + image_displacements = ( + displacements[:, np.newaxis, :] - translations_arr[np.newaxis, :, :] + ) + image_d_sq = np.sum(image_displacements**2, axis=2) + dist_matrix_sq[i, :] = np.min(image_d_sq, axis=1) + + return cast(Coordinates, np.sqrt(dist_matrix_sq)) + + def _build_neighbors(self, max_k: int = 2, **kwargs: Any) -> None: + """Calculates neighbor relationships for the periodic lattice. + + This method calculates neighbor relationships by computing the full N x N + distance matrix. It robustly handles all boundary conditions (fully + periodic, open, or mixed) by applying the Minimum Image Convention + (MIC) only to the periodic dimensions. + + From this distance matrix, it identifies unique neighbor shells up to + the specified `max_k` and populates the neighbor maps. The computed + distance matrix is then cached for future use. + + :param max_k: The maximum number of neighbor shells to + calculate. Defaults to 2. + :type max_k: int, optional + :param tol: The numerical tolerance for distance + comparisons. Defaults to 1e-6. + :type tol: float, optional + """ + tol = kwargs.get("tol", 1e-6) + dist_matrix = self._get_distance_matrix_with_mic() + dist_matrix_sq = dist_matrix**2 + self._distance_matrix = dist_matrix + all_distances_sq = dist_matrix_sq.flatten() + dist_shells_sq = self._identify_distance_shells(all_distances_sq, max_k, tol) + + self._neighbor_maps = {k: {} for k in range(1, len(dist_shells_sq) + 1)} + for k_idx, target_d_sq in enumerate(dist_shells_sq): + k = k_idx + 1 + current_k_map: Dict[int, List[int]] = {} + match_indices = np.where( + np.isclose(dist_matrix_sq, target_d_sq, rtol=0, atol=tol**2) + ) + for i, j in zip(*match_indices): + if i == j: + continue + if i not in current_k_map: + current_k_map[i] = [] + current_k_map[i].append(j) + + for i in current_k_map: + current_k_map[i].sort() + + self._neighbor_maps[k] = current_k_map + + def _compute_distance_matrix(self) -> Coordinates: + """Computes the distance matrix using the Minimum Image Convention.""" + return self._get_distance_matrix_with_mic() + + +class SquareLattice(TILattice): + """A 2D square lattice. + + This is a concrete implementation of a translationally invariant lattice + representing a simple square grid. It is a Bravais lattice with a + single-site basis. + + :param size: A tuple (Nx, Ny) specifying the number of + unit cells (sites) in the x and y directions. + :type size: Tuple[int, int] + :param lattice_constant: The distance between two adjacent + sites. Defaults to 1.0. + :type lattice_constant: float, optional + :param pbc: Specifies periodic boundary conditions. Can be a single boolean + for all dimensions or a tuple of booleans for each dimension + individually. Defaults to True. + :type pbc: Union[bool, Tuple[bool, bool]], optional + :param precompute_neighbors: If specified, pre-computes neighbor relationships + up to the given order `k` upon initialization. Defaults to None. + :type precompute_neighbors: Optional[int], optional + """ + + def __init__( + self, + size: Tuple[int, int], + lattice_constant: float = 1.0, + pbc: Union[bool, Tuple[bool, bool]] = True, + precompute_neighbors: Optional[int] = None, + ): + """Initializes the SquareLattice.""" + dimensionality = 2 + + # Define lattice vectors for a square lattice + lattice_vectors = np.array([[lattice_constant, 0.0], [0.0, lattice_constant]]) + + # A square lattice has a single site in its basis + basis_coords = np.array([[0.0, 0.0]]) + + # Call the parent TILattice constructor with these parameters + super().__init__( + dimensionality=dimensionality, + lattice_vectors=lattice_vectors, + basis_coords=basis_coords, + size=size, + pbc=pbc, + precompute_neighbors=precompute_neighbors, + ) + + +class HoneycombLattice(TILattice): + """A 2D honeycomb lattice. + + This is a classic example of a composite lattice. It consists of a + two-site basis (sublattices A and B) on an underlying triangular + Bravais lattice. + + :param size: A tuple (Nx, Ny) specifying the number of unit + cells along the two lattice vector directions. + :type size: Tuple[int, int] + :param lattice_constant: The bond length, i.e., the distance + between two nearest neighbor sites. Defaults to 1.0. + :type lattice_constant: float, optional + :param pbc: Specifies periodic + boundary conditions. Defaults to True. + :type pbc: Union[bool, Tuple[bool, bool]], optional + :param precompute_neighbors: If specified, pre-computes neighbor relationships + up to the given order `k` upon initialization. Defaults to None. + :type precompute_neighbors: Optional[int], optional + + """ + + def __init__( + self, + size: Tuple[int, int], + lattice_constant: float = 1.0, + pbc: Union[bool, Tuple[bool, bool]] = True, + precompute_neighbors: Optional[int] = None, + ): + """Initializes the HoneycombLattice.""" + dimensionality = 2 + a = lattice_constant + + # Define the primitive lattice vectors for the underlying triangular lattice + lattice_vectors = a * np.array([[1.5, np.sqrt(3) / 2], [1.5, -np.sqrt(3) / 2]]) + + # Define the coordinates of the two basis sites (A and B) + basis_coords = a * np.array([[0.0, 0.0], [1.0, 0.0]]) # Site A # Site B + + super().__init__( + dimensionality=dimensionality, + lattice_vectors=lattice_vectors, + basis_coords=basis_coords, + size=size, + pbc=pbc, + precompute_neighbors=precompute_neighbors, + ) + + +class TriangularLattice(TILattice): + """A 2D triangular lattice. + + This is a Bravais lattice where each site has 6 nearest neighbors. + + :param size: A tuple (Nx, Ny) specifying the number of + unit cells along the two lattice vector directions. + :type size: Tuple[int, int] + :param lattice_constant: The bond length, i.e., the + distance between two nearest neighbor sites. Defaults to 1.0. + :type lattice_constant: float, optional + :param pbc: Specifies periodic + boundary conditions. Defaults to True. + :type pbc: Union[bool, Tuple[bool, bool]], optional + :param precompute_neighbors: If specified, pre-computes neighbor relationships + up to the given order `k` upon initialization. Defaults to None. + :type precompute_neighbors: Optional[int], optional + + """ + + def __init__( + self, + size: Tuple[int, int], + lattice_constant: float = 1.0, + pbc: Union[bool, Tuple[bool, bool]] = True, + precompute_neighbors: Optional[int] = None, + ): + """Initializes the TriangularLattice.""" + dimensionality = 2 + a = lattice_constant + + # Define the primitive lattice vectors for a triangular lattice + lattice_vectors = a * np.array([[1.0, 0.0], [0.5, np.sqrt(3) / 2]]) + + # A triangular lattice is a Bravais lattice, with a single site in its basis + basis_coords = np.array([[0.0, 0.0]]) + + super().__init__( + dimensionality=dimensionality, + lattice_vectors=lattice_vectors, + basis_coords=basis_coords, + size=size, + pbc=pbc, + precompute_neighbors=precompute_neighbors, + ) + + +class ChainLattice(TILattice): + """A 1D chain (simple Bravais lattice). + + :param size: A tuple `(N,)` specifying the number of sites in the chain. + :type size: Tuple[int] + :param lattice_constant: The distance between two adjacent sites. Defaults to 1.0. + :type lattice_constant: float, optional + :param pbc: Specifies if periodic boundary conditions are applied. Defaults to True. + :type pbc: bool, optional + :param precompute_neighbors: If specified, pre-computes neighbor relationships + up to the given order `k` upon initialization. Defaults to None. + :type precompute_neighbors: Optional[int], optional + """ + + def __init__( + self, + size: Tuple[int], + lattice_constant: float = 1.0, + pbc: bool = True, + precompute_neighbors: Optional[int] = None, + ): + dimensionality = 1 + lattice_vectors = np.array([[lattice_constant]]) + basis_coords = np.array([[0.0]]) + super().__init__( + dimensionality=dimensionality, + lattice_vectors=lattice_vectors, + basis_coords=basis_coords, + size=size, + pbc=pbc, + precompute_neighbors=precompute_neighbors, + ) + + +class DimerizedChainLattice(TILattice): + """A 1D chain with an AB sublattice (dimerized chain). + + The unit cell contains two sites, A and B. The bond length is uniform. + + :param size: A tuple `(N,)` specifying the number of **unit cells**. + The total number of sites in the chain will be `2 * N`, as each + unit cell contains two sites. + :type size: Tuple[int] + :param lattice_constant: The distance between two adjacent sites (bond length). Defaults to 1.0. + :type lattice_constant: float, optional + :param pbc: Specifies if periodic boundary conditions are applied. Defaults to True. + :type pbc: bool, optional + :param precompute_neighbors: If specified, pre-computes neighbor relationships + up to the given order `k` upon initialization. Defaults to None. + :type precompute_neighbors: Optional[int], optional + """ + + def __init__( + self, + size: Tuple[int], + lattice_constant: float = 1.0, + pbc: bool = True, + precompute_neighbors: Optional[int] = None, + ): + dimensionality = 1 + # The unit cell vector connects two A sites, spanning length 2*a + lattice_vectors = np.array([[2 * lattice_constant]]) + # Basis has site A at origin, site B at distance 'a' + basis_coords = np.array([[0.0], [lattice_constant]]) + + super().__init__( + dimensionality=dimensionality, + lattice_vectors=lattice_vectors, + basis_coords=basis_coords, + size=size, + pbc=pbc, + precompute_neighbors=precompute_neighbors, + ) + + +class RectangularLattice(TILattice): + """A 2D rectangular lattice. + + This is a generalization of the SquareLattice where the lattice constants + in the x and y directions can be different. + + :param size: A tuple (Nx, Ny) specifying the number of sites in x and y. + :type size: Tuple[int, int] + :param lattice_constants: The distance between adjacent sites + in the x and y directions, e.g., (ax, ay). Defaults to (1.0, 1.0). + :type lattice_constants: Tuple[float, float], optional + :param pbc: Specifies periodic boundary conditions. Defaults to True. + :type pbc: Union[bool, Tuple[bool, bool]], optional + :param precompute_neighbors: If specified, pre-computes neighbor relationships + up to the given order `k` upon initialization. Defaults to None. + :type precompute_neighbors: Optional[int], optional + """ + + def __init__( + self, + size: Tuple[int, int], + lattice_constants: Tuple[float, float] = (1.0, 1.0), + pbc: Union[bool, Tuple[bool, bool]] = True, + precompute_neighbors: Optional[int] = None, + ): + dimensionality = 2 + ax, ay = lattice_constants + lattice_vectors = np.array([[ax, 0.0], [0.0, ay]]) + basis_coords = np.array([[0.0, 0.0]]) + + super().__init__( + dimensionality=dimensionality, + lattice_vectors=lattice_vectors, + basis_coords=basis_coords, + size=size, + pbc=pbc, + precompute_neighbors=precompute_neighbors, + ) + + +class CheckerboardLattice(TILattice): + """A 2D checkerboard lattice (a square lattice with an AB sublattice). + + The unit cell is a square rotated by 45 degrees, containing two sites. + + :param size: A tuple (Nx, Ny) specifying the number of unit cells. Total sites will be 2*Nx*Ny. + :type size: Tuple[int, int] + :param lattice_constant: The bond length between nearest neighbors. Defaults to 1.0. + :type lattice_constant: float, optional + :param pbc: Specifies periodic boundary conditions. Defaults to True. + :type pbc: Union[bool, Tuple[bool, bool]], optional + :param precompute_neighbors: If specified, pre-computes neighbor relationships + up to the given order `k` upon initialization. Defaults to None. + :type precompute_neighbors: Optional[int], optional + """ + + def __init__( + self, + size: Tuple[int, int], + lattice_constant: float = 1.0, + pbc: Union[bool, Tuple[bool, bool]] = True, + precompute_neighbors: Optional[int] = None, + ): + dimensionality = 2 + a = lattice_constant + # Primitive vectors for a square lattice rotated by 45 degrees. + lattice_vectors = a * np.array([[1.0, 1.0], [1.0, -1.0]]) + # Two-site basis + basis_coords = a * np.array([[0.0, 0.0], [1.0, 0.0]]) + super().__init__( + dimensionality=dimensionality, + lattice_vectors=lattice_vectors, + basis_coords=basis_coords, + size=size, + pbc=pbc, + precompute_neighbors=precompute_neighbors, + ) + + +class KagomeLattice(TILattice): + """A 2D Kagome lattice. + + This is a lattice with a three-site basis on a triangular Bravais lattice. + + :param size: A tuple (Nx, Ny) specifying the number of unit cells. Total sites will be 3*Nx*Ny. + :type size: Tuple[int, int] + :param lattice_constant: The bond length. Defaults to 1.0. + :type lattice_constant: float, optional + :param pbc: Specifies periodic boundary conditions. Defaults to True. + :type pbc: Union[bool, Tuple[bool, bool]], optional + :param precompute_neighbors: If specified, pre-computes neighbor relationships + up to the given order `k` upon initialization. Defaults to None. + :type precompute_neighbors: Optional[int], optional + """ + + def __init__( + self, + size: Tuple[int, int], + lattice_constant: float = 1.0, + pbc: Union[bool, Tuple[bool, bool]] = True, + precompute_neighbors: Optional[int] = None, + ): + dimensionality = 2 + a = lattice_constant + # Using a rectangular unit cell definition for simplicity + lattice_vectors = a * np.array([[2.0, 0.0], [1.0, np.sqrt(3)]]) + # Three-site basis + basis_coords = a * np.array([[0.0, 0.0], [1.0, 0.0], [0.5, np.sqrt(3) / 2.0]]) + super().__init__( + dimensionality=dimensionality, + lattice_vectors=lattice_vectors, + basis_coords=basis_coords, + size=size, + pbc=pbc, + precompute_neighbors=precompute_neighbors, + ) + + +class LiebLattice(TILattice): + """A 2D Lieb lattice. + + This is a lattice with a three-site basis on a square Bravais lattice. + It has sites at the corners and centers of the edges of a square. + + :param size: A tuple (Nx, Ny) specifying the number of unit cells. Total sites will be 3*Nx*Ny. + :type size: Tuple[int, int] + :param lattice_constant: The bond length. Defaults to 1.0. + :type lattice_constant: float, optional + :param pbc: Specifies periodic boundary conditions. Defaults to True. + :type pbc: Union[bool, Tuple[bool, bool]], optional + :param precompute_neighbors: If specified, pre-computes neighbor relationships + up to the given order `k` upon initialization. Defaults to None. + :type precompute_neighbors: Optional[int], optional + """ + + def __init__( + self, + size: Tuple[int, int], + lattice_constant: float = 1.0, + pbc: Union[bool, Tuple[bool, bool]] = True, + precompute_neighbors: Optional[int] = None, + ): + """Initializes the LiebLattice.""" + dimensionality = 2 + # Use a more descriptive name for clarity. In a Lieb lattice, + # the lattice_constant is the bond length between nearest neighbors. + bond_length = lattice_constant + + # The unit cell of a Lieb lattice is a square with side length + # equal to twice the bond length. + unit_cell_side = 2 * bond_length + lattice_vectors = np.array([[unit_cell_side, 0.0], [0.0, unit_cell_side]]) + + # The three-site basis consists of a corner site, a site on the + # center of the horizontal edge, and a site on the center of the vertical edge. + # Their coordinates are defined directly in terms of the physical bond length. + basis_coords = np.array( + [ + [0.0, 0.0], # Corner site + [bond_length, 0.0], # Horizontal edge center + [0.0, bond_length], # Vertical edge center + ] + ) + + super().__init__( + dimensionality=dimensionality, + lattice_vectors=lattice_vectors, + basis_coords=basis_coords, + size=size, + pbc=pbc, + precompute_neighbors=precompute_neighbors, + ) + + +class CubicLattice(TILattice): + """A 3D cubic lattice. + + This is a simple Bravais lattice, the 3D generalization of SquareLattice. + + :param size: A tuple (Nx, Ny, Nz) specifying the number of sites. + :type size: Tuple[int, int, int] + :param lattice_constant: The distance between adjacent sites. Defaults to 1.0. + :type lattice_constant: float, optional + :param pbc: Specifies periodic boundary conditions. Defaults to True. + :type pbc: Union[bool, Tuple[bool, bool, bool]], optional + :param precompute_neighbors: If specified, pre-computes neighbor relationships + up to the given order `k` upon initialization. Defaults to None. + :type precompute_neighbors: Optional[int], optional + """ + + def __init__( + self, + size: Tuple[int, int, int], + lattice_constant: float = 1.0, + pbc: Union[bool, Tuple[bool, bool, bool]] = True, + precompute_neighbors: Optional[int] = None, + ): + dimensionality = 3 + a = lattice_constant + lattice_vectors = np.array([[a, 0, 0], [0, a, 0], [0, 0, a]]) + basis_coords = np.array([[0.0, 0.0, 0.0]]) + super().__init__( + dimensionality=dimensionality, + lattice_vectors=lattice_vectors, + basis_coords=basis_coords, + size=size, + pbc=pbc, + precompute_neighbors=precompute_neighbors, + ) + + +class CustomizeLattice(AbstractLattice): + """A general lattice built from an explicit list of sites and coordinates. + + This class is suitable for creating lattices with arbitrary geometries, + such as finite clusters, disordered systems, or any custom structure + that does not have translational symmetry. The lattice is defined simply + by providing lists of identifiers and coordinates for each site. + + :param dimensionality: The spatial dimension of the lattice. + :type dimensionality: int + :param identifiers: A list of unique, hashable + identifiers for the sites. The length must match `coordinates`. + :type identifiers: List[SiteIdentifier] + :param coordinates: A list of site + coordinates. Each coordinate should be a list of floats or a + NumPy array. + :type coordinates: List[Union[List[float], Coordinates]] + :raises ValueError: If the lengths of `identifiers` and `coordinates` lists + do not match, or if a coordinate's dimension is incorrect. + :param precompute_neighbors: If specified, pre-computes neighbor relationships + up to the given order `k` upon initialization. Defaults to None. + :type precompute_neighbors: Optional[int], optional + + """ + + def __init__( + self, + dimensionality: int, + identifiers: List[SiteIdentifier], + coordinates: List[Union[List[float], Coordinates]], + precompute_neighbors: Optional[int] = None, + ): + """Initializes the CustomizeLattice.""" + super().__init__(dimensionality) + if len(identifiers) != len(coordinates): + raise ValueError( + "Identifiers and coordinates lists must have the same length." + ) + + # The _build_lattice logic is simple enough to be in __init__ + self._identifiers = list(identifiers) + self._coordinates = [np.array(c) for c in coordinates] + self._indices = list(range(len(identifiers))) + self._ident_to_idx = {ident: idx for idx, ident in enumerate(identifiers)} + + # Validate coordinate dimensions + for i, coord in enumerate(self._coordinates): + if coord.shape != (dimensionality,): + raise ValueError( + f"Coordinate at index {i} has shape {coord.shape}, " + f"expected ({dimensionality},)" + ) + + logger.info(f"CustomizeLattice with {self.num_sites} sites created.") + + if precompute_neighbors is not None and precompute_neighbors > 0: + self._build_neighbors(max_k=precompute_neighbors) + + def _build_lattice(self, *args: Any, **kwargs: Any) -> None: + """For CustomizeLattice, lattice data is built during __init__.""" + pass + + def _build_neighbors(self, max_k: int = 1, **kwargs: Any) -> None: + """Calculates neighbors using a KDTree for efficiency. + + This method uses a memory-efficient approach to identify neighbors without + initially computing the full N x N distance matrix. It leverages + `scipy.spatial.distance.pdist` to find unique distance shells and then + a `scipy.spatial.KDTree` for fast radius queries. This approach is + significantly more memory-efficient during the neighbor identification phase. + + After the neighbors are identified, the full distance matrix is computed + from the pairwise distances and cached for potential future use. + + :param max_k: The maximum number of neighbor shells to + calculate. Defaults to 1. + :type max_k: int, optional + :param tol: The numerical tolerance for distance + comparisons. Defaults to 1e-6. + :type tol: float, optional + """ + tol = kwargs.get("tol", 1e-6) + logger.info(f"Building neighbors for CustomizeLattice up to k={max_k}...") + if self.num_sites < 2: + return + + all_coords = np.array(self._coordinates) + + # 1. Use pdist for memory-efficient calculation of pairwise distances + # to robustly identify the distance shells. + all_distances_sq = pdist(all_coords, metric="sqeuclidean") + dist_shells_sq = self._identify_distance_shells(all_distances_sq, max_k, tol) + + if not dist_shells_sq: + logger.info("No distinct neighbor shells found.") + return + + # 2. Build the KDTree for efficient querying. + tree = KDTree(all_coords) + self._neighbor_maps = {k: {} for k in range(1, len(dist_shells_sq) + 1)} + + # 3. Find neighbors by isolating shells using inclusion-exclusion. + # `found_indices` will store all neighbors within a given radius. + found_indices: List[set[int]] = [] + for k_idx, target_d_sq in enumerate(dist_shells_sq): + radius = np.sqrt(target_d_sq) + tol + # Query for all points within the new, larger radius. + current_shell_indices = tree.query_ball_point( + all_coords, r=radius, return_sorted=True + ) + + # Now, isolate the neighbors for the current shell k + k = k_idx + 1 + current_k_map: Dict[int, List[int]] = {} + for i in range(self.num_sites): + + if k_idx == 0: + co_located_indices = tree.query_ball_point(all_coords[i], r=1e-12) + prev_found = set(co_located_indices) + else: + prev_found = found_indices[i] + + # The new neighbors are those in the current radius shell, + # excluding those already found in smaller shells. + new_neighbors = set(current_shell_indices[i]) - prev_found + + if new_neighbors: + current_k_map[i] = sorted(list(new_neighbors)) + + self._neighbor_maps[k] = current_k_map + found_indices = [ + set(l) for l in current_shell_indices + ] # Update for next iteration + self._distance_matrix = np.sqrt(squareform(all_distances_sq)) + + logger.info("Neighbor building complete using KDTree.") + + def _compute_distance_matrix(self) -> Coordinates: + """Computes the distance matrix from the stored coordinates. + + This implementation uses scipy.pdist for a memory-efficient + calculation of pairwise distances, which is then converted to a + full square matrix. + """ + if self.num_sites < 2: + return cast(Coordinates, np.empty((self.num_sites, self.num_sites))) + + all_coords = np.array(self._coordinates) + # Use pdist for memory-efficiency, then build the full matrix. + all_distances_sq = pdist(all_coords, metric="sqeuclidean") + dist_matrix_sq = squareform(all_distances_sq) + return cast(Coordinates, np.sqrt(dist_matrix_sq)) + + def _reset_computations(self) -> None: + """Resets all cached data that depends on the lattice structure.""" + self._neighbor_maps = {} + self._distance_matrix = None + + @classmethod + def from_lattice(cls, lattice: "AbstractLattice") -> "CustomizeLattice": + """Creates a CustomizeLattice instance from any existing lattice object. + + This is useful for 'detaching' a procedurally generated lattice (like + a SquareLattice) into a customizable one for further modifications, + such as adding defects or extra sites. + + :param lattice: An instance of any AbstractLattice subclass. + :type lattice: AbstractLattice + :return: A new CustomizeLattice instance with the same sites. + :rtype: CustomizeLattice + """ + all_sites_info = list(lattice.sites()) + + if not all_sites_info: + return cls( + dimensionality=lattice.dimensionality, identifiers=[], coordinates=[] + ) + + # Unzip the list of tuples into separate lists of identifiers and coordinates + _, identifiers, coordinates = zip(*all_sites_info) + + return cls( + dimensionality=lattice.dimensionality, + identifiers=list(identifiers), + coordinates=list(coordinates), + ) + + def add_sites( + self, + identifiers: List[SiteIdentifier], + coordinates: List[Union[List[float], Coordinates]], + ) -> None: + """Adds new sites to the lattice. + + This operation modifies the lattice in-place. After adding sites, any + previously computed neighbor information is cleared and must be + recalculated. + + :param identifiers: A list of unique, hashable identifiers for the new sites. + :type identifiers: List[SiteIdentifier] + :param coordinates: A list of coordinates for the new sites. + :type coordinates: List[Union[List[float], np.ndarray]] + :raises ValueError: If input lists have mismatched lengths, or if any new + identifier already exists in the lattice. + """ + if len(identifiers) != len(coordinates): + raise ValueError( + "Identifiers and coordinates lists must have the same length." + ) + if not identifiers: + return # Nothing to add + + # Check for duplicate identifiers before making any changes + existing_ids = set(self._identifiers) + new_ids = set(identifiers) + if not new_ids.isdisjoint(existing_ids): + raise ValueError( + f"Duplicate identifiers found: {new_ids.intersection(existing_ids)}" + ) + + for i, coord in enumerate(coordinates): + coord_arr = np.asarray(coord) + if coord_arr.shape != (self.dimensionality,): + raise ValueError( + f"New coordinate at index {i} has shape {coord_arr.shape}, " + f"expected ({self.dimensionality},)" + ) + self._coordinates.append(coord_arr) + self._identifiers.append(identifiers[i]) + + # Rebuild index mappings from scratch + self._indices = list(range(len(self._identifiers))) + self._ident_to_idx = {ident: idx for idx, ident in enumerate(self._identifiers)} + + # Invalidate any previously computed neighbors or distance matrices + self._reset_computations() + logger.info( + f"{len(identifiers)} sites added. Lattice now has {self.num_sites} sites." + ) + + def remove_sites(self, identifiers: List[SiteIdentifier]) -> None: + """Removes specified sites from the lattice. + + This operation modifies the lattice in-place. After removing sites, + all site indices are re-calculated, and any previously computed + neighbor information is cleared. + + :param identifiers: A list of identifiers for the sites to be removed. + :type identifiers: List[SiteIdentifier] + :raises ValueError: If any of the specified identifiers do not exist. + """ + if not identifiers: + return # Nothing to remove + + ids_to_remove = set(identifiers) + current_ids = set(self._identifiers) + if not ids_to_remove.issubset(current_ids): + raise ValueError( + f"Non-existent identifiers provided for removal: {ids_to_remove - current_ids}" + ) + + # Create new lists containing only the sites to keep + new_identifiers: List[SiteIdentifier] = [] + new_coordinates: List[Coordinates] = [] + for ident, coord in zip(self._identifiers, self._coordinates): + if ident not in ids_to_remove: + new_identifiers.append(ident) + new_coordinates.append(coord) + + # Replace old data with the new, filtered data + self._identifiers = new_identifiers + self._coordinates = new_coordinates + + # Rebuild index mappings + self._indices = list(range(len(self._identifiers))) + self._ident_to_idx = {ident: idx for idx, ident in enumerate(self._identifiers)} + + # Invalidate caches + self._reset_computations() + logger.info( + f"{len(ids_to_remove)} sites removed. Lattice now has {self.num_sites} sites." + ) diff --git a/tests/test_lattice.py b/tests/test_lattice.py new file mode 100644 index 00000000..ff5386ef --- /dev/null +++ b/tests/test_lattice.py @@ -0,0 +1,1665 @@ +from unittest.mock import patch +import logging +import time + +import matplotlib + +matplotlib.use("Agg") + + +import pytest +import numpy as np + +from tensorcircuit.templates.lattice import ( + ChainLattice, + CheckerboardLattice, + CubicLattice, + CustomizeLattice, + DimerizedChainLattice, + HoneycombLattice, + KagomeLattice, + LiebLattice, + RectangularLattice, + SquareLattice, + TriangularLattice, +) + + +@pytest.fixture +def simple_square_lattice() -> CustomizeLattice: + """ + Provides a simple 2x2 square CustomizeLattice instance for neighbor tests. + The sites are indexed as follows: + 2--3 + | | + 0--1 + """ + coords = [[0.0, 0.0], [1.0, 0.0], [0.0, 1.0], [1.0, 1.0]] + ids = list(range(len(coords))) + lattice = CustomizeLattice(dimensionality=2, identifiers=ids, coordinates=coords) + # Pre-calculate neighbors up to the 2nd shell for use in tests. + lattice._build_neighbors(max_k=2) + return lattice + + +@pytest.fixture +def kagome_lattice_fragment() -> CustomizeLattice: + """ + Pytest fixture to provide a standard CustomizeLattice instance. + This represents the Kagome fragment from the project requirements, + making it a reusable object for multiple tests. + """ + kag_coords = [ + [0.0, 0.0], + [1.0, 0.0], + [0.5, np.sqrt(3) / 2], # Triangle 1 + [2, 0], + [1.5, np.sqrt(3) / 2], # Triangle 2 (shifted basis) + [1.0, np.sqrt(3)], # Top site + ] + kag_ids = list(range(len(kag_coords))) + return CustomizeLattice( + dimensionality=2, identifiers=kag_ids, coordinates=kag_coords + ) + + +class TestCustomizeLattice: + """ + A test class to group all tests related to the CustomizeLattice. + This helps in organizing the test suite. + """ + + def test_initialization_and_properties(self, kagome_lattice_fragment): + """ + Test case for successful initialization and verification of basic properties. + This test function receives the 'kagome_lattice_fragment' fixture as an argument. + """ + # Arrange: The fixture has already prepared the 'lattice' object for us. + lattice = kagome_lattice_fragment + + # Assert: Check if the object's properties match our expectations. + assert lattice.dimensionality == 2 + assert lattice.num_sites == 6 + assert len(lattice) == 6 # This also tests the __len__ dunder method + + # Verify that coordinates are correctly stored as numpy arrays. + # It's important to use np.testing.assert_array_equal for numpy array comparison. + expected_coord = np.array([0.5, np.sqrt(3) / 2]) + np.testing.assert_array_equal(lattice.get_coordinates(2), expected_coord) + + # Verify that the mapping between identifiers and indices is correct. + assert lattice.get_identifier(4) == 4 + assert lattice.get_index(4) == 4 + + def test_input_validation_mismatched_lengths(self): + """ + Tests that a ValueError is raised if identifiers and coordinates + lists have mismatched lengths. + """ + # Arrange: Prepare invalid inputs. + coords = [[0.0, 0.0], [1.0, 0.0]] # 2 coordinates + ids = [0, 1, 2] # 3 identifiers + + # Act & Assert: Use pytest.raises as a context manager to ensure + # the specified exception is raised within the 'with' block. + with pytest.raises( + ValueError, + match="Identifiers and coordinates lists must have the same length.", + ): + CustomizeLattice(dimensionality=2, identifiers=ids, coordinates=coords) + + def test_input_validation_wrong_dimension(self): + """ + Tests that a ValueError is raised if a coordinate's dimension + does not match the lattice's specified dimensionality. + """ + # Arrange: Prepare coordinates with mixed dimensions for a 2D lattice. + coords_wrong_dim = [[0.0, 0.0], [1.0, 0.0, 0.0]] # A mix of 2D and 3D + ids_ok = [0, 1] + + # Act & Assert: Check for the specific error message. The 'r' before the string + # indicates a raw string, which is good practice for regex patterns. + with pytest.raises( + ValueError, match=r"Coordinate at index 1 has shape \(3,\), expected \(2,\)" + ): + CustomizeLattice( + dimensionality=2, identifiers=ids_ok, coordinates=coords_wrong_dim + ) + + def test_neighbor_finding(self, simple_square_lattice): + """ + Tests the k-th nearest neighbor finding functionality (_build_neighbors + and get_neighbors). + """ + # Arrange: The fixture provides the lattice with pre-built neighbors. + lattice = simple_square_lattice + + # --- Assertions for k=1 (Nearest Neighbors) --- + # We use set() for comparison to ignore the order of neighbors. + assert set(lattice.get_neighbors(0, k=1)) == {1, 2} + assert set(lattice.get_neighbors(1, k=1)) == {0, 3} + assert set(lattice.get_neighbors(2, k=1)) == {0, 3} + assert set(lattice.get_neighbors(3, k=1)) == {1, 2} + + # --- Assertions for k=2 (Next-Nearest Neighbors) --- + # These should be the diagonal sites. + assert set(lattice.get_neighbors(0, k=2)) == {3} + assert set(lattice.get_neighbors(1, k=2)) == {2} + + def test_neighbor_pairs(self, simple_square_lattice): + """ + Tests the retrieval of unique neighbor pairs (bonds) using + get_neighbor_pairs. + """ + # Arrange: Use the same fixture. + lattice = simple_square_lattice + + # --- Test for k=1 (Nearest Neighbor bonds) --- + # Act: Get unique nearest neighbor pairs. + nn_pairs = lattice.get_neighbor_pairs(k=1, unique=True) + + # Assert: The set of pairs should match the expected bonds. + # We convert the list of pairs to a set of tuples for order-independent comparison. + expected_nn_pairs = {(0, 1), (0, 2), (1, 3), (2, 3)} + assert set(map(tuple, nn_pairs)) == expected_nn_pairs + + # --- Test for k=2 (Next-Nearest Neighbor bonds) --- + # Act: Get unique next-nearest neighbor pairs. + nnn_pairs = lattice.get_neighbor_pairs(k=2, unique=True) + + # Assert: + expected_nnn_pairs = {(0, 3), (1, 2)} + assert set(map(tuple, nnn_pairs)) == expected_nnn_pairs + + def test_neighbor_pairs_non_unique(self, simple_square_lattice): + """ + Tests get_neighbor_pairs with unique=False to ensure all + directed pairs (bonds) are returned. + """ + # Arrange: Use the same 2x2 square lattice fixture. + # 2--3 + # | | + # 0--1 + lattice = simple_square_lattice + + # Act: Get NON-unique nearest neighbor pairs. + nn_pairs = lattice.get_neighbor_pairs(k=1, unique=False) + + # Assert: + # There are 4 bonds, so we expect 4 * 2 = 8 directed pairs. + assert len(nn_pairs) == 8 + + # Your source code sorts the output, so we can compare against a + # sorted list for a precise match. + expected_pairs = sorted( + [(0, 1), (1, 0), (0, 2), (2, 0), (1, 3), (3, 1), (2, 3), (3, 2)] + ) + + assert nn_pairs == expected_pairs + + @patch("matplotlib.pyplot.show") + def test_show_method_runs_and_calls_plt_show( + self, mock_show, simple_square_lattice + ): + """ + Smoke test for the .show() method. + It verifies that the method runs without raising an exception and that it + triggers a call to matplotlib's show() function. + We use @patch to "mock" the show function, preventing a plot window + from actually appearing during tests. + """ + # Arrange: Get the lattice instance from the fixture + lattice = simple_square_lattice + + # Act: Call the .show() method. + # We wrap it in a try...except block to give a more specific error + # if the method fails for any reason. + try: + lattice.show() + except Exception as e: + pytest.fail(f".show() method raised an unexpected exception: {e}") + + # Assert: Check that our mocked matplotlib.pyplot.show was called exactly once. + mock_show.assert_called_once() + + def test_sites_iterator(self, simple_square_lattice): + """ + Tests the sites() iterator to ensure it yields all sites correctly. + """ + # Arrange + lattice = simple_square_lattice + expected_num_sites = 4 + + # Act + # The sites() method returns an iterator, we convert it to a list to check its length. + all_sites = list(lattice.sites()) + + # Assert + assert len(all_sites) == expected_num_sites + + # For a more thorough check, verify the content of one of the yielded tuples. + # For the simple_square_lattice fixture, site 3 has identifier 3 and coords [1, 1]. + idx, ident, coords = all_sites[3] + assert idx == 3 + assert ident == 3 + np.testing.assert_array_equal(coords, np.array([1, 1])) + + def test_get_site_info_with_identifier(self, simple_square_lattice): + """ + Tests the get_site_info() method using a site identifier instead of an index. + This covers the 'else' branch of the type check in the method. + """ + # Arrange + lattice = simple_square_lattice + # In this fixture, the identifier for the site at index 2 is also the integer 2. + identifier_to_test = 2 + expected_index = 2 + expected_coords = np.array([0, 1]) + + # Act + idx, ident, coords = lattice.get_site_info(identifier_to_test) + + # Assert + assert idx == expected_index + assert ident == identifier_to_test + np.testing.assert_array_equal(coords, expected_coords) + + @patch("matplotlib.pyplot.show") + def test_show_method_with_labels(self, mock_show, simple_square_lattice): + """ + Tests that the .show() method runs without error when label-related + options are enabled. This covers the logic inside the + 'if show_indices or show_identifiers:' block. + """ + # Arrange + lattice = simple_square_lattice + + # Act & Assert + try: + # Call .show() with options to display indices and identifiers. + lattice.show(show_indices=True, show_identifiers=True) + except Exception as e: + pytest.fail( + f".show() with label options raised an unexpected exception: {e}" + ) + + # Ensure the plotting function is still called. + mock_show.assert_called_once() + + def test_get_neighbors_logs_info_for_uncached_k( + self, simple_square_lattice, caplog + ): + """ + Tests that an INFO message is logged when get_neighbors is called for a 'k' + that has not been pre-calculated, triggering on-demand computation. + """ + # Arrange + lattice = simple_square_lattice # This fixture builds neighbors up to k=2 + k_to_test = 99 # A value that is clearly not cached + caplog.set_level(logging.INFO) # Ensure INFO logs are captured + + # Act + # This will now trigger the on-demand computation + _ = lattice.get_neighbors(0, k=k_to_test) + + # Assert + # Check that the correct INFO message about on-demand building was logged. + expected_log = ( + f"Neighbors for k={k_to_test} not pre-computed. " + f"Building now up to max_k={k_to_test}." + ) + assert expected_log in caplog.text + + @patch("matplotlib.pyplot.show") + def test_show_prints_warning_for_uncached_bonds( + self, mock_show, simple_square_lattice, caplog + ): + """ + Tests that a warning is printed when .show() is asked to draw a bond layer 'k' + that has not been pre-calculated. + """ + # Arrange + lattice = simple_square_lattice # This fixture builds neighbors up to k=2 + k_to_test = 99 # A value that is clearly not cached + + # Act + lattice.show(show_bonds_k=k_to_test) + + # Assert + assert ( + f"Cannot draw bonds. k={k_to_test} neighbors have not been calculated" + in caplog.text + ) + + @patch("matplotlib.pyplot.show") + def test_show_method_for_3d_lattice(self, mock_show): + """ + Tests that the .show() method can handle a 3D lattice without + crashing. This covers the 'if self.dimensionality == 3:' branches. + """ + # Arrange: Create a simple 2-site lattice in 3D space. + coords_3d = [[0.0, 0.0, 0.0], [1.0, 1.0, 1.0]] + ids_3d = [0, 1] + lattice_3d = CustomizeLattice( + dimensionality=3, identifiers=ids_3d, coordinates=coords_3d + ) + + # Assert basic property + assert lattice_3d.dimensionality == 3 + + # Act & Assert + # We just need to ensure that calling .show() on a 3D object + # executes the 3D plotting logic without raising an exception. + try: + lattice_3d.show(show_indices=True, show_bonds_k=None) + except Exception as e: + pytest.fail(f".show() for 3D lattice raised an unexpected exception: {e}") + + # Verify that the plotting pipeline was completed. + mock_show.assert_called_once() + + @patch("matplotlib.pyplot.subplots") + def test_show_method_actually_draws_2d_labels( + self, mock_subplots, simple_square_lattice + ): + """ + Tests if ax.text is actually called for a 2D lattice when labels are enabled. + """ + # Arrange: + # 1. Prepare mock Figure and Axes objects that `matplotlib.pyplot.subplots` will return. + # This allows us to inspect calls to the `ax` object. + mock_fig = matplotlib.figure.Figure() + mock_ax = matplotlib.axes.Axes(mock_fig, [0.0, 0.0, 1.0, 1.0]) + mock_subplots.return_value = (mock_fig, mock_ax) + + # 2. Mock the text method on our mock Axes object to monitor its calls. + with patch.object(mock_ax, "text") as mock_text_method: + lattice = simple_square_lattice + + # Act: + # Call the show method. It will now operate on our mock_ax object. + lattice.show(show_indices=True) + + # Assert: + # Check if the ax.text method was called. For a 4-site lattice, it should be called 4 times. + assert mock_text_method.call_count == lattice.num_sites + + def test_custom_irregular_geometry_neighbors(self): + """ + Tests neighbor finding on a more complex, non-grid-like custom geometry + to stress-test the distance shell and KDTree logic. + """ + # Arrange: A "star-shaped" lattice with a central point, + # an inner shell, and an outer shell. + coords = [ + [0.0, 0.0], # Site 0: Center + [1.0, 0.0], + [0.0, 1.0], + [-1.0, 0.0], + [0.0, -1.0], # Sites 1-4: Inner shell (dist=1) + [2.0, 0.0], + [0.0, 2.0], + [-2.0, 0.0], + [0.0, -2.0], # Sites 5-8: Outer shell (dist=2) + ] + ids = list(range(len(coords))) + lattice = CustomizeLattice( + dimensionality=2, identifiers=ids, coordinates=coords + ) + lattice._build_neighbors(max_k=3) + + # Assert 1: Neighbors of the central point (0) should be the distinct shells. + assert set(lattice.get_neighbors(0, k=1)) == {1, 2, 3, 4} + # The shell at dist=2.0 (d_sq=4.0) is the 3rd global shell, so we check k=3. + assert set(lattice.get_neighbors(0, k=3)) == {5, 6, 7, 8} + + assert lattice.get_neighbors(0, k=2) == [] + + # Assert 2: Neighbors of a point on the inner shell, e.g., site 1 ([1.0, 0.0]). + # Its nearest neighbors (k=1) are the center (0) and the closest point on the outer shell (5). + # Both are at distance 1.0. + assert set(lattice.get_neighbors(1, k=1)) == {0, 5} + + # Its next-nearest neighbors (k=2) are the other two points on the inner shell (2 and 4), + # both at distance sqrt(2). + assert set(lattice.get_neighbors(1, k=2)) == {2, 4} + + def test_customizelattice_max_k_precomputation_and_ondemand(self): + """ + A robust test to verify `precompute_neighbors` (max_k) for CustomizeLattice. + This test is designed to FAIL on the buggy code. + """ + coords = [ + [0.0, 0.0], + [1.0, 0.0], + [0.0, 1.0], + [-1.0, 0.0], + [0.0, -1.0], + [1.0, 1.0], + [-1.0, 1.0], + [-1.0, -1.0], + [1.0, -1.0], + [2.0, 0.0], + [0.0, 2.0], + [-2.0, 0.0], + [0.0, -2.0], + ] + ids = list(range(len(coords))) + k_precompute = 2 + + lattice = CustomizeLattice( + dimensionality=2, + identifiers=ids, + coordinates=coords, + precompute_neighbors=k_precompute, + ) + + computed_shells = sorted(list(lattice._neighbor_maps.keys())) + expected_shells = list(range(1, k_precompute + 1)) + + assert computed_shells == expected_shells, ( + f"TEST FAILED for CustomizeLattice with k={k_precompute}. " + f"Expected shells {expected_shells}, but found {computed_shells}." + ) + + k_ondemand = 3 + _ = lattice.get_neighbors(0, k=k_ondemand) + + computed_shells_after = sorted(list(lattice._neighbor_maps.keys())) + expected_shells_after = list(range(1, k_ondemand + 1)) + + assert computed_shells_after == expected_shells_after, ( + f"ON-DEMAND TEST FAILED for CustomizeLattice. " + f"Expected shells {expected_shells_after} after demanding k={k_ondemand}, " + f"but found {computed_shells_after}." + ) + + +@pytest.fixture +def obc_square_lattice() -> SquareLattice: + """Provides a 3x3 SquareLattice with Open Boundary Conditions.""" + return SquareLattice(size=(3, 3), pbc=False) + + +@pytest.fixture +def pbc_square_lattice() -> SquareLattice: + """Provides a 3x3 SquareLattice with Periodic Boundary Conditions.""" + return SquareLattice(size=(3, 3), pbc=True) + + +class TestSquareLattice: + """ + Groups all tests for the SquareLattice class, which implicitly tests + the core functionality of its parent, TILattice. + """ + + def test_initialization_and_properties(self, obc_square_lattice): + """ + Tests the basic properties of a SquareLattice instance. + """ + lattice = obc_square_lattice + assert lattice.dimensionality == 2 + assert lattice.num_sites == 9 # A 3x3 lattice should have 9 sites. + assert len(lattice) == 9 + + def test_site_info_and_identifiers(self, obc_square_lattice): + """ + Tests that site information (coordinates, identifiers) is correct. + """ + lattice = obc_square_lattice + center_idx = lattice.get_index((1, 1, 0)) + assert center_idx == 4 + + _, ident, coords = lattice.get_site_info(center_idx) + assert ident == (1, 1, 0) + np.testing.assert_array_equal(coords, np.array([1.0, 1.0])) + + corner_idx = 0 + _, ident, coords = lattice.get_site_info(corner_idx) + assert ident == (0, 0, 0) + np.testing.assert_array_equal(coords, np.array([0.0, 0.0])) + + def test_neighbors_with_open_boundaries(self, obc_square_lattice): + """ + Tests neighbor finding with Open Boundary Conditions (OBC) using specific + neighbor identities. + """ + lattice = obc_square_lattice + # Site indices for a 3x3 grid (row-major order): + # 0 1 2 + # 3 4 5 + # 6 7 8 + center_idx = 4 # (1, 1, 0) + corner_idx = 0 # (0, 0, 0) + edge_idx = 3 # (1, 0, 0) + + # Assert center site (4) has neighbors 1, 3, 5, 7 + assert set(lattice.get_neighbors(center_idx, k=1)) == {1, 3, 5, 7} + # Assert corner site (0) has neighbors 1, 3 + assert set(lattice.get_neighbors(corner_idx, k=1)) == {1, 3} + # Assert edge site (3) has neighbors 0, 4, 6 + assert set(lattice.get_neighbors(edge_idx, k=1)) == {0, 4, 6} + + def test_neighbors_with_periodic_boundaries(self, pbc_square_lattice): + """ + Tests neighbor finding with Periodic Boundary Conditions (PBC). + """ + lattice = pbc_square_lattice + corner_idx = lattice.get_index((0, 0, 0)) + + neighbors = lattice.get_neighbors(corner_idx, k=1) + neighbor_idents = {lattice.get_identifier(i) for i in neighbors} + expected_neighbor_idents = {(1, 0, 0), (0, 1, 0), (2, 0, 0), (0, 2, 0)} + assert neighbor_idents == expected_neighbor_idents + + nnn_neighbors = lattice.get_neighbors(corner_idx, k=2) + nnn_neighbor_idents = {lattice.get_identifier(i) for i in nnn_neighbors} + expected_nnn_idents = {(1, 1, 0), (2, 1, 0), (1, 2, 0), (2, 2, 0)} + assert nnn_neighbor_idents == expected_nnn_idents + + +# --- Tests for HoneycombLattice --- + + +@pytest.fixture +def pbc_honeycomb_lattice() -> HoneycombLattice: + """Provides a 2x2 HoneycombLattice with Periodic Boundary Conditions.""" + return HoneycombLattice(size=(2, 2), pbc=True) + + +class TestHoneycombLattice: + """ + Tests the HoneycombLattice class, focusing on its two-site basis. + """ + + def test_initialization_and_properties(self, pbc_honeycomb_lattice): + """ + Tests that the total number of sites is correct for a composite lattice. + """ + lattice = pbc_honeycomb_lattice + assert lattice.num_sites == 8 + assert lattice.num_basis == 2 + + def test_honeycomb_neighbors(self, pbc_honeycomb_lattice): + """ + Tests that every site in a honeycomb lattice has 3 nearest neighbors. + """ + lattice = pbc_honeycomb_lattice + site_a_idx = lattice.get_index((0, 0, 0)) + assert len(lattice.get_neighbors(site_a_idx, k=1)) == 3 + + site_b_idx = lattice.get_index((0, 0, 1)) + assert len(lattice.get_neighbors(site_b_idx, k=1)) == 3 + + +# --- Tests for TriangularLattice --- + + +@pytest.fixture +def pbc_triangular_lattice() -> TriangularLattice: + """ + Provides a 3x3 TriangularLattice with Periodic Boundary Conditions. + A 3x3 size is used to ensure all 6 nearest neighbors are unique sites. + """ + return TriangularLattice(size=(3, 3), pbc=True) + + +class TestTriangularLattice: + """ + Tests the TriangularLattice class, focusing on its coordination number. + """ + + def test_initialization_and_properties(self, pbc_triangular_lattice): + """ + Tests the basic properties of the triangular lattice. + """ + lattice = pbc_triangular_lattice + assert lattice.num_sites == 9 # 3 * 3 = 9 sites for a 3x3 grid + + def test_triangular_neighbors(self, pbc_triangular_lattice): + """ + Tests that every site in a triangular lattice has 6 nearest neighbors. + """ + lattice = pbc_triangular_lattice + site_idx = 0 + assert len(lattice.get_neighbors(site_idx, k=1)) == 6 + + +# --- Tests for New TILattice Implementations --- + + +class TestRectangularLattice: + """Tests for the 2D RectangularLattice.""" + + def test_rectangular_properties_and_neighbors(self): + """Tests neighbor counts for an OBC rectangular lattice.""" + lattice = RectangularLattice(size=(3, 4), pbc=False) + assert lattice.num_sites == 12 + assert lattice.dimensionality == 2 + + # Test neighbor counts for different site types + center_idx = lattice.get_index((1, 1, 0)) + corner_idx = lattice.get_index((0, 0, 0)) + edge_idx = lattice.get_index((0, 1, 0)) + + assert len(lattice.get_neighbors(center_idx, k=1)) == 4 + assert len(lattice.get_neighbors(corner_idx, k=1)) == 2 + assert len(lattice.get_neighbors(edge_idx, k=1)) == 3 + + +class TestTILatticeEdgeCases: + """ + A dedicated class for testing the behavior of TILattice and its + subclasses under less common, "edge-case" conditions. + """ + + @pytest.fixture + def obc_1d_chain(self) -> ChainLattice: + """ + Provides a 5-site 1D chain with Open Boundary Conditions. + """ + # 0--1--2--3--4 + return ChainLattice(size=(5,), pbc=False) + + def test_1d_chain_properties_and_neighbors(self, obc_1d_chain): + # Arrange + lattice = obc_1d_chain + + # Assert basic properties + assert lattice.num_sites == 5 + assert lattice.dimensionality == 1 + + # Assert neighbor counts for different positions + # Endpoint (site 0) should have 1 neighbor (site 1) + endpoint_idx = lattice.get_index((0, 0)) + assert lattice.get_neighbors(endpoint_idx, k=1) == [1] + + # Middle point (site 2) should have 2 neighbors (sites 1 and 3) + middle_idx = lattice.get_index((2, 0)) + assert len(lattice.get_neighbors(middle_idx, k=1)) == 2 + assert set(lattice.get_neighbors(middle_idx, k=1)) == {1, 3} + + @pytest.fixture + def nonsquare_lattice(self) -> SquareLattice: + """Provides a non-square 2x3 lattice to test indexing.""" + return SquareLattice(size=(2, 3), pbc=False) + + def test_nonsquare_lattice_indexing(self, nonsquare_lattice): + """ + Tests site indexing and coordinate generation on a non-square (2x3) lattice. + This ensures the logic correctly handles different dimension lengths. + The lattice sites are indexed row by row: + (0,0) (0,1) (0,2) -> indices 0, 1, 2 + (1,0) (1,1) (1,2) -> indices 3, 4, 5 + """ + # Arrange + lattice = nonsquare_lattice + + # Assert properties + assert lattice.num_sites == 6 # 2 * 3 = 6 + + # Act & Assert: Check a non-trivial site, e.g., the last one. + # The identifier for the site in the last row and last column. + ident = (1, 2, 0) + expected_idx = 5 + expected_coords = np.array([1.0, 2.0]) + + # Get index from identifier + idx = lattice.get_index(ident) + assert idx == expected_idx + + # Get info from index + _, _, coords = lattice.get_site_info(idx) + np.testing.assert_array_equal(coords, expected_coords) + + @patch("matplotlib.pyplot.show") + def test_show_method_for_1d_lattice(self, mock_show, obc_1d_chain): + """ + Tests that the .show() method can handle a 1D lattice (chain) + without crashing. This covers the 'if self.dimensionality == 1:' branches. + """ + # Arrange + lattice_1d = obc_1d_chain + + # Assert basic property + assert lattice_1d.num_sites == 5 + + # Act & Assert + try: + # Call .show() on the 1D lattice to execute the 1D plotting logic. + lattice_1d.show(show_indices=True) + except Exception as e: + pytest.fail(f".show() for 1D lattice raised an unexpected exception: {e}") + + # Verify that the plotting pipeline was completed. + mock_show.assert_called_once() + + +# --- Tests for API Robustness / Negative Cases --- + + +class TestApiRobustness: + """ + Groups tests that verify the API's behavior with invalid inputs. + This ensures the lattice classes fail gracefully and predictably. + """ + + def test_access_with_out_of_bounds_index(self, simple_square_lattice): + """ + Tests that an IndexError is raised when accessing a site index + that is out of the valid range (0 to num_sites-1). + """ + # Arrange + lattice = simple_square_lattice # This lattice has 4 sites (indices 0, 1, 2, 3) + invalid_index = 999 + + # Act & Assert + # We use pytest.raises to confirm that the expected exception is thrown. + with pytest.raises(IndexError): + lattice.get_coordinates(invalid_index) + + with pytest.raises(IndexError): + lattice.get_identifier(invalid_index) + + with pytest.raises(IndexError): + # get_site_info should also raise IndexError for an invalid index + lattice.get_site_info(invalid_index) + + def test_empty_lattice_handles_gracefully(self, caplog): + """ + Tests that an empty lattice initializes correctly and that methods + like .show() and ._build_neighbors() handle the zero-site case + gracefully without crashing. + """ + # Arrange: Create an empty CustomizeLattice instance. + empty_lattice = CustomizeLattice( + dimensionality=2, identifiers=[], coordinates=[] + ) + + # Assert: Verify basic properties. + assert empty_lattice.num_sites == 0 + assert len(empty_lattice) == 0 + + # Act & Assert for .show(): Verify it prints the expected message without crashing. + caplog.set_level(logging.INFO) + + empty_lattice.show() + assert "Lattice is empty, nothing to show." in caplog.text + + # Act & Assert for neighbor finding: Verify these calls run without errors. + empty_lattice._build_neighbors() + assert empty_lattice.get_neighbor_pairs(k=1) == [] + + def test_single_site_lattice_handles_gracefully(self): + """ + Tests that a lattice with a single site correctly handles neighbor + finding (i.e., returns no neighbors). + """ + # Arrange: Create a CustomizeLattice with a single site. + single_site_lattice = CustomizeLattice( + dimensionality=2, identifiers=[0], coordinates=[[0.0, 0.0]] + ) + + # Assert: Verify basic properties. + assert single_site_lattice.num_sites == 1 + + # Act: Attempt to build neighbor relationships. + single_site_lattice._build_neighbors(max_k=1) + + # Assert: The single site should have no neighbors. + assert single_site_lattice.get_neighbors(0, k=1) == [] + + def test_access_with_non_existent_identifier(self, simple_square_lattice): + """ + Tests that a ValueError is raised when accessing a site + with an identifier that does not exist in the lattice. + """ + # Arrange + lattice = simple_square_lattice + invalid_identifier = "non_existent_site" + + # Act & Assert + # Your code raises a ValueError with a specific message. We can even + # use the 'match' parameter to check if the error message is correct. + with pytest.raises(ValueError, match="not found in the lattice"): + lattice.get_index(invalid_identifier) + + with pytest.raises(ValueError, match="not found in the lattice"): + lattice.get_site_info(invalid_identifier) + + def test_show_warning_for_unsupported_dimension(self, caplog): + """ + Tests that .show() prints a warning when called on a lattice with a + dimensionality that it does not support for plotting (e.g., 4D). + """ + # Arrange: Create a simple lattice with an unsupported dimension. + lattice_4d = CustomizeLattice( + dimensionality=4, identifiers=[0], coordinates=[[0, 0, 0, 0]] + ) + + # Act + lattice_4d.show() + + # Assert: Check that the appropriate warning was printed to stdout. + assert "show() is not implemented for 4D lattices." in caplog.text + + def test_disconnected_lattice_neighbor_finding(self): + """ + Tests that neighbor finding algorithms work correctly for a lattice + composed of multiple, physically disconnected components. + """ + # Arrange: Create a lattice with two disconnected 2x2 squares, + # separated by a large distance. + # Component 1: sites with indices 0, 1, 2, 3 + # Component 2: sites with indices 4, 5, 6, 7 + coords = [ + [0.0, 0.0], + [1.0, 0.0], + [0.0, 1.0], + [1.0, 1.0], # Square 1 + [100.0, 0.0], + [101.0, 0.0], + [100.0, 1.0], + [101.0, 1.0], # Square 2 + ] + ids = list(range(len(coords))) + lattice = CustomizeLattice( + dimensionality=2, identifiers=ids, coordinates=coords + ) + lattice._build_neighbors(max_k=1) # Explicitly build neighbors + + # --- Test 1: get_neighbors() --- + # Act: Get neighbors for a site in the first component. + neighbors_of_site_0 = lattice.get_neighbors(0, k=1) + + # Assert: Its neighbors must only be within the first component. + assert set(neighbors_of_site_0) == {1, 2} + + # --- Test 2: get_neighbor_pairs() --- + # Act: Get all unique bonds for the entire lattice. + all_bonds = lattice.get_neighbor_pairs(k=1, unique=True) + + # Assert: No bond should connect a site from Component 1 to Component 2. + for i, j in all_bonds: + # A bond is valid only if both its sites are in the same component. + # We check this by seeing if their indices fall in the same range. + is_in_first_component = i < 4 and j < 4 + is_in_second_component = i >= 4 and j >= 4 + + assert is_in_first_component or is_in_second_component, ( + f"Found an invalid bond { (i,j) } that incorrectly connects " + "two separate components of the lattice." + ) + + def test_lattice_with_duplicate_coordinates(self): + """ + Tests a pathological case where multiple sites share the exact same coordinates. + The neighbor-finding logic must still treat them as distinct sites and + correctly identify neighbors based on other non-overlapping sites. + """ + # Arrange + # Site 'A' and 'B' are at the same position (0,0). + # Site 'C' is at (1,0), which should be a neighbor to both 'A' and 'B'. + ids = ["A", "B", "C"] + coords = [[0.0, 0.0], [0.0, 0.0], [1.0, 0.0]] + + lattice = CustomizeLattice( + dimensionality=2, identifiers=ids, coordinates=coords + ) + lattice._build_neighbors(max_k=1) # Build nearest neighbors + + # Act + idx_A = lattice.get_index("A") + idx_B = lattice.get_index("B") + idx_C = lattice.get_index("C") + + neighbors_A = lattice.get_neighbors(idx_A, k=1) + neighbors_B = lattice.get_neighbors(idx_B, k=1) + + # Assert + # 1. The distance between the overlapping points 'A' and 'B' is 0, + # so they should NOT be considered neighbors of each other. + assert ( + idx_B not in neighbors_A + ), "Overlapping sites should not be their own neighbors." + assert ( + idx_A not in neighbors_B + ), "Overlapping sites should not be their own neighbors." + + # 2. Both 'A' and 'B' should correctly identify 'C' as their neighbor. + # This is the key test of robustness. + assert neighbors_A == [ + idx_C + ], "Site 'A' failed to find its correct neighbor 'C'." + assert neighbors_B == [ + idx_C + ], "Site 'B' failed to find its correct neighbor 'C'." + + # 3. Conversely, 'C' should identify both 'A' and 'B' as its neighbors. + neighbors_C = lattice.get_neighbors(idx_C, k=1) + assert set(neighbors_C) == { + idx_A, + idx_B, + }, "Site 'C' failed to find both overlapping neighbors." + + def test_neighbor_shells_with_tiny_separation(self): + """ + Tests the numerical stability of neighbor shell identification. + Creates a lattice where the k=1 and k=2 shells are separated by a + distance much smaller than the default tolerance, and verifies that they + are still correctly identified as distinct shells. + """ + # Arrange + # Let d1 be the distance to the first neighbor shell. + d1 = 1.0 + # Let d2 be the distance to the second shell, which is extremely close to d1. + epsilon = 1e-8 # A tiny separation + d2 = d1 + epsilon + + # Create a 1D lattice with these specific distances. + # Site 0 is origin. Site 1 is at d1. Site 2 is at d2. + ids = [0, 1, 2] + coords = [[0.0], [d1], [d2]] + + # We explicitly use a tolerance LARGER than the separation, + # which SHOULD cause the shells to merge. + lattice_merged = CustomizeLattice( + dimensionality=1, identifiers=ids, coordinates=coords + ) + # Use a tolerance that cannot distinguish d1 and d2. + lattice_merged._build_neighbors(max_k=2, tol=1e-7) + + # Now, use a tolerance SMALLER than the separation, + # which SHOULD correctly distinguish the shells. + lattice_distinct = CustomizeLattice( + dimensionality=1, identifiers=ids, coordinates=coords + ) + lattice_distinct._build_neighbors(max_k=2, tol=1e-9) + + # Assert for the merged case + # With a large tolerance, site 1 and 2 should both be in the k=1 shell. + merged_neighbors_k1 = lattice_merged.get_neighbors(0, k=1) + assert set(merged_neighbors_k1) == { + 1, + 2, + }, "Shells were not merged with a large tolerance." + # There should be no k=2 shell. + merged_neighbors_k2 = lattice_merged.get_neighbors(0, k=2) + assert ( + merged_neighbors_k2 == [] + ), "A k=2 shell should not exist when shells are merged." + + # Assert for the distinct case + # With a small tolerance, only site 1 should be in the k=1 shell. + distinct_neighbors_k1 = lattice_distinct.get_neighbors(0, k=1) + assert distinct_neighbors_k1 == [ + 1 + ], "k=1 shell is incorrect with a small tolerance." + # Site 2 should now be in its own k=2 shell. + distinct_neighbors_k2 = lattice_distinct.get_neighbors(0, k=2) + assert distinct_neighbors_k2 == [ + 2 + ], "k=2 shell is incorrect with a small tolerance." + + +class TestTILattice: + """ + A dedicated class for testing the Translationally Invariant Lattice (TILattice) + and its subclasses like SquareLattice. + """ + + def test_init_with_mismatched_shapes_raises_error(self): + """ + Tests that TILattice raises AssertionError if the 'size' parameter's + length does not match the dimensionality. + """ + # Act & Assert: + # Pass a 'size' tuple with 3 elements to a 2D SquareLattice. + # This should trigger the AssertionError from the parent TILattice class. + with pytest.raises(AssertionError, match="Size tuple length mismatch"): + SquareLattice(size=(2, 2, 2)) + + def test_init_with_tuple_pbc(self): + """ + Tests that TILattice correctly handles a tuple input for the 'pbc' + (periodic boundary conditions) parameter. This covers the 'else' branch. + """ + # Arrange + pbc_tuple = (True, False) + + # Act + # Initialize a lattice with a tuple for pbc. + lattice = SquareLattice(size=(3, 3), pbc=pbc_tuple) + + # Assert + # The public 'pbc' attribute should be identical to the tuple we passed. + assert lattice.pbc == pbc_tuple + + @pytest.mark.parametrize( + "LatticeClass, init_args, k_precompute", + [ + (HoneycombLattice, {"size": (4, 5), "pbc": True}, 1), + (SquareLattice, {"size": (5, 5), "pbc": True}, 2), + (SquareLattice, {"size": (5, 5), "pbc": False}, 1), + (KagomeLattice, {"size": (3, 3), "pbc": True}, 1), + ], + ) + def test_tilattice_max_k_precomputation_and_ondemand( + self, LatticeClass, init_args, k_precompute + ): + """ + A robust, parameterized test to verify that `precompute_neighbors` (max_k) + works correctly across various TILattice types and conditions. + This test is designed to FAIL on the buggy code. + """ + lattice = LatticeClass(**init_args, precompute_neighbors=k_precompute) + + computed_shells = sorted(list(lattice._neighbor_maps.keys())) + expected_shells = list(range(1, k_precompute + 1)) + + assert computed_shells == expected_shells, ( + f"TEST FAILED for {LatticeClass.__name__} with k={k_precompute}. " + f"Expected shells {expected_shells}, but found {computed_shells}." + ) + + k_ondemand = k_precompute + 1 + + _ = lattice.get_neighbors(0, k=k_ondemand) + + computed_shells_after = sorted(list(lattice._neighbor_maps.keys())) + expected_shells_after = list(range(1, k_ondemand + 1)) + + assert computed_shells_after == expected_shells_after, ( + f"ON-DEMAND TEST FAILED for {LatticeClass.__name__}. " + f"Expected shells {expected_shells_after} after demanding k={k_ondemand}, " + f"but found {computed_shells_after}." + ) + + +class TestLongRangeNeighborFinding: + """ + Tests neighbor finding on larger lattices and for longer-range interactions (large k), + addressing suggestions from code review. + """ + + @pytest.fixture(scope="class") + def large_pbc_square_lattice(self) -> SquareLattice: + """ + Provides a single 6x8 SquareLattice with PBC for all tests in this class. + Using scope="class" makes it more efficient as it's created only once. + """ + # We choose a non-square size to catch potential bugs with non-uniform dimensions. + return SquareLattice(size=(7, 9), pbc=True) + + def test_neighbor_shell_structure_on_large_lattice(self, large_pbc_square_lattice): + """ + Tests the coordination number of various neighbor shells (k) on a large + periodic lattice. In a PBC square lattice, every site is identical, so + the number of neighbors for each shell k should be the same for all sites. + + Shell distances squared and their coordination numbers for a 2D square lattice: + - k=1: dist_sq=1 (e.g., (1,0)) -> 4 neighbors + - k=2: dist_sq=2 (e.g., (1,1)) -> 4 neighbors + - k=3: dist_sq=4 (e.g., (2,0)) -> 4 neighbors + - k=4: dist_sq=5 (e.g., (2,1)) -> 8 neighbors + - k=5: dist_sq=8 (e.g., (2,2)) -> 4 neighbors + - k=6: dist_sq=9 (e.g., (3,0)) -> 4 neighbors + - k=7: dist_sq=10 (e.g., (3,1)) -> 8 neighbors + """ + lattice = large_pbc_square_lattice + # Pick an arbitrary site, e.g., index 0. + site_idx = 0 + + # Expected coordination numbers for the first few shells. + expected_coordinations = {1: 4, 2: 4, 3: 4, 4: 8, 5: 4, 6: 4, 7: 8} + + for k, expected_count in expected_coordinations.items(): + neighbors = lattice.get_neighbors(site_idx, k=k) + assert ( + len(neighbors) == expected_count + ), f"Failed for k={k}. Expected {expected_count}, got {len(neighbors)}" + + def test_requesting_k_beyond_max_possible_shell(self, large_pbc_square_lattice): + """ + Tests that requesting a neighbor shell 'k' that is larger than any + possible shell in the finite lattice returns an empty list, and does + not raise an error. + """ + lattice = large_pbc_square_lattice + site_idx = 0 + + # 1. First, find out the maximum number of shells that *do* exist. + # We do this by calling _build_neighbors with a very large max_k. + # This is a bit of "white-box" testing but necessary to find the true max k. + lattice._build_neighbors(max_k=100) + max_k_found = len(lattice._neighbor_maps) + + # 2. Assert that the last valid shell is not empty. + last_shell_neighbors = lattice.get_neighbors(site_idx, k=max_k_found) + assert len(last_shell_neighbors) > 0 + + # 3. Assert that requesting a shell just beyond the last valid one returns empty. + # This is the core of the test. + non_existent_shell_neighbors = lattice.get_neighbors( + site_idx, k=max_k_found + 1 + ) + assert non_existent_shell_neighbors == [] + + @patch("matplotlib.pyplot.subplots") + def test_show_method_with_custom_bond_kwargs( + self, mock_subplots, simple_square_lattice + ): + """ + Tests that .show() correctly uses the `bond_kwargs` parameter + to customize the appearance of neighbor bonds. + """ + # Arrange: + # 1. Set up mock Figure and Axes objects, similar to other show() tests. + mock_fig = matplotlib.figure.Figure() + mock_ax = matplotlib.axes.Axes(mock_fig, [0.0, 0.0, 1.0, 1.0]) + mock_subplots.return_value = (mock_fig, mock_ax) + + # 2. Define our custom styles and the expected final styles. + lattice = simple_square_lattice + custom_bond_kwargs = {"color": "red", "linestyle": ":", "linewidth": 2} + + # The final dictionary should contain the defaults updated by our custom arguments. + expected_plot_kwargs = { + "color": "red", # Overridden + "linestyle": ":", # Overridden + "linewidth": 2, # A new key + "alpha": 0.6, # From default + "zorder": 1, # From default + } + + # 3. We specifically mock the `plot` method on our mock `ax` object. + with patch.object(mock_ax, "plot") as mock_plot_method: + # Act: + # Call the show method with our custom bond styles. + lattice.show(show_bonds_k=1, bond_kwargs=custom_bond_kwargs) + + # Assert: + # Check that the plot method was called. For a 2x2 square, there are 4 NN bonds. + assert mock_plot_method.call_count == 4 + + # Get the keyword arguments from the very first call to plot(). + # Note: call_args is a tuple (positional_args, keyword_args). We need the second element. + actual_kwargs = mock_plot_method.call_args[1] + + # Verify that the keyword arguments used for plotting match our expectations. + assert actual_kwargs == expected_plot_kwargs + + def test_mixed_boundary_conditions(self): + """ + Tests neighbor finding with mixed PBC (periodic in x, open in y). + This verifies that the neighbor finding logic correctly handles + anisotropy in periodic boundary conditions and returns sorted indices. + """ + # Arrange: Create a 3x3 square lattice, periodic in x, open in y. + lattice = SquareLattice(size=(3, 3), pbc=(True, False)) + + # We will test a site on the corner of the open boundary: (0, 0) + corner_site_idx = lattice.get_index((0, 0, 0)) + + # --- Test corner site (0, 0, 0), which is index 0 --- + # Act + corner_neighbors = lattice.get_neighbors(corner_site_idx, k=1) + + # Assert: The expected neighbors are (1,0,0), (2,0,0) [periodic], and (0,1,0) + # We get their indices and sort them to create the expected output. + expected_indices = sorted( + [ + lattice.get_index((1, 0, 0)), # Right neighbor + lattice.get_index((2, 0, 0)), # "Left" neighbor (wraps around) + lattice.get_index((0, 1, 0)), # "Up" neighbor + ] + ) + + # The list returned by get_neighbors should be identical to our sorted list. + assert ( + corner_neighbors == expected_indices + ), "Failed for corner site with mixed BC." + + # --- Test middle site on the edge (1, 0, 0), which is index 1 --- + edge_site_idx = lattice.get_index((1, 0, 0)) + + # Act + edge_neighbors = lattice.get_neighbors(edge_site_idx, k=1) + + # Assert + expected_edge_indices = sorted( + [ + lattice.get_index((0, 0, 0)), # Left neighbor + lattice.get_index((2, 0, 0)), # Right neighbor + lattice.get_index((1, 1, 0)), # "Up" neighbor + ] + ) + assert ( + edge_neighbors == expected_edge_indices + ), "Failed for edge site with mixed BC." + + +class TestAllTILattices: + """ + A parameterized test class to verify the basic properties and coordination + numbers for all implemented TILattice subclasses. This avoids code duplication. + """ + + # --- Test data in a structured and readable format --- + # Format: + # ( + # LatticeClass, # The lattice class to test + # {"size": ..., ...}, # Arguments for the constructor + # expected_num_sites, # Expected total number of sites + # expected_num_basis, # Expected number of sites in the basis + # {site_repr: count} # Dict of {representative_site: neighbor_count} + # ) + # For `site_repr`: + # - For simple lattices (basis=1), it's the integer index of the site. + # - For composite lattices (basis>1), it's the *basis index* to test. + lattice_test_cases = [ + # 1D Lattices + (ChainLattice, {"size": (5,), "pbc": True}, 5, 1, {0: 2, 2: 2}), + (ChainLattice, {"size": (5,), "pbc": False}, 5, 1, {0: 1, 2: 2}), + (DimerizedChainLattice, {"size": (3,), "pbc": True}, 6, 2, {0: 2, 1: 2}), + # 2D Lattices + ( + RectangularLattice, + {"size": (3, 4), "pbc": False}, + 12, + 1, + {5: 4, 0: 2, 4: 3}, + ), # center, corner, edge + (HoneycombLattice, {"size": (2, 2), "pbc": True}, 8, 2, {0: 3, 1: 3}), + (TriangularLattice, {"size": (3, 3), "pbc": True}, 9, 1, {0: 6}), + (CheckerboardLattice, {"size": (2, 2), "pbc": True}, 8, 2, {0: 4, 1: 4}), + (KagomeLattice, {"size": (2, 2), "pbc": True}, 12, 3, {0: 4, 1: 4, 2: 4}), + (LiebLattice, {"size": (2, 2), "pbc": True}, 12, 3, {0: 4, 1: 2, 2: 2}), + # 3D Lattices + (CubicLattice, {"size": (3, 3, 3), "pbc": True}, 27, 1, {0: 6, 13: 6}), + ] + + @pytest.mark.parametrize( + "LatticeClass, init_args, num_sites, num_basis, coordination_numbers", + lattice_test_cases, + ) + def test_lattice_properties_and_coordination( + self, + LatticeClass, + init_args, + num_sites, + num_basis, + coordination_numbers, + ): + """ + A single, parameterized test to validate all TILattice types. + """ + # --- Arrange --- + # Create the lattice instance dynamically from the test data. + lattice = LatticeClass(**init_args) + + # --- Assert: Basic properties --- + assert lattice.num_sites == num_sites + assert lattice.num_basis == num_basis + assert lattice.dimensionality == len(init_args["size"]) + + # --- Assert: Coordination numbers (nearest neighbors, k=1) --- + for site_repr, expected_count in coordination_numbers.items(): + # This logic correctly gets the site index to test, + # whether it's a simple or composite lattice. + if lattice.num_basis > 1: + # For composite lattices, site_repr is the basis_index. + # We find the index of this basis site in the first unit cell. + uc_coord = (0,) * lattice.dimensionality + test_site_idx = lattice.get_index(uc_coord + (site_repr,)) + else: + # For simple lattices, site_repr is the absolute site index. + test_site_idx = site_repr + + neighbors = lattice.get_neighbors(test_site_idx, k=1) + assert len(neighbors) == expected_count + if isinstance(LatticeClass, ChainLattice) and not init_args.get("pbc"): + if test_site_idx == 0: + assert 1 in neighbors + + +class TestCustomizeLatticeDynamic: + """Tests the dynamic modification capabilities of CustomizeLattice.""" + + @pytest.fixture + def initial_lattice(self) -> CustomizeLattice: + """Provides a basic 3-site lattice for modification tests.""" + return CustomizeLattice( + dimensionality=2, + identifiers=["A", "B", "C"], + coordinates=[[0, 0], [1, 0], [0, 1]], + ) + + def test_from_lattice_conversion(self): + """Tests creating a CustomizeLattice from a TILattice.""" + # Arrange + sq_lattice = SquareLattice(size=(2, 2), pbc=False) + + # Act + custom_lattice = CustomizeLattice.from_lattice(sq_lattice) + + # Assert + assert isinstance(custom_lattice, CustomizeLattice) + assert custom_lattice.num_sites == sq_lattice.num_sites + assert custom_lattice.dimensionality == sq_lattice.dimensionality + # Verify a site to be sure + np.testing.assert_array_equal( + custom_lattice.get_coordinates(3), sq_lattice.get_coordinates(3) + ) + assert custom_lattice.get_identifier(3) == sq_lattice.get_identifier(3) + + def test_add_sites_successfully(self, initial_lattice): + """Tests adding new, valid sites to the lattice.""" + # Arrange + lat = initial_lattice + assert lat.num_sites == 3 + + # Act + lat.add_sites(identifiers=["D", "E"], coordinates=[[1, 1], [2, 2]]) + + # Assert + assert lat.num_sites == 5 + assert lat.get_identifier(4) == "E" + np.testing.assert_array_equal(lat.get_coordinates(3), np.array([1, 1])) + assert "E" in lat._ident_to_idx + + def test_remove_sites_successfully(self, initial_lattice): + """Tests removing existing sites from the lattice.""" + # Arrange + lat = initial_lattice + assert lat.num_sites == 3 + + # Act + lat.remove_sites(identifiers=["A", "C"]) + + # Assert + assert lat.num_sites == 1 + assert lat.get_identifier(0) == "B" # Site 'B' is now at index 0 + assert "A" not in lat._ident_to_idx + np.testing.assert_array_equal(lat.get_coordinates(0), np.array([1, 0])) + + def test_add_duplicate_identifier_raises_error(self, initial_lattice): + """Tests that adding a site with an existing identifier fails.""" + with pytest.raises(ValueError, match="Duplicate identifiers found"): + initial_lattice.add_sites(identifiers=["A"], coordinates=[[9, 9]]) + + def test_remove_nonexistent_identifier_raises_error(self, initial_lattice): + """Tests that removing a non-existent site fails.""" + with pytest.raises(ValueError, match="Non-existent identifiers provided"): + initial_lattice.remove_sites(identifiers=["Z"]) + + def test_modification_clears_neighbor_cache(self, initial_lattice): + """ + Tests that add_sites and remove_sites correctly invalidate the + pre-computed neighbor map. + """ + # Arrange: Pre-compute neighbors on the initial lattice + initial_lattice._build_neighbors(max_k=1) + assert 0 in initial_lattice._neighbor_maps[1] # Check that neighbors exist + + # Act 1: Add a site + initial_lattice.add_sites(identifiers=["D"], coordinates=[[5, 5]]) + + # Assert 1: The neighbor map should now be empty + assert not initial_lattice._neighbor_maps + + # Arrange 2: Re-compute neighbors and then remove a site + initial_lattice._build_neighbors(max_k=1) + assert 0 in initial_lattice._neighbor_maps[1] + + # Act 2: Remove a site + initial_lattice.remove_sites(identifiers=["A"]) + + # Assert 2: The neighbor map should be empty again + assert not initial_lattice._neighbor_maps + + def test_modification_clears_distance_matrix_cache(self, initial_lattice): + """ + Tests that add_sites and remove_sites correctly invalidate the + cached distance matrix and that the recomputed matrix is correct. + """ + # Arrange 1: Compute, cache, and perform a meaningful check on the original matrix. + lat = initial_lattice + original_matrix = lat.distance_matrix + assert lat._distance_matrix is not None + assert original_matrix.shape == (3, 3) + # Meaningful check: distance from 'A'(idx 0) to 'B'(idx 1) should be 1.0 + np.testing.assert_allclose(original_matrix[0, 1], 1.0) + + # Act 1: Add a site. This should invalidate the cache. + lat.add_sites(identifiers=["D"], coordinates=[[1, 1]]) + + # Assert 1: Check cache is cleared and the new matrix is correct. + assert lat._distance_matrix is None # Verify cache invalidation + new_matrix_added = lat.distance_matrix + assert new_matrix_added.shape == (4, 4) + # Meaningful check: distance from 'B'(idx 1) to new site 'D'(idx 3) should be 1.0 + # Coords: B=[1,0], D=[1,1] + np.testing.assert_allclose(new_matrix_added[1, 3], 1.0) + + # Act 2: Remove a site. This should also invalidate the cache. + lat.remove_sites(identifiers=["A"]) + + # Assert 2: Check cache is cleared again and the final matrix is correct. + assert lat._distance_matrix is None # Verify cache invalidation + final_matrix = lat.distance_matrix + assert final_matrix.shape == (3, 3) # Now has 3 sites again + # Meaningful check: After removing 'A', the sites are B, C, D. + # 'B' is now at index 0 (coords [1,0]) + # 'C' is now at index 1 (coords [0,1]) + # 'D' is now at index 2 (coords [1,1]) + # Distance from new 'B' (idx 0) to new 'D' (idx 2) should be 1.0 + np.testing.assert_allclose(final_matrix[0, 2], 1.0) + + def test_neighbor_finding_returns_sorted_list(self, simple_square_lattice): + """ + Ensures that the list of neighbors returned by get_neighbors is always sorted. + This provides a stricter check than set-based comparisons. + """ + # Arrange + lattice = simple_square_lattice + + # Act + # Get neighbors for the central site (index 1 in a 2x2 grid) + # Expected neighbors are 0, 3. + neighbors = lattice.get_neighbors(1, k=1) + + # Assert + # We compare directly against a pre-sorted list, not a set. + # This will fail if the implementation returns [3, 0] instead of [0, 3]. + assert neighbors == [ + 0, + 3, + ], "The neighbor list should be sorted in ascending order." + + +class TestDistanceMatrix: + + # This is the upgraded, parameterized test. + @pytest.mark.parametrize( + # We define test scenarios as tuples: + # (build_k, check_site_identifier, expected_dist_sq) + # build_k: The number of neighbor shells to pre-build. + # check_site_identifier: The identifier of a site whose distance from the origin we will check. + # expected_dist_sq: The expected squared distance to that site. + "build_k, check_site_identifier, expected_dist_sq", + [ + # Scenario 1: The most common case. Build only NN (k=1), but check a NNN (k=2) distance. + # A buggy cache would fail this. + (1, (1, 1, 0), 2.0), + # Scenario 2: Build up to k=2, but check a k=3 distance. + (2, (2, 0, 0), 4.0), + # Scenario 3: Build up to k=3, but check a k=4 distance. + (3, (2, 1, 0), 5.0), + # Scenario 4: A more complex, higher-order neighbor. + (5, (3, 1, 0), 10.0), + ], + ) + def test_tilattice_full_pbc_distance_matrix_is_correct_regardless_of_build_k( + self, build_k, check_site_identifier, expected_dist_sq + ): + """ + Tests that the distance matrix for a fully periodic TILattice is + always fully correct, no matter how many neighbor shells were pre-calculated. + This is a high-strength test designed to catch subtle caching bugs where + the cached matrix might only contain partial information. + """ + # Arrange + # Using a larger, non-square lattice to avoid accidental symmetries + lat = SquareLattice(size=(7, 9), pbc=True) + + # Act + # Step 1: Pre-build neighbors. This is where a faulty caching + # mechanism in the source code might be triggered. + lat._build_neighbors(max_k=build_k) + + # Step 2: Access the distance_matrix property. A correct implementation + # will return a fully valid matrix. + dist_matrix = lat.distance_matrix + + # Assert + # Find the indices for the sites we want to check. + origin_idx = lat.get_index((0, 0, 0)) + check_site_idx = lat.get_index(check_site_identifier) + + # The core assertion: check the distance. + actual_dist_sq = dist_matrix[origin_idx, check_site_idx] ** 2 + + error_message = ( + f"Distance matrix failed when building k={build_k}. " + f"Checking distance to site {check_site_identifier} (expected sq={expected_dist_sq}) " + f"but got sq={actual_dist_sq} instead." + ) + + np.testing.assert_allclose( + actual_dist_sq, expected_dist_sq, err_msg=error_message + ) + + def test_tilattice_mixed_bc_distance_matrix_is_correct(self): + """ + Tests that the distance matrix is correctly calculated for a TILattice + with mixed boundary conditions (e.g., periodic in x, open in y). + """ + # Arrange + # pbc=(True, False) means periodic along x-axis, open along y-axis. + lat = SquareLattice(size=(5, 5), pbc=(True, False)) + + # Pre-build neighbors to engage the caching logic. + lat._build_neighbors(max_k=2) + dist_matrix = lat.distance_matrix + + # Assert + origin_idx = lat.get_index((0, 0, 0)) + + # 1. Test a distance affected by the periodic boundary (x-direction) + # The distance between (0,0) and (4,0) should be 1.0 due to PBC wrap-around. + pbc_neighbor_idx = lat.get_index((4, 0, 0)) + np.testing.assert_allclose(dist_matrix[origin_idx, pbc_neighbor_idx], 1.0) + + # 2. Test a distance affected by the open boundary (y-direction) + # The distance between (0,0) and (0,4) should be 4.0 as there's no wrap-around. + obc_neighbor_idx = lat.get_index((0, 4, 0)) + np.testing.assert_allclose(dist_matrix[origin_idx, obc_neighbor_idx], 4.0) + + # 3. Test a general, off-axis point. + # Distance from (0,0) to (3,3) with x-pbc. The x-distance is min(3, 5-3=2) = 2. + # The y-distance is 3. So total distance is sqrt(2^2 + 3^2) = sqrt(13). + general_neighbor_idx = lat.get_index((3, 3, 0)) + np.testing.assert_allclose( + dist_matrix[origin_idx, general_neighbor_idx], np.sqrt(13) + ) + + # --- This list and the following test are now at the correct indentation level --- + lattice_instances_for_invariant_test = [ + SquareLattice(size=(4, 4), pbc=True), + SquareLattice(size=(4, 3), pbc=(True, False)), # Mixed BC, non-square + HoneycombLattice(size=(3, 3), pbc=True), + TriangularLattice(size=(4, 4), pbc=False), + CustomizeLattice( + dimensionality=2, + identifiers=list(range(4)), + coordinates=[[0, 0], [1, 1], [0, 1], [1, 0]], + ), + ] + + @pytest.mark.parametrize("lattice", lattice_instances_for_invariant_test) + def test_distance_matrix_invariants_for_all_lattice_types(self, lattice): + """ + Tests that the distance matrix for any lattice type adheres to + fundamental mathematical properties (invariants): symmetry, zero diagonal, + and positive off-diagonal elements. + """ + # Arrange + n = lattice.num_sites + if n < 2: + pytest.skip("Invariant test requires at least 2 sites.") + + # Act + # We call the property directly, without building neighbors first, + # to test the on-demand computation path. + matrix = lattice.distance_matrix + + # Assert + # 1. Symmetry: The matrix must be equal to its transpose. + np.testing.assert_allclose( + matrix, + matrix.T, + err_msg=f"Distance matrix for {type(lattice).__name__} is not symmetric.", + ) + + # 2. Zero Diagonal: All diagonal elements must be zero. + np.testing.assert_allclose( + np.diag(matrix), + np.zeros(n), + err_msg=f"Diagonal of distance matrix for {type(lattice).__name__} is not zero.", + ) + + # 3. Positive Off-diagonal: All non-diagonal elements must be > 0. + # We create a boolean mask for the off-diagonal elements. + off_diagonal_mask = ~np.eye(n, dtype=bool) + assert np.all( + matrix[off_diagonal_mask] > 1e-9 + ), f"Found non-positive off-diagonal elements in distance matrix for {type(lattice).__name__}." + + +@pytest.mark.slow +class TestPerformance: + def test_pbc_implementation_is_not_significantly_slower_than_obc(self): + """ + A performance regression test. + It ensures that the specialized implementation for fully periodic + lattices (pbc=True) is not substantially slower than the general + implementation used for open boundaries (pbc=False). + This test will FAIL with the current code, exposing the performance bug. + """ + # Arrange: Use a large-enough lattice to make performance differences apparent + size = (30, 30) + k = 1 + + # Act 1: Measure the execution time of the general (OBC) implementation + start_time_obc = time.time() + _ = SquareLattice(size=size, pbc=False, precompute_neighbors=k) + duration_obc = time.time() - start_time_obc + + # Act 2: Measure the execution time of the specialized (PBC) implementation + start_time_pbc = time.time() + _ = SquareLattice(size=size, pbc=True, precompute_neighbors=k) + duration_pbc = time.time() - start_time_pbc + + print( + f"\n[Performance] OBC ({size}): {duration_obc:.4f}s | PBC ({size}): {duration_pbc:.4f}s" + ) + + # Assert: The PBC implementation should not be drastically slower. + # We allow it to be up to 3 times slower to account for minor overheads, + # but this will catch the current 10x+ regression. + # THIS ASSERTION WILL FAIL with the current buggy code. + assert duration_pbc < duration_obc * 5, ( + "The specialized PBC implementation is significantly slower " + "than the general-purpose implementation." + )