From 9e01be8e94f9381a21e033fefd8edbe69666b396 Mon Sep 17 00:00:00 2001 From: Stellogic Date: Sun, 3 Aug 2025 16:40:01 +0800 Subject: [PATCH 01/16] feat(lattice): Make lattice geometries differentiable and backend-agnostic --- examples/lennard_jones_optimization.py | 116 +++ tensorcircuit/backends/abstract_backend.py | 115 +++ tensorcircuit/backends/jax_backend.py | 35 +- tensorcircuit/backends/numpy_backend.py | 33 +- tensorcircuit/backends/pytorch_backend.py | 35 + tensorcircuit/backends/tensorflow_backend.py | 33 + tensorcircuit/templates/hamiltonians.py | 19 +- tensorcircuit/templates/lattice.py | 871 +++++++++++------- tests/test_hamiltonians.py | 32 + tests/test_lattice.py | 920 +++++++++++++++++-- 10 files changed, 1828 insertions(+), 381 deletions(-) create mode 100644 examples/lennard_jones_optimization.py diff --git a/examples/lennard_jones_optimization.py b/examples/lennard_jones_optimization.py new file mode 100644 index 00000000..400d1420 --- /dev/null +++ b/examples/lennard_jones_optimization.py @@ -0,0 +1,116 @@ +import optax +import numpy as np +import jax.numpy as jnp +import matplotlib.pyplot as plt +import jax +import tensorcircuit as tc + + +jax.config.update("jax_enable_x64", True) +K = tc.set_backend("jax") + + +def calculate_potential(log_a, base_distance_matrix, epsilon=0.5, sigma=1.0): + """ + Calculate the total Lennard-Jones potential energy for a given logarithm of the lattice constant (log_a). + """ + lattice_constant = jnp.exp(log_a) + d = base_distance_matrix * lattice_constant + d_safe = jnp.where(d > 1e-9, d, 1e-9) + + term12 = (sigma / d_safe) ** 12 + term6 = (sigma / d_safe) ** 6 + potential_matrix = 4 * epsilon * (term12 - term6) + + num_sites = d.shape[0] + potential_matrix = potential_matrix * ( + 1 - K.eye(num_sites, dtype=potential_matrix.dtype) + ) + + potential_energy = K.sum(potential_matrix) / 2.0 + + return potential_energy + + +# Pre-calculate the base distance matrix (for lattice_constant=1.0) +size = (10, 10) +lat_base = tc.templates.lattice.SquareLattice(size, lattice_constant=1.0, pbc=True) +base_distance_matrix = lat_base.distance_matrix + +# Create a lambda function to pass the base distance matrix to the potential function +potential_fun_for_grad = lambda log_a: calculate_potential(log_a, base_distance_matrix) +value_and_grad_fun = K.jit(K.value_and_grad(potential_fun_for_grad)) + +optimizer = optax.adam(learning_rate=0.01) + +log_a = K.convert_to_tensor(jnp.log(1.1)) + +opt_state = optimizer.init(log_a) + +history = {"a": [], "energy": []} + +print("Starting optimization of lattice constant...") +for i in range(200): + energy, grad = value_and_grad_fun(log_a) + + history["a"].append(jnp.exp(log_a)) + history["energy"].append(energy) + + if jnp.isnan(grad): + print(f"Gradient became NaN at iteration {i+1}. Stopping optimization.") + print(f"Current energy: {energy}, Current log_a: {log_a}") + break + + updates, opt_state = optimizer.update(grad, opt_state) + log_a = optax.apply_updates(log_a, updates) + + if (i + 1) % 20 == 0: + current_a = jnp.exp(log_a) + print( + f"Iteration {i+1}/200: Total Energy = {energy:.4f}, Lattice Constant = {current_a:.4f}" + ) + +final_a = jnp.exp(log_a) +final_energy = calculate_potential(K.convert_to_tensor(log_a), base_distance_matrix) + +if not jnp.isnan(final_energy): + print("\nOptimization finished!") + print(f"Final optimized lattice constant: {final_a:.6f}") + print(f"Corresponding minimum total energy: {final_energy:.6f}") + + # Vectorized calculation for the potential curve + a_vals = np.linspace(0.8, 1.5, 200) + log_a_vals = np.log(a_vals) + + # Use vmap to create a vectorized version of the potential function + vmap_potential = jax.vmap(lambda la: calculate_potential(la, base_distance_matrix)) + potential_curve = vmap_potential(K.convert_to_tensor(log_a_vals)) + + plt.figure(figsize=(10, 6)) + plt.plot(a_vals, potential_curve, label="Lennard-Jones Potential", color="blue") + plt.scatter( + history["a"], + history["energy"], + color="red", + s=20, + zorder=5, + label="Optimization Steps", + ) + plt.scatter( + final_a, + final_energy, + color="green", + s=100, + zorder=6, + marker="*", + label="Final Optimized Point", + ) + + plt.title("Lennard-Jones Potential Optimization") + plt.xlabel("Lattice Constant (a)") + plt.ylabel("Total Potential Energy") + plt.legend() + plt.grid(True) + plt.show() +else: + print("\nOptimization failed. Final energy is NaN.") diff --git a/tensorcircuit/backends/abstract_backend.py b/tensorcircuit/backends/abstract_backend.py index adbad83f..9e6de0fe 100644 --- a/tensorcircuit/backends/abstract_backend.py +++ b/tensorcircuit/backends/abstract_backend.py @@ -596,6 +596,84 @@ def argsort(self: Any, a: Tensor, axis: int = -1) -> Tensor: "Backend '{}' has not implemented `argsort`.".format(self.name) ) + def sort(self: Any, a: Tensor, axis: int = -1) -> Tensor: + """ + Sort a tensor along the given axis. + + :param a: [description] + :type a: Tensor + :param axis: [description], defaults to -1 + :type axis: int, optional + :return: [description] + :rtype: Tensor + """ + raise NotImplementedError( + "Backend '{}' has not implemented `sort`.".format(self.name) + ) + + def all(self: Any, a: Tensor, axis: Optional[Sequence[int]] = None) -> Tensor: + """ + Test whether all array elements along a given axis evaluate to True. + + :param a: Input tensor + :type a: Tensor + :param axis: Axis or axes along which a logical AND reduction is performed, + defaults to None + :type axis: Optional[Sequence[int]], optional + :return: A new boolean or tensor resulting from the AND reduction + :rtype: Tensor + """ + raise NotImplementedError( + "Backend '{}' has not implemented `all`.".format(self.name) + ) + + def meshgrid(self: Any, *args: Any, **kwargs: Any) -> Any: + """ + Return coordinate matrices from coordinate vectors. + + :param args: coordinate vectors + :type args: Any + :param kwargs: keyword arguments for meshgrid + :type kwargs: Any + :return: list of coordinate matrices + :rtype: Any + """ + raise NotImplementedError( + "Backend '{}' has not implemented `meshgrid`.".format(self.name) + ) + + def expand_dims(self: Any, a: Tensor, axis: int) -> Tensor: + """ + Expand the shape of a tensor. + Insert a new axis that will appear at the `axis` position in the expanded + tensor shape. + + :param a: Input tensor + :type a: Tensor + :param axis: Position in the expanded axes where the new axis is placed + :type axis: int + :return: Output tensor with the number of dimensions increased by one. + :rtype: Tensor + """ + raise NotImplementedError( + "Backend '{}' has not implemented `expand_dims`.".format(self.name) + ) + + def power(self: Any, a: Tensor, b: Union[Tensor, float]) -> Tensor: + """ + First array elements raised to powers from second array, element-wise. + + :param a: The bases + :type a: Tensor + :param b: The exponents + :type b: Union[Tensor, float] + :return: The bases in `a` raised to the powers in `b`. + :rtype: Tensor + """ + raise NotImplementedError( + "Backend '{}' has not implemented `power`.".format(self.name) + ) + def unique_with_counts(self: Any, a: Tensor, **kws: Any) -> Tuple[Tensor, Tensor]: """ Find the unique elements and their corresponding counts of the given tensor ``a``. @@ -1404,6 +1482,43 @@ def cond( "Backend '{}' has not implemented `cond`.".format(self.name) ) + def where( + self: Any, + condition: Tensor, + x: Optional[Tensor] = None, + y: Optional[Tensor] = None, + ) -> Tensor: + """ + Return a tensor of elements selected from either x or y, depending on condition. + + :param condition: Where True, yield x, otherwise yield y. + :type condition: Tensor (bool) + :param x: Values from which to choose when condition is True. + :type x: Tensor + :param y: Values from which to choose when condition is False. + :type y: Tensor + :return: A tensor with elements from x where condition is True, and y otherwise. + :rtype: Tensor + """ + raise NotImplementedError( + "Backend '{}' has not implemented `where`.".format(self.name) + ) + + def equal(self: Any, x1: Tensor, x2: Tensor) -> Tensor: + """ + Return the truth value of (x1 == x2) element-wise. + + :param x1: Input tensor. + :type x1: Tensor + :param x2: Input tensor. + :type x2: Tensor + :return: Output tensor, element-wise truth value of (x1 == x2). + :rtype: Tensor (bool) + """ + raise NotImplementedError( + "Backend '{}' has not implemented `equal`.".format(self.name) + ) + def switch( self: Any, index: Tensor, branches: Sequence[Callable[[], Tensor]] ) -> Tensor: diff --git a/tensorcircuit/backends/jax_backend.py b/tensorcircuit/backends/jax_backend.py index a9d17b96..eeaf1d63 100644 --- a/tensorcircuit/backends/jax_backend.py +++ b/tensorcircuit/backends/jax_backend.py @@ -243,8 +243,10 @@ def zeros(self, shape: Tuple[int, ...], dtype: Optional[str] = None) -> Tensor: def copy(self, tensor: Tensor) -> Tensor: return jnp.array(tensor, copy=True) - def convert_to_tensor(self, tensor: Tensor) -> Tensor: + def convert_to_tensor(self, tensor: Tensor, **kwargs: Any) -> Tensor: result = jnp.asarray(tensor) + if "dtype" in kwargs and kwargs["dtype"] is not None: + result = self.cast(result, kwargs["dtype"]) return result def abs(self, a: Tensor) -> Tensor: @@ -353,6 +355,9 @@ def expm(self, a: Tensor) -> Tensor: # currently expm in jax doesn't support AD, it will raise an AssertError, # see https://github.com/google/jax/issues/2645 + def power(self, a: Tensor, b: Union[Tensor, float]) -> Tensor: + return jnp.power(a, b) + def stack(self, a: Sequence[Tensor], axis: int = 0) -> Tensor: return jnp.stack(a, axis=axis) @@ -390,6 +395,9 @@ def argmin(self, a: Tensor, axis: int = 0) -> Tensor: def argsort(self, a: Tensor, axis: int = -1) -> Tensor: return jnp.argsort(a, axis=axis) + def sort(self, a: Tensor, axis: int = -1) -> Tensor: + return jnp.sort(a, axis=axis) + def unique_with_counts( # type: ignore self, a: Tensor, *, size: Optional[int] = None, fill_value: Optional[int] = None ) -> Tuple[Tensor, Tensor]: @@ -410,6 +418,12 @@ def onehot(self, a: Tensor, num: int) -> Tensor: def cumsum(self, a: Tensor, axis: Optional[int] = None) -> Tensor: return jnp.cumsum(a, axis) + def all(self, a: Tensor, axis: Optional[int] = None) -> Tensor: + return jnp.all(a, axis=axis) + + def equal(self, x1: Tensor, x2: Tensor) -> Tensor: + return jnp.equal(x1, x2) + def is_tensor(self, a: Any) -> bool: if not isinstance(a, jnp.ndarray): return False @@ -812,4 +826,23 @@ def wrapper( vvag = vectorized_value_and_grad + def meshgrid(self, *args: Any, **kwargs: Any) -> Any: + """ + Backend-agnostic meshgrid function. + """ + return jnp.meshgrid(*args, **kwargs) + optimizer = optax_optimizer + + def expand_dims(self, a: Tensor, axis: int) -> Tensor: + return jnp.expand_dims(a, axis) + + def where( + self, + condition: Tensor, + x: Optional[Tensor] = None, + y: Optional[Tensor] = None, + ) -> Tensor: + if x is None and y is None: + return jnp.where(condition) + return jnp.where(condition, x, y) diff --git a/tensorcircuit/backends/numpy_backend.py b/tensorcircuit/backends/numpy_backend.py index 633a8467..c0b35a1f 100644 --- a/tensorcircuit/backends/numpy_backend.py +++ b/tensorcircuit/backends/numpy_backend.py @@ -35,7 +35,7 @@ def _sum_numpy( # see https://github.com/google/TensorNetwork/issues/952 -def _convert_to_tensor_numpy(self: Any, a: Tensor) -> Tensor: +def _convert_to_tensor_numpy(self: Any, a: Tensor, **kwargs: Any) -> Tensor: if not isinstance(a, np.ndarray) and not np.isscalar(a): a = np.array(a) a = np.asarray(a) @@ -80,6 +80,9 @@ def copy(self, a: Tensor) -> Tensor: def expm(self, a: Tensor) -> Tensor: return expm(a) + def power(self, a: Tensor, b: Union[Tensor, float]) -> Tensor: + return np.power(a, b) + def abs(self, a: Tensor) -> Tensor: return np.abs(a) @@ -132,6 +135,12 @@ def eigvalsh(self, a: Tensor) -> Tensor: def kron(self, a: Tensor, b: Tensor) -> Tensor: return np.kron(a, b) + def meshgrid(self, *args: Any, **kwargs: Any) -> Any: + """ + Backend-agnostic meshgrid function. + """ + return np.meshgrid(*args, **kwargs) + def dtype(self, a: Tensor) -> str: return a.dtype.__str__() # type: ignore @@ -151,6 +160,9 @@ def i(self, dtype: Any = None) -> Tensor: dtype = getattr(np, dtype) return np.array(1j, dtype=dtype) + def expand_dims(self, a: Tensor, axis: int) -> Tensor: + return np.expand_dims(a, axis) + def stack(self, a: Sequence[Tensor], axis: int = 0) -> Tensor: return np.stack(a, axis=axis) @@ -173,6 +185,9 @@ def std( ) -> Tensor: return np.std(a, axis=axis, keepdims=keepdims) + def all(self, a: Tensor, axis: Optional[Sequence[int]] = None) -> Tensor: + return np.all(a, axis=axis) + def unique_with_counts(self, a: Tensor, **kws: Any) -> Tuple[Tensor, Tensor]: return np.unique(a, return_counts=True) # type: ignore @@ -188,6 +203,9 @@ def argmax(self, a: Tensor, axis: int = 0) -> Tensor: def argmin(self, a: Tensor, axis: int = 0) -> Tensor: return np.argmin(a, axis=axis) + def sort(self, a: Tensor, axis: int = -1) -> Tensor: + return np.sort(a, axis=axis) + def sigmoid(self, a: Tensor) -> Tensor: return expit(a) @@ -345,6 +363,19 @@ def to_dense(self, sp_a: Tensor) -> Tensor: def is_sparse(self, a: Tensor) -> bool: return issparse(a) # type: ignore + def where( + self, + condition: Tensor, + x: Optional[Tensor] = None, + y: Optional[Tensor] = None, + ) -> Tensor: + if x is None and y is None: + return np.where(condition) + return np.where(condition, x, y) + + def equal(self, x: Tensor, y: Tensor) -> Tensor: + return np.equal(x, y) + def cond( self, pred: bool, diff --git a/tensorcircuit/backends/pytorch_backend.py b/tensorcircuit/backends/pytorch_backend.py index 83176308..73b8c4d1 100644 --- a/tensorcircuit/backends/pytorch_backend.py +++ b/tensorcircuit/backends/pytorch_backend.py @@ -244,6 +244,9 @@ def expm(self, a: Tensor) -> Tensor: # it doesn't support complex numbers which is more severe issue. # see https://github.com/pytorch/pytorch/issues/9983 + def power(self, a: Tensor, b: Union[Tensor, float]) -> Tensor: + return torchlib.pow(a, b) + def sin(self, a: Tensor) -> Tensor: return torchlib.sin(a) @@ -369,6 +372,17 @@ def argmax(self, a: Tensor, axis: int = 0) -> Tensor: def argmin(self, a: Tensor, axis: int = 0) -> Tensor: return torchlib.argmin(a, dim=axis) + def sort(self, a: Tensor, axis: int = -1) -> Tensor: + return torchlib.sort(a, dim=axis).values + + def all(self, tensor: Tensor, axis: Optional[Sequence[int]] = None) -> Tensor: + """ + Corresponds to torch.all. + """ + if axis is None: + return torchlib.all(tensor) + return torchlib.all(tensor, dim=axis) + def unique_with_counts(self, a: Tensor, **kws: Any) -> Tuple[Tensor, Tensor]: return torchlib.unique(a, return_counts=True) # type: ignore @@ -425,6 +439,21 @@ def searchsorted(self, a: Tensor, v: Tensor, side: str = "left") -> Tensor: v = self.convert_to_tensor(v) return torchlib.searchsorted(a, v, side=side) + def where( + self, + condition: Tensor, + x: Optional[Tensor] = None, + y: Optional[Tensor] = None, + ) -> Tensor: + if x is None and y is None: + return torchlib.where(condition) + return torchlib.where(condition, x, y) + + def equal(self, x1: Tensor, x2: Any) -> Tensor: + if not self.is_tensor(x2): + x2 = torchlib.tensor(x2, device=x1.device, dtype=x1.dtype) + return torchlib.eq(x1, x2) + def reverse(self, a: Tensor) -> Tensor: return torchlib.flip(a, dims=(-1,)) @@ -706,6 +735,12 @@ def wrapper( return wrapper + def expand_dims(self, a: Tensor, axis: int) -> Tensor: + return torchlib.unsqueeze(a, dim=axis) + vvag = vectorized_value_and_grad + def meshgrid(self, *args: Any, **kws: Any) -> Tensor: + return torchlib.meshgrid(*args, **kws) + optimizer = torch_optimizer diff --git a/tensorcircuit/backends/tensorflow_backend.py b/tensorcircuit/backends/tensorflow_backend.py index beaf8b5f..29508d16 100644 --- a/tensorcircuit/backends/tensorflow_backend.py +++ b/tensorcircuit/backends/tensorflow_backend.py @@ -75,6 +75,7 @@ def update(self, grads: pytree, params: pytree) -> pytree: def _tensordot_tf( self: Any, a: Tensor, b: Tensor, axes: Union[int, Sequence[Sequence[int]]] ) -> Tensor: + b = tf.cast(b, a.dtype) return tf.tensordot(a, b, axes) @@ -444,6 +445,9 @@ def copy(self, a: Tensor) -> Tensor: def expm(self, a: Tensor) -> Tensor: return tf.linalg.expm(a) + def power(self, a: Tensor, b: Union[Tensor, float]) -> Tensor: + return tf.math.pow(a, b) + def sin(self, a: Tensor) -> Tensor: return tf.math.sin(a) @@ -524,6 +528,23 @@ def min(self, a: Tensor, axis: Optional[int] = None) -> Tensor: def max(self, a: Tensor, axis: Optional[int] = None) -> Tensor: return tf.reduce_max(a, axis=axis) + def all(self, a: Tensor, axis: Optional[int] = None) -> Tensor: + return tf.reduce_all(tf.cast(a, tf.bool), axis=axis) + + def where( + self, + condition: Tensor, + x: Optional[Tensor] = None, + y: Optional[Tensor] = None, + ) -> Tensor: + if x is None and y is None: + # Return a tuple of tensors to be consistent with other backends + return tuple(tf.unstack(tf.where(condition), axis=1)) + return tf.where(condition, x, y) + + def equal(self, x1: Tensor, x2: Tensor) -> Tensor: + return tf.math.equal(x1, x2) + def argmax(self, a: Tensor, axis: int = 0) -> Tensor: return tf.math.argmax(a, axis=axis) @@ -533,6 +554,9 @@ def argmin(self, a: Tensor, axis: int = 0) -> Tensor: def argsort(self, a: Tensor, axis: int = -1) -> Tensor: return tf.argsort(a, axis=axis) + def sort(self, a: Tensor, axis: int = -1) -> Tensor: + return tf.sort(a, axis=axis) + def unique_with_counts(self, a: Tensor, **kws: Any) -> Tuple[Tensor, Tensor]: r = tf.unique_with_counts(a) order = tf.argsort(r.y) @@ -1058,4 +1082,13 @@ def wrapper( vvag = vectorized_value_and_grad + def meshgrid(self, *args: Any, **kwargs: Any) -> Any: + """ + Backend-agnostic meshgrid function. + """ + return tf.meshgrid(*args, **kwargs) + optimizer = keras_optimizer + + def expand_dims(self, a: Tensor, axis: int) -> Tensor: + return tf.expand_dims(a, axis) diff --git a/tensorcircuit/templates/hamiltonians.py b/tensorcircuit/templates/hamiltonians.py index 0382b11d..b7e0cc64 100644 --- a/tensorcircuit/templates/hamiltonians.py +++ b/tensorcircuit/templates/hamiltonians.py @@ -17,13 +17,14 @@ def _create_empty_sparse_matrix(shape: Tuple[int, int]) -> Any: def heisenberg_hamiltonian( lattice: AbstractLattice, j_coupling: Union[float, List[float], Tuple[float, ...]] = 1.0, + interaction_scope: str = "neighbors", ) -> Any: """ Generates the sparse matrix of the Heisenberg Hamiltonian for a given lattice. The Heisenberg Hamiltonian is defined as: - H = J * Σ_{} (X_i X_j + Y_i Y_j + Z_i Z_j) - where the sum is over all unique nearest-neighbor pairs . + H = J * Σ_{i,j} (X_i X_j + Y_i Y_j + Z_i Z_j) + where the sum is over a specified set of interacting pairs {i,j}. :param lattice: An instance of a class derived from AbstractLattice, which provides the geometric information of the system. @@ -32,11 +33,23 @@ def heisenberg_hamiltonian( isotropic model (Jx=Jy=Jz) or a list/tuple of 3 floats for an anisotropic model (Jx, Jy, Jz). Defaults to 1.0. :type j_coupling: Union[float, List[float], Tuple[float, ...]], optional + :param interaction_scope: Defines the range of interactions. + - "neighbors": Includes only nearest-neighbor pairs (default). + - "all": Includes all unique pairs of sites. + :type interaction_scope: str, optional :return: The Hamiltonian as a backend-agnostic sparse matrix. :rtype: Any """ num_sites = lattice.num_sites - neighbor_pairs = lattice.get_neighbor_pairs(k=1, unique=True) + if interaction_scope == "neighbors": + neighbor_pairs = lattice.get_neighbor_pairs(k=1, unique=True) + elif interaction_scope == "all": + neighbor_pairs = lattice.get_all_pairs() + else: + raise ValueError( + f"Invalid interaction_scope: '{interaction_scope}'. " + "Must be 'neighbors' or 'all'." + ) if isinstance(j_coupling, (float, int)): js = [float(j_coupling)] * 3 diff --git a/tensorcircuit/templates/lattice.py b/tensorcircuit/templates/lattice.py index 52f152c9..e02c5342 100644 --- a/tensorcircuit/templates/lattice.py +++ b/tensorcircuit/templates/lattice.py @@ -19,10 +19,9 @@ ) logger = logging.getLogger(__name__) +import itertools import numpy as np - -from scipy.spatial import KDTree -from scipy.spatial.distance import pdist, squareform +from .. import backend # This block resolves a name resolution issue for the static type checker (mypy). @@ -41,9 +40,11 @@ import matplotlib.axes from mpl_toolkits.mplot3d import Axes3D +Tensor = Any SiteIndex = int SiteIdentifier = Hashable -Coordinates = np.ndarray[Any, Any] +Coordinates = Tensor + NeighborMap = Dict[SiteIndex, List[SiteIndex]] @@ -62,15 +63,32 @@ class AbstractLattice(abc.ABC): def __init__(self, dimensionality: int): """Initializes the base lattice class.""" + logger.debug( + f"[DEBUG-LATTICE] Initializing AbstractLattice with dimensionality: {dimensionality}" + ) self._dimensionality = dimensionality - # --- Internal Data Structures (to be populated by subclasses) --- - self._indices: List[SiteIndex] = [] - self._identifiers: List[SiteIdentifier] = [] - self._coordinates: List[Coordinates] = [] - self._ident_to_idx: Dict[SiteIdentifier, SiteIndex] = {} - self._neighbor_maps: Dict[int, NeighborMap] = {} - self._distance_matrix: Optional[Coordinates] = None + # Core data structures for storing site information. + self._indices: List[SiteIndex] = [] # List of integer indices [0, 1, ..., N-1] + self._identifiers: List[SiteIdentifier] = ( + [] + ) # List of unique, hashable site identifiers + self._coordinates: Optional[Coordinates] = ( + None # N x D array of site coordinates + ) + + # Mappings for efficient lookups. + self._ident_to_idx: Dict[SiteIdentifier, SiteIndex] = ( + {} + ) # Maps identifiers to indices + + # Cached properties, computed on demand. + self._neighbor_maps: Dict[int, NeighborMap] = ( + {} + ) # Caches neighbor info for different k + self._distance_matrix: Optional[Coordinates] = ( + None # Caches the full N x N distance matrix + ) @property def num_sites(self) -> int: @@ -95,12 +113,17 @@ def distance_matrix(self) -> Coordinates: subsequent calls. This computation can be expensive for large lattices. """ if self._distance_matrix is None: - logger.info("Distance matrix not cached. Computing now...") + logger.debug("Distance matrix not cached. Computing now...") + logger.debug("[DEBUG-LATTICE] Computing distance matrix...") self._distance_matrix = self._compute_distance_matrix() + logger.debug("[DEBUG-LATTICE] ...distance matrix computed.") return self._distance_matrix def _validate_index(self, index: SiteIndex) -> None: """A private helper to check if a site index is within the valid range.""" + logger.debug( + f"[DEBUG-LATTICE] Validating index: {index} against num_sites: {self.num_sites}" + ) if not (0 <= index < self.num_sites): raise IndexError( f"Site index {index} out of range (0-{self.num_sites - 1})" @@ -116,7 +139,10 @@ def get_coordinates(self, index: SiteIndex) -> Coordinates: :rtype: Coordinates """ self._validate_index(index) - return self._coordinates[index] + assert self._coordinates is not None + coords = self._coordinates[index] + logger.debug(f"[DEBUG-LATTICE] get_coordinates for index {index}: {coords}") + return coords def get_identifier(self, index: SiteIndex) -> SiteIdentifier: """Gets the abstract identifier of a site by its integer index. @@ -140,8 +166,14 @@ def get_index(self, identifier: SiteIdentifier) -> SiteIndex: :rtype: SiteIndex """ try: - return self._ident_to_idx[identifier] + logger.debug(f"[DEBUG-LATTICE] Getting index for identifier: {identifier}") + index = self._ident_to_idx[identifier] + logger.debug(f"[DEBUG-LATTICE] Found index: {index}") + return index except KeyError as e: + logger.debug( + f"[DEBUG-LATTICE] Identifier {identifier} not found in _ident_to_idx map." + ) raise ValueError( f"Identifier {identifier} not found in the lattice." ) from e @@ -166,13 +198,26 @@ def get_site_info( - The site's coordinates as a NumPy array. :rtype: Tuple[SiteIndex, SiteIdentifier, Coordinates] """ + logger.debug( + f"[DEBUG-LATTICE] get_site_info called with: {index_or_identifier} (type: {type(index_or_identifier)})" + ) + assert self._coordinates is not None if isinstance(index_or_identifier, int): # SiteIndex is an int idx = index_or_identifier self._validate_index(idx) + logger.debug( + f"[DEBUG-LATTICE] Identified as SiteIndex. Returning info for index {idx}." + ) return idx, self._identifiers[idx], self._coordinates[idx] - else: # Identifier + else: ident = index_or_identifier + logger.debug( + f"[DEBUG-LATTICE] Identified as SiteIdentifier. Looking up index for {ident}." + ) idx = self.get_index(ident) + logger.debug( + f"[DEBUG-LATTICE] Returning info for identifier {ident} (index {idx})." + ) return idx, ident, self._coordinates[idx] def sites(self) -> Iterator[Tuple[SiteIndex, SiteIdentifier, Coordinates]]: @@ -185,7 +230,9 @@ def sites(self) -> Iterator[Tuple[SiteIndex, SiteIdentifier, Coordinates]]: index, identifier, and coordinates. :rtype: Iterator[Tuple[SiteIndex, SiteIdentifier, Coordinates]] """ + logger.debug("[DEBUG-LATTICE] Creating sites iterator.") for i in range(self.num_sites): + assert self._coordinates is not None yield i, self._identifiers[i], self._coordinates[i] def get_neighbors(self, index: SiteIndex, k: int = 1) -> List[SiteIndex]: @@ -202,10 +249,14 @@ def get_neighbors(self, index: SiteIndex, k: int = 1) -> List[SiteIndex]: pre-calculated or if the site has no such neighbors. :rtype: List[SiteIndex] """ + logger.debug(f"[DEBUG-LATTICE] Getting neighbors for index {index}, k={k}") if k not in self._neighbor_maps: logger.info( f"Neighbors for k={k} not pre-computed. Building now up to max_k={k}." ) + logger.debug( + f"[DEBUG-LATTICE] Neighbor map for k={k} not found. Triggering _build_neighbors(max_k={k})." + ) self._build_neighbors(max_k=k) if k not in self._neighbor_maps: @@ -231,13 +282,18 @@ def get_neighbor_pairs( :rtype: List[Tuple[SiteIndex, SiteIndex]] """ + logger.debug( + f"[DEBUG-LATTICE] Getting neighbor pairs for k={k}, unique={unique}" + ) if k not in self._neighbor_maps: logger.info( f"Neighbor pairs for k={k} not pre-computed. Building now up to max_k={k}." ) + logger.debug( + f"[DEBUG-LATTICE] Neighbor map for k={k} not found. Triggering _build_neighbors(max_k={k})." + ) self._build_neighbors(max_k=k) - # After attempting to build, check again. If still not found, return empty. if k not in self._neighbor_maps: return [] @@ -251,8 +307,29 @@ def get_neighbor_pairs( pairs.append((i, j)) return sorted(pairs) - # Sorting provides a deterministic output order - # --- Abstract Methods for Subclass Implementation --- + def get_all_pairs(self) -> List[Tuple[SiteIndex, SiteIndex]]: + """ + Returns a list of all unique pairs of site indices (i, j) where i < j. + + This method provides all-to-all connectivity, useful for Hamiltonians + where every site interacts with every other site. + + Note on Differentiability: + This method provides a static list of index pairs and is not differentiable + itself. However, it is designed to be used in combination with the fully + differentiable ``distance_matrix`` property. By using the pairs from this + method to index into the ``distance_matrix``, one can construct differentiable + objective functions based on all-pair interactions, effectively bypassing the + non-differentiable ``get_neighbor_pairs`` method for geometry optimization tasks. + + :return: A list of tuples, where each tuple is a unique pair of site indices. + :rtype: List[Tuple[SiteIndex, SiteIndex]] + """ + logger.debug("[DEBUG-LATTICE] Getting all unique pairs of sites.") + if self.num_sites < 2: + return [] + # Use itertools.combinations to efficiently generate all unique pairs (i, j) with i < j. + return sorted(list(itertools.combinations(range(self.num_sites), 2))) @abc.abstractmethod def _build_lattice(self, *args: Any, **kwargs: Any) -> None: @@ -325,6 +402,12 @@ def show( :param kwargs: Additional keyword arguments to be passed directly to the `matplotlib.pyplot.scatter` function for customizing site appearance. """ + logger.debug( + ( + f"[DEBUG-LATTICE] show() called with: show_indices={show_indices}, " + f"show_identifiers={show_identifiers}, show_bonds_k={show_bonds_k}" + ) + ) try: import matplotlib.pyplot as plt except ImportError: @@ -334,7 +417,8 @@ def show( ) return - # creat "fig_created_internally" as flag + # Flag to track if the Matplotlib figure was created by this method. + # This prevents calling plt.show() if the user provided their own Axes. fig_created_internally = False if self.num_sites == 0: @@ -347,7 +431,8 @@ def show( return if ax is None: - # when ax is none, make fig_created_internally true + # If no Axes object is provided, create a new figure and axes. + logger.debug("[DEBUG-LATTICE] `ax` is None, creating new figure.") fig_created_internally = True if self.dimensionality == 3: fig = plt.figure(figsize=(8, 8)) @@ -355,9 +440,11 @@ def show( else: fig, ax = plt.subplots(figsize=(8, 8)) else: + logger.debug("[DEBUG-LATTICE] Using provided `ax` object.") fig = ax.figure # type: ignore coords = np.array(self._coordinates) + # Prepare arguments for the scatter plot, allowing user overrides. scatter_args = {"s": 100, "zorder": 2} scatter_args.update(kwargs) if self.dimensionality == 1: @@ -369,14 +456,14 @@ def show( ax.scatter(coords[:, 0], coords[:, 1], coords[:, 2], **scatter_args) # type: ignore if show_indices or show_identifiers: + logger.debug("[DEBUG-LATTICE] Drawing site labels (indices/identifiers).") for i in range(self.num_sites): label = str(self._identifiers[i]) if show_identifiers else str(i) + # Calculate a small offset for placing text labels to avoid overlap with sites. offset = ( 0.02 * np.max(np.ptp(coords, axis=0)) if coords.size > 0 else 0.1 ) - # Robust Logic: Decide plotting strategy based on known dimensionality. - if self.dimensionality == 1: ax.text(coords[i, 0], offset, label, fontsize=9, ha="center") elif self.dimensionality == 2: @@ -398,9 +485,8 @@ def show( zorder=3, ) - # Note: No 'else' needed as we already check dimensionality at the start. - if show_bonds_k is not None: + logger.debug(f"[DEBUG-LATTICE] Drawing bonds for k={show_bonds_k}.") if show_bonds_k not in self._neighbor_maps: logger.warning( f"Cannot draw bonds. k={show_bonds_k} neighbors have not been calculated." @@ -420,6 +506,7 @@ def show( if self.dimensionality > 2: ax_3d = cast("Axes3D", ax) for i, j in bonds: + assert self._coordinates is not None p1, p2 = self._coordinates[i], self._coordinates[j] ax_3d.plot( [p1[0], p2[0]], @@ -429,11 +516,12 @@ def show( ) else: for i, j in bonds: + assert self._coordinates is not None p1, p2 = self._coordinates[i], self._coordinates[j] if self.dimensionality == 1: # type: ignore ax.plot([p1[0], p2[0]], [0, 0], **plot_bond_kwargs) # type: ignore - else: # dimensionality == 2 + else: ax.plot([p1[0], p2[0]], [p1[1], p2[1]], **plot_bond_kwargs) # type: ignore except ValueError as e: @@ -449,7 +537,7 @@ def show( ax.set_zlabel("z") ax.grid(True) - # 3. whether plt.show() + # Display the plot only if the figure was created within this function. if fig_created_internally: plt.show() @@ -475,39 +563,51 @@ def _identify_distance_shells( :return: A sorted list of squared distances representing the shells. :rtype: List[float] """ + logger.debug( + f"[DEBUG-LATTICE] Identifying up to {max_k} distance shells with tolerance {tol}." + ) + # A small threshold to filter out zero distances (site to itself). ZERO_THRESHOLD_SQ = 1e-12 - all_distances_sq = np.asarray(all_distances_sq) + all_distances_sq = backend.convert_to_tensor(all_distances_sq) # Now, the .size call below is guaranteed to be safe. - if all_distances_sq.size == 0: + if backend.sizen(all_distances_sq) == 0: + logger.debug( + "[DEBUG-LATTICE] No non-zero distances found, returning empty shells." + ) return [] - sorted_dist = np.sort(all_distances_sq[all_distances_sq > ZERO_THRESHOLD_SQ]) + # Filter out self-distances and sort the remaining squared distances. + sorted_dist = backend.sort( + all_distances_sq[all_distances_sq > ZERO_THRESHOLD_SQ] + ) - if sorted_dist.size == 0: + if backend.sizen(sorted_dist) == 0: + logger.debug( + "[DEBUG-LATTICE] Sorted distances are empty, returning empty shells." + ) return [] - # Identify shells using the user-provided tolerance. dist_shells = [sorted_dist[0]] for d_sq in sorted_dist[1:]: if len(dist_shells) >= max_k: break - # If the current distance is notably larger than the last shell's distance - if d_sq > dist_shells[-1] + tol**2: + if backend.sqrt(d_sq) - backend.sqrt(dist_shells[-1]) > tol: dist_shells.append(d_sq) + logger.debug( + f"[DEBUG-LATTICE] Identified distance shells (squared): {dist_shells}" + ) return dist_shells def _build_neighbors_by_distance_matrix( self, max_k: int = 2, tol: float = 1e-6 ) -> None: """A generic, distance-based neighbor finding method. - This method calculates the full N x N distance matrix to find neighbor shells. It is computationally expensive for large N (O(N^2)) and is best suited for non-periodic or custom-defined lattices. - :param max_k: The maximum number of neighbor shells to calculate. Defaults to 2. :type max_k: int, optional @@ -515,29 +615,61 @@ def _build_neighbors_by_distance_matrix( comparisons. Defaults to 1e-6. :type tol: float, optional """ + logger.debug( + f"[DEBUG-LATTICE] Building neighbors via distance matrix up to max_k={max_k}." + ) if self.num_sites < 2: return - all_coords = np.array(self._coordinates) - dist_matrix_sq = np.sum( - (all_coords[:, np.newaxis, :] - all_coords[np.newaxis, :, :]) ** 2, axis=-1 + all_coords = self._coordinates + # Vectorized computation of the squared distance matrix: + # (N, 1, D) - (1, N, D) -> (N, N, D) -> (N, N) + displacements = backend.expand_dims(all_coords, 1) - backend.expand_dims( + all_coords, 0 ) + dist_matrix_sq = backend.sum(backend.power(displacements, 2), axis=-1) - all_distances_sq = dist_matrix_sq.flatten() + # Flatten the matrix to a list of all squared distances to identify shells. + all_distances_sq = backend.reshape(dist_matrix_sq, [-1]) dist_shells_sq = self._identify_distance_shells(all_distances_sq, max_k, tol) - self._neighbor_maps = {k: {} for k in range(1, len(dist_shells_sq) + 1)} + self._neighbor_maps = self._build_neighbor_map_from_distances( + dist_matrix_sq, dist_shells_sq, tol + ) + self._distance_matrix = backend.sqrt(dist_matrix_sq) + + def _build_neighbor_map_from_distances( + self, + dist_matrix_sq: Coordinates, + dist_shells_sq: List[float], + tol: float = 1e-6, + ) -> Dict[int, NeighborMap]: + """ + Builds a neighbor map from a squared distance matrix and identified shells. + This is a generic helper function to reduce code duplication. + """ + neighbor_maps: Dict[int, NeighborMap] = { + k: {} for k in range(1, len(dist_shells_sq) + 1) + } for k_idx, target_d_sq in enumerate(dist_shells_sq): k = k_idx + 1 current_k_map: Dict[int, List[int]] = {} - for i in range(self.num_sites): - neighbor_indices = np.where( - np.isclose(dist_matrix_sq[i], target_d_sq, rtol=0, atol=tol**2) - )[0] - if len(neighbor_indices) > 0: - current_k_map[i] = sorted(neighbor_indices.tolist()) - self._neighbor_maps[k] = current_k_map - self._distance_matrix = np.sqrt(dist_matrix_sq) + # For each shell, find all pairs of sites (i, j) with that distance. + is_close_matrix = backend.abs(dist_matrix_sq - target_d_sq) < tol + rows, cols = backend.where(is_close_matrix) + + for i, j in zip(backend.numpy(rows), backend.numpy(cols)): + if i == j: + continue + if i not in current_k_map: + current_k_map[i] = [] + current_k_map[i].append(j) + + for i in current_k_map: + current_k_map[i].sort() + + neighbor_maps[k] = current_k_map + return neighbor_maps class TILattice(AbstractLattice): @@ -587,19 +719,22 @@ def __init__( precompute_neighbors: Optional[int] = None, ): """Initializes the Translationally Invariant Lattice.""" + logger.debug(f"[DEBUG-LATTICE] Initializing TILattice: {size}, pbc={pbc}") super().__init__(dimensionality) - assert lattice_vectors.shape == ( + + self.lattice_vectors = backend.convert_to_tensor(lattice_vectors) + self.basis_coords = backend.convert_to_tensor(basis_coords) + + assert self.lattice_vectors.shape == ( dimensionality, dimensionality, ), "Lattice vectors shape mismatch" assert ( - basis_coords.shape[1] == dimensionality + self.basis_coords.shape[1] == dimensionality ), "Basis coordinates dimension mismatch" assert len(size) == dimensionality, "Size tuple length mismatch" - self.lattice_vectors = lattice_vectors - self.basis_coords = basis_coords - self.num_basis = basis_coords.shape[0] + self.num_basis = self.basis_coords.shape[0] self.size = size if isinstance(pbc, bool): self.pbc = tuple([pbc] * dimensionality) @@ -607,131 +742,209 @@ def __init__( assert len(pbc) == dimensionality, "PBC tuple length mismatch" self.pbc = tuple(pbc) - # Build the lattice sites and their neighbor relationships self._build_lattice() if precompute_neighbors is not None and precompute_neighbors > 0: logger.info(f"Pre-computing neighbors up to k={precompute_neighbors}...") self._build_neighbors(max_k=precompute_neighbors) def _build_lattice(self) -> None: - """Generates all site information for the periodic lattice. - - This method iterates through each unit cell defined by `self.size`, - and for each unit cell, it iterates through all basis sites. It then - calculates the real-space coordinates and creates a unique identifier - for each site, populating the internal lattice data structures. """ - current_index = 0 + Generates all site information for the periodic lattice in a vectorized manner. + """ + logger.debug("[DEBUG-LATTICE] Starting _build_lattice for TILattice.") + ranges = [backend.arange(s) for s in self.size] + + # Generate a grid of all integer unit cell coordinates. + grid = backend.meshgrid(*ranges, indexing="ij") + all_cell_coords = backend.reshape( + backend.stack(grid, axis=-1), (-1, self.dimensionality) + ) - # Iterate over all unit cell coordinates elegantly using np.ndindex - for cell_coord in np.ndindex(self.size): - cell_coord_arr = np.array(cell_coord) - # R = n1*a1 + n2*a2 + ... - cell_vector = np.dot(cell_coord_arr, self.lattice_vectors) + all_cell_coords = backend.cast(all_cell_coords, self.lattice_vectors.dtype) - # Iterate over the basis sites within the unit cell - for basis_index in range(self.num_basis): - basis_vec = self.basis_coords[basis_index] + cell_vectors = backend.tensordot( + all_cell_coords, self.lattice_vectors, axes=[[1], [0]] + ) + + cell_vectors = backend.cast(cell_vectors, self.basis_coords.dtype) + + # Combine cell vectors with basis coordinates to get all site positions + # via broadcasting: (num_cells, 1, D) + (1, num_basis, D) -> (num_cells, num_basis, D) + all_coords = backend.expand_dims(cell_vectors, 1) + backend.expand_dims( + self.basis_coords, 0 + ) - # Calculate the real-space coordinate - coord = cell_vector + basis_vec - # Create a structured identifier - identifier = cell_coord + (basis_index,) + self._coordinates = backend.reshape(all_coords, (-1, self.dimensionality)) - # Store site information + self._indices = [] + self._identifiers = [] + self._ident_to_idx = {} + current_index = 0 + + # Generate integer indices and tuple-based identifiers for all sites. + # e.g., identifier = (uc_x, uc_y, basis_idx) + size_ranges = [range(s) for s in self.size] + for cell_coord_tuple in itertools.product(*size_ranges): + for basis_index in range(self.num_basis): + identifier = cell_coord_tuple + (basis_index,) self._indices.append(current_index) self._identifiers.append(identifier) - self._coordinates.append(coord) self._ident_to_idx[identifier] = current_index current_index += 1 + logger.debug( + f"[DEBUG-LATTICE] Finished _build_lattice. Total sites: {self.num_sites}" + ) + def _get_distance_matrix_with_mic(self) -> Coordinates: """ - Computes the full N x N distance matrix, correctly applying the - Minimum Image Convention (MIC) for all periodic dimensions. + Computes the full N x N distance matrix using backend operations, + correctly applying the Minimum Image Convention (MIC) for all + periodic dimensions in a memory-efficient manner. """ - all_coords = np.array(self._coordinates) - size_arr = np.array(self.size) - system_vectors = self.lattice_vectors * size_arr[:, np.newaxis] + logger.debug( + "[DEBUG-LATTICE] Computing distance matrix with Minimum Image Convention." + ) + + size_arr = backend.convert_to_tensor(self.size) + size_arr = backend.cast(size_arr, self.lattice_vectors.dtype) + + # Calculate the full system vectors that span the entire finite lattice. + system_vectors = self.lattice_vectors * backend.expand_dims(size_arr, axis=1) - # Generate translation vectors ONLY for periodic dimensions pbc_dims = [d for d in range(self.dimensionality) if self.pbc[d]] - translations = [np.zeros(self.dimensionality)] - if pbc_dims: - num_pbc_dims = len(pbc_dims) - pbc_system_vectors = system_vectors[pbc_dims, :] - # Create all 3^k - 1 non-zero shifts for k periodic dimensions - shift_options = [np.array([-1, 0, 1])] * num_pbc_dims - shifts_grid = np.meshgrid(*shift_options, indexing="ij") - all_shifts = np.stack(shifts_grid, axis=-1).reshape(-1, num_pbc_dims) - all_shifts = all_shifts[np.any(all_shifts != 0, axis=1)] + if not pbc_dims: + # If no PBC, the only 'translation' is the zero vector. + translations_arr = backend.zeros( + [1, self.dimensionality], dtype=self.lattice_vectors.dtype + ) + else: + logger.debug( + f"[DEBUG-LATTICE] Applying MIC for periodic dimensions: {pbc_dims}" + ) + num_pbc_dims = len(pbc_dims) + pbc_system_vectors = backend.gather1d( + system_vectors, backend.convert_to_tensor(pbc_dims) + ) - pbc_translations = all_shifts @ pbc_system_vectors - translations.extend(pbc_translations) + # Generate all 3^d possible image shifts (-1, 0, 1) for periodic dimensions. + shift_options = [backend.convert_to_tensor([-1.0, 0.0, 1.0])] * num_pbc_dims + shifts_grid = backend.meshgrid(*shift_options, indexing="ij") + all_shifts = backend.reshape( + backend.stack(shifts_grid, axis=-1), (-1, num_pbc_dims) + ) - translations_arr = np.array(translations, dtype=float) + translations_arr = backend.tensordot( + all_shifts, pbc_system_vectors, axes=[[1], [0]] + ) - # Calculate the distance matrix applying MIC - dist_matrix_sq = np.full((self.num_sites, self.num_sites), np.inf, dtype=float) + dist_sq_rows = [] + # Iterate through each site `i` to compute its distance to all other sites `j`. + # This is done row-by-row to manage memory for very large lattices. + assert self._coordinates is not None for i in range(self.num_sites): - displacements = all_coords - all_coords[i] - image_displacements = ( - displacements[:, np.newaxis, :] - translations_arr[np.newaxis, :, :] - ) - image_d_sq = np.sum(image_displacements**2, axis=2) - dist_matrix_sq[i, :] = np.min(image_d_sq, axis=1) + # For each site `i`, calculate displacements to all other sites `j`. + displacements_i = self._coordinates - self._coordinates[i] # Shape: (N, D) + # Then, for each displacement `d_ij`, find the minimum distance among + # `d_ij` and all its periodic images. + image_displacements_i = backend.expand_dims( + displacements_i, 1 + ) - backend.expand_dims(translations_arr, 0) + image_d_sq_i = backend.sum(image_displacements_i**2, axis=2) + min_dist_sq_i = backend.min(image_d_sq_i, axis=1) + dist_sq_rows.append(min_dist_sq_i) + + dist_matrix_sq = backend.stack(dist_sq_rows, axis=0) + safe_dist_matrix_sq = backend.where(dist_matrix_sq > 0, dist_matrix_sq, 0.0) + return backend.sqrt(safe_dist_matrix_sq) + + def _get_distance_matrix_with_mic_vectorized(self) -> Coordinates: + """ + Computes the full N x N distance matrix using a fully vectorized approach + to be compatible with JIT compilation (e.g., JAX). + """ + logger.debug("[DEBUG-LATTICE] Computing distance matrix with MIC (vectorized).") + size_arr = backend.cast( + backend.convert_to_tensor(self.size), self.lattice_vectors.dtype + ) + system_vectors = self.lattice_vectors * backend.expand_dims(size_arr, axis=1) + + pbc_mask = backend.convert_to_tensor(self.pbc) + + # Generate all 3^d possible image shifts (-1, 0, 1) for all dimensions + shift_options = [ + backend.convert_to_tensor([-1.0, 0.0, 1.0]) + ] * self.dimensionality + shifts_grid = backend.meshgrid(*shift_options, indexing="ij") + all_shifts = backend.reshape( + backend.stack(shifts_grid, axis=-1), (-1, self.dimensionality) + ) + + # Only apply shifts to periodic dimensions + masked_shifts = all_shifts * backend.cast(pbc_mask, all_shifts.dtype) + + # Calculate all translation vectors due to PBC + translations_arr = backend.tensordot( + masked_shifts, system_vectors, axes=[[1], [0]] + ) + + # Vectorized computation of all displacements between any two sites + # Shape: (N, 1, D) - (1, N, D) -> (N, N, D) + displacements = backend.expand_dims(self._coordinates, 1) - backend.expand_dims( + self._coordinates, 0 + ) + + # Consider all periodic images for each displacement + # Shape: (N, N, 1, D) - (1, 1, num_translations, D) -> (N, N, num_translations, D) + image_displacements = backend.expand_dims( + displacements, 2 + ) - backend.expand_dims(backend.expand_dims(translations_arr, 0), 0) - return cast(Coordinates, np.sqrt(dist_matrix_sq)) + # Sum of squares for distances + image_d_sq = backend.sum(backend.power(image_displacements, 2), axis=3) + + # Find the minimum distance among all images (Minimum Image Convention) + min_dist_sq = backend.min(image_d_sq, axis=2) + + safe_dist_matrix_sq = backend.where(min_dist_sq > 0, min_dist_sq, 0.0) + return backend.sqrt(safe_dist_matrix_sq) def _build_neighbors(self, max_k: int = 2, **kwargs: Any) -> None: """Calculates neighbor relationships for the periodic lattice. - This method calculates neighbor relationships by computing the full N x N - distance matrix. It robustly handles all boundary conditions (fully - periodic, open, or mixed) by applying the Minimum Image Convention - (MIC) only to the periodic dimensions. - - From this distance matrix, it identifies unique neighbor shells up to - the specified `max_k` and populates the neighbor maps. The computed - distance matrix is then cached for future use. + This method computes neighbor information by first calculating the full + distance matrix using the Minimum Image Convention (MIC) to correctly + handle periodic boundary conditions. It then identifies unique distance + shells (e.g., nearest, next-nearest) and populates the neighbor maps + accordingly. This approach is general and works for any periodic lattice + geometry defined by the TILattice class. - :param max_k: The maximum number of neighbor shells to - calculate. Defaults to 2. + :param max_k: The maximum order of neighbors to compute (e.g., k=1 for + nearest neighbors, k=2 for next-nearest, etc.). Defaults to 2. :type max_k: int, optional - :param tol: The numerical tolerance for distance - comparisons. Defaults to 1e-6. - :type tol: float, optional + :param \**kwargs: Additional keyword arguments. May include: + - ``tol`` (float): The numerical tolerance used to determine if two + distances are equal when identifying shells. Defaults to 1e-6. """ + logger.debug( + f"[DEBUG-LATTICE] Building neighbors for TILattice up to max_k={max_k}." + ) tol = kwargs.get("tol", 1e-6) - dist_matrix = self._get_distance_matrix_with_mic() + dist_matrix = self._get_distance_matrix_with_mic_vectorized() dist_matrix_sq = dist_matrix**2 self._distance_matrix = dist_matrix - all_distances_sq = dist_matrix_sq.flatten() + all_distances_sq = backend.reshape(dist_matrix_sq, [-1]) dist_shells_sq = self._identify_distance_shells(all_distances_sq, max_k, tol) - - self._neighbor_maps = {k: {} for k in range(1, len(dist_shells_sq) + 1)} - for k_idx, target_d_sq in enumerate(dist_shells_sq): - k = k_idx + 1 - current_k_map: Dict[int, List[int]] = {} - match_indices = np.where( - np.isclose(dist_matrix_sq, target_d_sq, rtol=0, atol=tol**2) - ) - for i, j in zip(*match_indices): - if i == j: - continue - if i not in current_k_map: - current_k_map[i] = [] - current_k_map[i].append(j) - - for i in current_k_map: - current_k_map[i].sort() - - self._neighbor_maps[k] = current_k_map + self._neighbor_maps = self._build_neighbor_map_from_distances( + dist_matrix_sq, dist_shells_sq, tol + ) def _compute_distance_matrix(self) -> Coordinates: """Computes the distance matrix using the Minimum Image Convention.""" - return self._get_distance_matrix_with_mic() + if self.num_sites == 0: + return backend.zeros((0, 0)) + return self._get_distance_matrix_with_mic_vectorized() class SquareLattice(TILattice): @@ -759,20 +972,20 @@ class SquareLattice(TILattice): def __init__( self, size: Tuple[int, int], - lattice_constant: float = 1.0, + lattice_constant: Union[float, Any] = 1.0, pbc: Union[bool, Tuple[bool, bool]] = True, precompute_neighbors: Optional[int] = None, ): """Initializes the SquareLattice.""" dimensionality = 2 - # Define lattice vectors for a square lattice - lattice_vectors = np.array([[lattice_constant, 0.0], [0.0, lattice_constant]]) - - # A square lattice has a single site in its basis - basis_coords = np.array([[0.0, 0.0]]) + # Define orthogonal lattice vectors for a square. + lattice_vectors = backend.convert_to_tensor( + [[lattice_constant, 0.0], [0.0, lattice_constant]] + ) + # A square lattice is a Bravais lattice, so it has a single-site basis. + basis_coords = backend.convert_to_tensor([[0.0, 0.0]]) - # Call the parent TILattice constructor with these parameters super().__init__( dimensionality=dimensionality, lattice_vectors=lattice_vectors, @@ -808,7 +1021,7 @@ class HoneycombLattice(TILattice): def __init__( self, size: Tuple[int, int], - lattice_constant: float = 1.0, + lattice_constant: Union[float, Any] = 1.0, pbc: Union[bool, Tuple[bool, bool]] = True, precompute_neighbors: Optional[int] = None, ): @@ -816,11 +1029,13 @@ def __init__( dimensionality = 2 a = lattice_constant - # Define the primitive lattice vectors for the underlying triangular lattice - lattice_vectors = a * np.array([[1.5, np.sqrt(3) / 2], [1.5, -np.sqrt(3) / 2]]) - - # Define the coordinates of the two basis sites (A and B) - basis_coords = a * np.array([[0.0, 0.0], [1.0, 0.0]]) # Site A # Site B + # Define the two primitive lattice vectors for the underlying triangular Bravais lattice. + lattice_vectors = [ + [a * 1.5, a * backend.sqrt(backend.convert_to_tensor(3.0)) / 2], + [a * 1.5, -a * backend.sqrt(backend.convert_to_tensor(3.0)) / 2], + ] + # Define the two basis sites (A and B) within the unit cell. + basis_coords = [[0.0, 0.0], [a * 1.0, 0.0]] super().__init__( dimensionality=dimensionality, @@ -855,7 +1070,7 @@ class TriangularLattice(TILattice): def __init__( self, size: Tuple[int, int], - lattice_constant: float = 1.0, + lattice_constant: Union[float, Any] = 1.0, pbc: Union[bool, Tuple[bool, bool]] = True, precompute_neighbors: Optional[int] = None, ): @@ -863,11 +1078,13 @@ def __init__( dimensionality = 2 a = lattice_constant - # Define the primitive lattice vectors for a triangular lattice - lattice_vectors = a * np.array([[1.0, 0.0], [0.5, np.sqrt(3) / 2]]) - - # A triangular lattice is a Bravais lattice, with a single site in its basis - basis_coords = np.array([[0.0, 0.0]]) + # Define the primitive lattice vectors for a triangular lattice. + lattice_vectors = [ + [a * 1.0, 0.0], + [a * 0.5, a * backend.sqrt(backend.convert_to_tensor(3.0)) / 2], + ] + # A triangular lattice is a Bravais lattice with a single-site basis. + basis_coords = [[0.0, 0.0]] super().__init__( dimensionality=dimensionality, @@ -896,13 +1113,16 @@ class ChainLattice(TILattice): def __init__( self, size: Tuple[int], - lattice_constant: float = 1.0, + lattice_constant: Union[float, Any] = 1.0, pbc: bool = True, precompute_neighbors: Optional[int] = None, ): dimensionality = 1 - lattice_vectors = np.array([[lattice_constant]]) - basis_coords = np.array([[0.0]]) + # The lattice vector is just the lattice constant along one dimension. + lattice_vectors = [[lattice_constant]] + # A simple chain is a Bravais lattice with a single-site basis. + basis_coords = [[0.0]] + super().__init__( dimensionality=dimensionality, lattice_vectors=lattice_vectors, @@ -934,15 +1154,15 @@ class DimerizedChainLattice(TILattice): def __init__( self, size: Tuple[int], - lattice_constant: float = 1.0, + lattice_constant: Union[float, Any] = 1.0, pbc: bool = True, precompute_neighbors: Optional[int] = None, ): dimensionality = 1 - # The unit cell vector connects two A sites, spanning length 2*a - lattice_vectors = np.array([[2 * lattice_constant]]) - # Basis has site A at origin, site B at distance 'a' - basis_coords = np.array([[0.0], [lattice_constant]]) + # The unit cell is twice the bond length, as it contains two sites. + lattice_vectors = [[2 * lattice_constant]] + # Two basis sites (A and B) separated by the bond length. + basis_coords = [[0.0], [lattice_constant]] super().__init__( dimensionality=dimensionality, @@ -975,14 +1195,16 @@ class RectangularLattice(TILattice): def __init__( self, size: Tuple[int, int], - lattice_constants: Tuple[float, float] = (1.0, 1.0), + lattice_constants: Union[Tuple[float, float], Any] = (1.0, 1.0), pbc: Union[bool, Tuple[bool, bool]] = True, precompute_neighbors: Optional[int] = None, ): dimensionality = 2 ax, ay = lattice_constants - lattice_vectors = np.array([[ax, 0.0], [0.0, ay]]) - basis_coords = np.array([[0.0, 0.0]]) + # Orthogonal lattice vectors with potentially different lengths. + lattice_vectors = [[ax, 0.0], [0.0, ay]] + # A rectangular lattice is a Bravais lattice with a single-site basis. + basis_coords = [[0.0, 0.0]] super().__init__( dimensionality=dimensionality, @@ -1013,16 +1235,17 @@ class CheckerboardLattice(TILattice): def __init__( self, size: Tuple[int, int], - lattice_constant: float = 1.0, + lattice_constant: Union[float, Any] = 1.0, pbc: Union[bool, Tuple[bool, bool]] = True, precompute_neighbors: Optional[int] = None, ): dimensionality = 2 a = lattice_constant - # Primitive vectors for a square lattice rotated by 45 degrees. - lattice_vectors = a * np.array([[1.0, 1.0], [1.0, -1.0]]) - # Two-site basis - basis_coords = a * np.array([[0.0, 0.0], [1.0, 0.0]]) + # The unit cell is a square rotated by 45 degrees. + lattice_vectors = [[a * 1.0, a * 1.0], [a * 1.0, a * -1.0]] + # Two basis sites (A and B) within the unit cell. + basis_coords = [[a * 0.0, a * 0.0], [a * 1.0, a * 0.0]] + super().__init__( dimensionality=dimensionality, lattice_vectors=lattice_vectors, @@ -1052,16 +1275,24 @@ class KagomeLattice(TILattice): def __init__( self, size: Tuple[int, int], - lattice_constant: float = 1.0, + lattice_constant: Union[float, Any] = 1.0, pbc: Union[bool, Tuple[bool, bool]] = True, precompute_neighbors: Optional[int] = None, ): dimensionality = 2 a = lattice_constant - # Using a rectangular unit cell definition for simplicity - lattice_vectors = a * np.array([[2.0, 0.0], [1.0, np.sqrt(3)]]) - # Three-site basis - basis_coords = a * np.array([[0.0, 0.0], [1.0, 0.0], [0.5, np.sqrt(3) / 2.0]]) + # The Kagome lattice is based on a triangular Bravais lattice. + lattice_vectors = [ + [a * 2.0, a * 0.0], + [a * 1.0, a * backend.sqrt(backend.convert_to_tensor(3.0))], + ] + # It has a three-site basis, forming the corners of the triangles. + basis_coords = [ + [a * 0.0, a * 0.0], + [a * 1.0, a * 0.0], + [a * 0.5, a * backend.sqrt(backend.convert_to_tensor(3.0)) / 2.0], + ] + super().__init__( dimensionality=dimensionality, lattice_vectors=lattice_vectors, @@ -1092,31 +1323,23 @@ class LiebLattice(TILattice): def __init__( self, size: Tuple[int, int], - lattice_constant: float = 1.0, + lattice_constant: Union[float, Any] = 1.0, pbc: Union[bool, Tuple[bool, bool]] = True, precompute_neighbors: Optional[int] = None, ): """Initializes the LiebLattice.""" dimensionality = 2 - # Use a more descriptive name for clarity. In a Lieb lattice, - # the lattice_constant is the bond length between nearest neighbors. bond_length = lattice_constant - # The unit cell of a Lieb lattice is a square with side length - # equal to twice the bond length. unit_cell_side = 2 * bond_length - lattice_vectors = np.array([[unit_cell_side, 0.0], [0.0, unit_cell_side]]) - - # The three-site basis consists of a corner site, a site on the - # center of the horizontal edge, and a site on the center of the vertical edge. - # Their coordinates are defined directly in terms of the physical bond length. - basis_coords = np.array( - [ - [0.0, 0.0], # Corner site - [bond_length, 0.0], # Horizontal edge center - [0.0, bond_length], # Vertical edge center - ] - ) + # The Lieb lattice is based on a square Bravais lattice. + lattice_vectors = [[unit_cell_side, 0.0], [0.0, unit_cell_side]] + # It has a three-site basis: one corner and two edge-centers. + basis_coords = [ + [0.0, 0.0], # Corner site + [bond_length, 0.0], # x-edge center + [0.0, bond_length], # y-edge center + ] super().__init__( dimensionality=dimensionality, @@ -1147,14 +1370,16 @@ class CubicLattice(TILattice): def __init__( self, size: Tuple[int, int, int], - lattice_constant: float = 1.0, + lattice_constant: Union[float, Any] = 1.0, pbc: Union[bool, Tuple[bool, bool, bool]] = True, precompute_neighbors: Optional[int] = None, ): dimensionality = 3 a = lattice_constant - lattice_vectors = np.array([[a, 0, 0], [0, a, 0], [0, 0, a]]) - basis_coords = np.array([[0.0, 0.0, 0.0]]) + # Orthogonal lattice vectors of equal length in 3D. + lattice_vectors = [[a, 0, 0], [0, a, 0], [0, 0, a]] + # A simple cubic lattice is a Bravais lattice with a single-site basis. + basis_coords = [[0.0, 0.0, 0.0]] super().__init__( dimensionality=dimensionality, lattice_vectors=lattice_vectors, @@ -1194,29 +1419,40 @@ def __init__( self, dimensionality: int, identifiers: List[SiteIdentifier], - coordinates: List[Union[List[float], Coordinates]], + coordinates: Any, precompute_neighbors: Optional[int] = None, ): """Initializes the CustomizeLattice.""" + logger.debug( + f"[DEBUG-LATTICE] Initializing CustomizeLattice with {len(identifiers)} sites." + ) super().__init__(dimensionality) - if len(identifiers) != len(coordinates): + + self._coordinates = backend.convert_to_tensor(coordinates) + if len(identifiers) == 0: + self._coordinates = backend.reshape( + self._coordinates, (0, self.dimensionality) + ) + + if len(identifiers) != backend.shape_tuple(self._coordinates)[0]: raise ValueError( - "Identifiers and coordinates lists must have the same length." + "The number of identifiers must match the number of coordinates. " + f"Got {len(identifiers)} identifiers and " + f"{backend.shape_tuple(self._coordinates)[0]} coordinates." ) - # The _build_lattice logic is simple enough to be in __init__ self._identifiers = list(identifiers) - self._coordinates = [np.array(c) for c in coordinates] self._indices = list(range(len(identifiers))) self._ident_to_idx = {ident: idx for idx, ident in enumerate(identifiers)} - # Validate coordinate dimensions - for i, coord in enumerate(self._coordinates): - if coord.shape != (dimensionality,): - raise ValueError( - f"Coordinate at index {i} has shape {coord.shape}, " - f"expected ({dimensionality},)" - ) + if ( + self.num_sites > 0 + and backend.shape_tuple(self._coordinates)[1] != dimensionality + ): + raise ValueError( + f"Coordinates tensor has dimension {backend.shape_tuple(self._coordinates)[1]}, " + f"but expected dimensionality is {dimensionality}." + ) logger.info(f"CustomizeLattice with {self.num_sites} sites created.") @@ -1228,98 +1464,66 @@ def _build_lattice(self, *args: Any, **kwargs: Any) -> None: pass def _build_neighbors(self, max_k: int = 1, **kwargs: Any) -> None: - """Calculates neighbors using a KDTree for efficiency. - - This method uses a memory-efficient approach to identify neighbors without - initially computing the full N x N distance matrix. It leverages - `scipy.spatial.distance.pdist` to find unique distance shells and then - a `scipy.spatial.KDTree` for fast radius queries. This approach is - significantly more memory-efficient during the neighbor identification phase. - - After the neighbors are identified, the full distance matrix is computed - from the pairwise distances and cached for potential future use. + """ + Calculates neighbor relationships using a distance matrix. - :param max_k: The maximum number of neighbor shells to - calculate. Defaults to 1. - :type max_k: int, optional - :param tol: The numerical tolerance for distance - comparisons. Defaults to 1e-6. - :type tol: float, optional + This method leverages the generic `_build_neighbors_by_distance_matrix` + to ensure differentiability, avoiding non-differentiable libraries + like SciPy's KDTree. """ + logger.debug( + f"[DEBUG-LATTICE] Building neighbors for CustomizeLattice up to max_k={max_k}." + ) tol = kwargs.get("tol", 1e-6) - logger.info(f"Building neighbors for CustomizeLattice up to k={max_k}...") if self.num_sites < 2: return - all_coords = np.array(self._coordinates) + # For CustomizeLattice, we must use the distance matrix method. + dist_matrix = self._compute_distance_matrix() + dist_matrix_sq = dist_matrix**2 + self._distance_matrix = dist_matrix - # 1. Use pdist for memory-efficient calculation of pairwise distances - # to robustly identify the distance shells. - all_distances_sq = pdist(all_coords, metric="sqeuclidean") + all_distances_sq = backend.reshape(dist_matrix_sq, [-1]) dist_shells_sq = self._identify_distance_shells(all_distances_sq, max_k, tol) - if not dist_shells_sq: - logger.info("No distinct neighbor shells found.") - return - - # 2. Build the KDTree for efficient querying. - tree = KDTree(all_coords) - self._neighbor_maps = {k: {} for k in range(1, len(dist_shells_sq) + 1)} - - # 3. Find neighbors by isolating shells using inclusion-exclusion. - # `found_indices` will store all neighbors within a given radius. - found_indices: List[set[int]] = [] - for k_idx, target_d_sq in enumerate(dist_shells_sq): - radius = np.sqrt(target_d_sq) + tol - # Query for all points within the new, larger radius. - current_shell_indices = tree.query_ball_point( - all_coords, r=radius, return_sorted=True - ) - - # Now, isolate the neighbors for the current shell k - k = k_idx + 1 - current_k_map: Dict[int, List[int]] = {} - for i in range(self.num_sites): - - if k_idx == 0: - co_located_indices = tree.query_ball_point(all_coords[i], r=1e-12) - prev_found = set(co_located_indices) - else: - prev_found = found_indices[i] - - # The new neighbors are those in the current radius shell, - # excluding those already found in smaller shells. - new_neighbors = set(current_shell_indices[i]) - prev_found - - if new_neighbors: - current_k_map[i] = sorted(list(new_neighbors)) - - self._neighbor_maps[k] = current_k_map - found_indices = [ - set(l) for l in current_shell_indices - ] # Update for next iteration - self._distance_matrix = np.sqrt(squareform(all_distances_sq)) + self._neighbor_maps = self._build_neighbor_map_from_distances( + dist_matrix_sq, dist_shells_sq, tol + ) - logger.info("Neighbor building complete using KDTree.") + logger.info(f"Neighbor building complete for CustomizeLattice up to k={max_k}.") def _compute_distance_matrix(self) -> Coordinates: - """Computes the distance matrix from the stored coordinates. - - This implementation uses scipy.pdist for a memory-efficient - calculation of pairwise distances, which is then converted to a - full square matrix. """ + Computes the full N x N distance matrix using backend operations. + This implementation is fully differentiable. + """ + logger.debug("[DEBUG-LATTICE] Computing distance matrix for CustomizeLattice.") + if self.num_sites == 0: + return backend.zeros((0, 0)) if self.num_sites < 2: - return cast(Coordinates, np.empty((self.num_sites, self.num_sites))) + assert self._coordinates is not None + return backend.zeros( + (self.num_sites, self.num_sites), dtype=self._coordinates.dtype + ) - all_coords = np.array(self._coordinates) - # Use pdist for memory-efficiency, then build the full matrix. - all_distances_sq = pdist(all_coords, metric="sqeuclidean") - dist_matrix_sq = squareform(all_distances_sq) - return cast(Coordinates, np.sqrt(dist_matrix_sq)) + # Vectorized computation of displacements: (N, 1, D) - (1, N, D) -> (N, N, D) + displacements = backend.expand_dims(self._coordinates, 1) - backend.expand_dims( + self._coordinates, 0 + ) + + dist_matrix_sq = backend.sum(backend.power(displacements, 2), axis=-1) + + return backend.where( + backend.equal(dist_matrix_sq, 0), + 0, + backend.sqrt(dist_matrix_sq), + ) def _reset_computations(self) -> None: """Resets all cached data that depends on the lattice structure.""" + logger.debug( + "[DEBUG-LATTICE] Resetting cached computations (_neighbor_maps, _distance_matrix)." + ) self._neighbor_maps = {} self._distance_matrix = None @@ -1336,9 +1540,15 @@ def from_lattice(cls, lattice: "AbstractLattice") -> "CustomizeLattice": :return: A new CustomizeLattice instance with the same sites. :rtype: CustomizeLattice """ + logger.debug( + f"[DEBUG-LATTICE] Creating CustomizeLattice from existing lattice: {type(lattice).__name__}" + ) all_sites_info = list(lattice.sites()) if not all_sites_info: + logger.debug( + "[DEBUG-LATTICE] Source lattice is empty, creating an empty CustomizeLattice." + ) return cls( dimensionality=lattice.dimensionality, identifiers=[], coordinates=[] ) @@ -1355,7 +1565,7 @@ def from_lattice(cls, lattice: "AbstractLattice") -> "CustomizeLattice": def add_sites( self, identifiers: List[SiteIdentifier], - coordinates: List[Union[List[float], Coordinates]], + coordinates: Any, ) -> None: """Adds new sites to the lattice. @@ -1363,21 +1573,32 @@ def add_sites( previously computed neighbor information is cleared and must be recalculated. - :param identifiers: A list of unique, hashable identifiers for the new sites. + :param identifiers: A list of unique identifiers for the new sites. :type identifiers: List[SiteIdentifier] - :param coordinates: A list of coordinates for the new sites. - :type coordinates: List[Union[List[float], np.ndarray]] - :raises ValueError: If input lists have mismatched lengths, or if any new - identifier already exists in the lattice. + :param coordinates: The coordinates for the new sites. Can be a list of lists, + a NumPy array, or a backend-compatible tensor (e.g., jax.numpy.ndarray). + :type coordinates: Any """ - if len(identifiers) != len(coordinates): + if not identifiers: + logger.debug( + "[DEBUG-LATTICE] add_sites called with empty identifiers list. No action taken." + ) + return + + new_coords_tensor = backend.convert_to_tensor(coordinates) + + if len(identifiers) != backend.shape_tuple(new_coords_tensor)[0]: raise ValueError( "Identifiers and coordinates lists must have the same length." ) - if not identifiers: - return # Nothing to add - # Check for duplicate identifiers before making any changes + if backend.shape_tuple(new_coords_tensor)[1] != self.dimensionality: + raise ValueError( + f"New coordinate tensor has dimension {backend.shape_tuple(new_coords_tensor)[1]}, " + f"but expected dimensionality is {self.dimensionality}." + ) + + # Ensure that the new identifiers are unique and do not already exist. existing_ids = set(self._identifiers) new_ids = set(identifiers) if not new_ids.isdisjoint(existing_ids): @@ -1385,21 +1606,14 @@ def add_sites( f"Duplicate identifiers found: {new_ids.intersection(existing_ids)}" ) - for i, coord in enumerate(coordinates): - coord_arr = np.asarray(coord) - if coord_arr.shape != (self.dimensionality,): - raise ValueError( - f"New coordinate at index {i} has shape {coord_arr.shape}, " - f"expected ({self.dimensionality},)" - ) - self._coordinates.append(coord_arr) - self._identifiers.append(identifiers[i]) + self._coordinates = backend.concat( + [self._coordinates, new_coords_tensor], axis=0 + ) + self._identifiers.extend(identifiers) - # Rebuild index mappings from scratch self._indices = list(range(len(self._identifiers))) self._ident_to_idx = {ident: idx for idx, ident in enumerate(self._identifiers)} - # Invalidate any previously computed neighbors or distance matrices self._reset_computations() logger.info( f"{len(identifiers)} sites added. Lattice now has {self.num_sites} sites." @@ -1414,10 +1628,12 @@ def remove_sites(self, identifiers: List[SiteIdentifier]) -> None: :param identifiers: A list of identifiers for the sites to be removed. :type identifiers: List[SiteIdentifier] - :raises ValueError: If any of the specified identifiers do not exist. """ if not identifiers: - return # Nothing to remove + logger.debug( + "[DEBUG-LATTICE] remove_sites called with empty identifiers list. No action taken." + ) + return ids_to_remove = set(identifiers) current_ids = set(self._identifiers) @@ -1426,23 +1642,25 @@ def remove_sites(self, identifiers: List[SiteIdentifier]) -> None: f"Non-existent identifiers provided for removal: {ids_to_remove - current_ids}" ) - # Create new lists containing only the sites to keep - new_identifiers: List[SiteIdentifier] = [] - new_coordinates: List[Coordinates] = [] - for ident, coord in zip(self._identifiers, self._coordinates): - if ident not in ids_to_remove: - new_identifiers.append(ident) - new_coordinates.append(coord) + # Find the indices of the sites that we want to keep. + indices_to_keep = [ + idx + for idx, ident in enumerate(self._identifiers) + if ident not in ids_to_remove + ] + + new_identifiers = [self._identifiers[i] for i in indices_to_keep] + + self._coordinates = backend.gather1d( + self._coordinates, + backend.cast(backend.convert_to_tensor(indices_to_keep), "int32"), + ) - # Replace old data with the new, filtered data self._identifiers = new_identifiers - self._coordinates = new_coordinates - # Rebuild index mappings self._indices = list(range(len(self._identifiers))) self._ident_to_idx = {ident: idx for idx, ident in enumerate(self._identifiers)} - # Invalidate caches self._reset_computations() logger.info( f"{len(ids_to_remove)} sites removed. Lattice now has {self.num_sites} sites." @@ -1477,24 +1695,31 @@ def get_compatible_layers(bonds: List[Tuple[int, int]]) -> List[List[Tuple[int, tuple represents a bond. All bonds within a layer are non-overlapping. :rtype: List[List[Tuple[int, int]]] """ - uncolored_edges: Set[Tuple[int, int]] = {(min(bond), max(bond)) for bond in bonds} + logger.debug(f"[DEBUG-LATTICE] Getting compatible layers for {len(bonds)} bonds.") + # Ensure all bonds are in a canonical form (i, j) with i < j and remove duplicates. + sorted_edges = sorted(list({(min(bond), max(bond)) for bond in bonds})) layers: List[List[Tuple[int, int]]] = [] + unassigned_edges = set(sorted_edges) - while uncolored_edges: + # Greedily build layers until all edges have been assigned. + while unassigned_edges: current_layer: List[Tuple[int, int]] = [] qubits_in_this_layer: Set[int] = set() - edges_to_process = sorted(list(uncolored_edges)) + sorted_unassigned = sorted(list(unassigned_edges)) - for edge in edges_to_process: + # Iterate through remaining edges and add an edge to the current layer + # if it doesn't conflict with (share a qubit with) edges already in the layer. + for edge in sorted_unassigned: i, j = edge if i not in qubits_in_this_layer and j not in qubits_in_this_layer: current_layer.append(edge) qubits_in_this_layer.add(i) qubits_in_this_layer.add(j) - uncolored_edges -= set(current_layer) - layers.append(sorted(current_layer)) + unassigned_edges -= set(current_layer) + layers.append(current_layer) + logger.debug(f"[DEBUG-LATTICE] Found {len(layers)} compatible layers.") return layers diff --git a/tests/test_hamiltonians.py b/tests/test_hamiltonians.py index 7ecd0112..8be5771d 100644 --- a/tests/test_hamiltonians.py +++ b/tests/test_hamiltonians.py @@ -157,3 +157,35 @@ def test_anisotropic_heisenberg(self, backend): h_generated_dense = tc.backend.to_dense(h_generated) assert h_generated_dense.shape == (4, 4) assert np.allclose(h_generated_dense, h_expected) + + def test_heisenberg_hamiltonian_all_interactions(self): + """ + Test the Heisenberg Hamiltonian with 'all' interaction scope. + For a 3-site chain, this should include interactions (0,1), (0,2), and (1,2). + """ + lattice = ChainLattice(size=(3,), pbc=False) + j_coupling = 1.0 + h_generated = heisenberg_hamiltonian( + lattice, j_coupling=j_coupling, interaction_scope="all" + ) + + # Manually construct the expected Hamiltonian for all-to-all interaction + # H = J * [ (X0X1 + Y0Y1 + Z0Z1) + (X0X2 + Y0Y2 + Z0Z2) + (X1X2 + Y1Y2 + Z1Z2) ] + xx_01 = np.kron(PAULI_X, np.kron(PAULI_X, PAULI_I)) + yy_01 = np.kron(PAULI_Y, np.kron(PAULI_Y, PAULI_I)) + zz_01 = np.kron(PAULI_Z, np.kron(PAULI_Z, PAULI_I)) + + xx_02 = np.kron(PAULI_X, np.kron(PAULI_I, PAULI_X)) + yy_02 = np.kron(PAULI_Y, np.kron(PAULI_I, PAULI_Y)) + zz_02 = np.kron(PAULI_Z, np.kron(PAULI_I, PAULI_Z)) + + xx_12 = np.kron(PAULI_I, np.kron(PAULI_X, PAULI_X)) + yy_12 = np.kron(PAULI_I, np.kron(PAULI_Y, PAULI_Y)) + zz_12 = np.kron(PAULI_I, np.kron(PAULI_Z, PAULI_Z)) + + h_expected = j_coupling * ( + (xx_01 + yy_01 + zz_01) + (xx_02 + yy_02 + zz_02) + (xx_12 + yy_12 + zz_12) + ) + + assert h_generated.shape == (8, 8) + assert np.allclose(tc.backend.to_dense(h_generated), h_expected) diff --git a/tests/test_lattice.py b/tests/test_lattice.py index 12354f13..39f064bd 100644 --- a/tests/test_lattice.py +++ b/tests/test_lattice.py @@ -1,10 +1,43 @@ +from __future__ import annotations from unittest.mock import patch import logging +import sys -# import time +# Configure logging for debugging purposes +logging.basicConfig( + level=logging.DEBUG, + stream=sys.stdout, + format="[%(levelname)s] %(asctime)s - %(name)s:%(lineno)d - %(message)s", +) +logger = logging.getLogger(__name__) +from unittest.mock import patch +from typing import TYPE_CHECKING, Any + +import time import matplotlib + +try: + import jax.numpy as jnp +except ImportError: + jnp = None +try: + import tensorflow as tf +except ImportError: + tf = None +try: + import torch +except ImportError: + torch = None + +if TYPE_CHECKING: + import jax.numpy as jnp + from jax import Array + import tensorflow as tf + import torch + + matplotlib.use("Agg") @@ -26,6 +59,7 @@ AbstractLattice, get_compatible_layers, ) +import tensorcircuit as tc @pytest.fixture @@ -107,7 +141,7 @@ def test_input_validation_mismatched_lengths(self): # the specified exception is raised within the 'with' block. with pytest.raises( ValueError, - match="Identifiers and coordinates lists must have the same length.", + match="The number of identifiers must match the number of coordinates.", ): CustomizeLattice(dimensionality=2, identifiers=ids, coordinates=coords) @@ -122,9 +156,8 @@ def test_input_validation_wrong_dimension(self): # Act & Assert: Check for the specific error message. The 'r' before the string # indicates a raw string, which is good practice for regex patterns. - with pytest.raises( - ValueError, match=r"Coordinate at index 1 has shape \(3,\), expected \(2,\)" - ): + with pytest.raises(ValueError): + CustomizeLattice( dimensionality=2, identifiers=ids_ok, coordinates=coords_wrong_dim ) @@ -139,15 +172,51 @@ def test_neighbor_finding(self, simple_square_lattice): # --- Assertions for k=1 (Nearest Neighbors) --- # We use set() for comparison to ignore the order of neighbors. - assert set(lattice.get_neighbors(0, k=1)) == {1, 2} - assert set(lattice.get_neighbors(1, k=1)) == {0, 3} - assert set(lattice.get_neighbors(2, k=1)) == {0, 3} - assert set(lattice.get_neighbors(3, k=1)) == {1, 2} + neighbors = set(lattice.get_neighbors(0, k=1)) + expected = {1, 2} + assert neighbors == expected, ( + f"Failed neighbor check for site 0 (k=1). " + f"Expected {expected}, but got {neighbors}. " + f"Missing: {expected - neighbors}. Extra: {neighbors - expected}." + ) + neighbors = set(lattice.get_neighbors(1, k=1)) + expected = {0, 3} + assert neighbors == expected, ( + f"Failed neighbor check for site 1 (k=1). " + f"Expected {expected}, but got {neighbors}. " + f"Missing: {expected - neighbors}. Extra: {neighbors - expected}." + ) + neighbors = set(lattice.get_neighbors(2, k=1)) + expected = {0, 3} + assert neighbors == expected, ( + f"Failed neighbor check for site 2 (k=1). " + f"Expected {expected}, but got {neighbors}. " + f"Missing: {expected - neighbors}. Extra: {neighbors - expected}." + ) + neighbors = set(lattice.get_neighbors(3, k=1)) + expected = {1, 2} + assert neighbors == expected, ( + f"Failed neighbor check for site 3 (k=1). " + f"Expected {expected}, but got {neighbors}. " + f"Missing: {expected - neighbors}. Extra: {neighbors - expected}." + ) # --- Assertions for k=2 (Next-Nearest Neighbors) --- # These should be the diagonal sites. - assert set(lattice.get_neighbors(0, k=2)) == {3} - assert set(lattice.get_neighbors(1, k=2)) == {2} + neighbors = set(lattice.get_neighbors(0, k=2)) + expected = {3} + assert neighbors == expected, ( + f"Failed neighbor check for site 0 (k=2). " + f"Expected {expected}, but got {neighbors}. " + f"Missing: {expected - neighbors}. Extra: {neighbors - expected}." + ) + neighbors = set(lattice.get_neighbors(1, k=2)) + expected = {2} + assert neighbors == expected, ( + f"Failed neighbor check for site 1 (k=2). " + f"Expected {expected}, but got {neighbors}. " + f"Missing: {expected - neighbors}. Extra: {neighbors - expected}." + ) def test_neighbor_pairs(self, simple_square_lattice): """ @@ -477,6 +546,30 @@ def test_customizelattice_max_k_precomputation_and_ondemand(self): f"but found {computed_shells_after}." ) + def test_precompute_neighbors_on_init_custom(self): + """ + Tests that the `precompute_neighbors` argument correctly populates + the neighbor map upon initialization for CustomizeLattice. + """ + coords = [[float(i), float(i)] for i in range(10)] + ids = list(range(10)) + k_to_precompute = 3 + + # Initialize the lattice with the precompute_neighbors argument + lattice = CustomizeLattice( + dimensionality=2, + identifiers=ids, + coordinates=coords, + precompute_neighbors=k_to_precompute, + ) + + # Assert that the internal neighbor map is populated correctly + assert lattice._neighbor_maps is not None + # Check that all shells up to k_to_precompute are present + computed_shells = sorted(list(lattice._neighbor_maps.keys())) + expected_shells = list(range(1, k_to_precompute + 1)) + assert computed_shells == expected_shells + @pytest.fixture def obc_square_lattice() -> SquareLattice: @@ -537,11 +630,29 @@ def test_neighbors_with_open_boundaries(self, obc_square_lattice): edge_idx = 3 # (1, 0, 0) # Assert center site (4) has neighbors 1, 3, 5, 7 - assert set(lattice.get_neighbors(center_idx, k=1)) == {1, 3, 5, 7} + neighbors = set(lattice.get_neighbors(center_idx, k=1)) + expected = {1, 3, 5, 7} + assert neighbors == expected, ( + f"Failed neighbor check for site {center_idx} (k=1). " + f"Expected {expected}, but got {neighbors}. " + f"Missing: {expected - neighbors}. Extra: {neighbors - expected}." + ) # Assert corner site (0) has neighbors 1, 3 - assert set(lattice.get_neighbors(corner_idx, k=1)) == {1, 3} + neighbors = set(lattice.get_neighbors(corner_idx, k=1)) + expected = {1, 3} + assert neighbors == expected, ( + f"Failed neighbor check for site {corner_idx} (k=1). " + f"Expected {expected}, but got {neighbors}. " + f"Missing: {expected - neighbors}. Extra: {neighbors - expected}." + ) # Assert edge site (3) has neighbors 0, 4, 6 - assert set(lattice.get_neighbors(edge_idx, k=1)) == {0, 4, 6} + neighbors = set(lattice.get_neighbors(edge_idx, k=1)) + expected = {0, 4, 6} + assert neighbors == expected, ( + f"Failed neighbor check for site {edge_idx} (k=1). " + f"Expected {expected}, but got {neighbors}. " + f"Missing: {expected - neighbors}. Extra: {neighbors - expected}." + ) def test_neighbors_with_periodic_boundaries(self, pbc_square_lattice): """ @@ -566,8 +677,8 @@ def test_neighbors_with_periodic_boundaries(self, pbc_square_lattice): @pytest.fixture def pbc_honeycomb_lattice() -> HoneycombLattice: - """Provides a 2x2 HoneycombLattice with Periodic Boundary Conditions.""" - return HoneycombLattice(size=(2, 2), pbc=True) + """Provides a 3x3 HoneycombLattice with Periodic Boundary Conditions.""" + return HoneycombLattice(size=(3, 3), pbc=True) class TestHoneycombLattice: @@ -580,7 +691,7 @@ def test_initialization_and_properties(self, pbc_honeycomb_lattice): Tests that the total number of sites is correct for a composite lattice. """ lattice = pbc_honeycomb_lattice - assert lattice.num_sites == 8 + assert lattice.num_sites == 18 assert lattice.num_basis == 2 def test_honeycomb_neighbors(self, pbc_honeycomb_lattice): @@ -594,6 +705,20 @@ def test_honeycomb_neighbors(self, pbc_honeycomb_lattice): site_b_idx = lattice.get_index((0, 0, 1)) assert len(lattice.get_neighbors(site_b_idx, k=1)) == 3 + def test_honeycomb_next_nearest_neighbors(self, pbc_honeycomb_lattice): + """ + Tests that every site in a honeycomb lattice has 6 next-nearest neighbors + under periodic boundary conditions. + """ + lattice = pbc_honeycomb_lattice + # In a PBC honeycomb lattice, every site is equivalent. + # We can just test one site from each basis. + site_a_idx = lattice.get_index((0, 0, 0)) + assert len(lattice.get_neighbors(site_a_idx, k=2)) == 6 + + site_b_idx = lattice.get_index((0, 0, 1)) + assert len(lattice.get_neighbors(site_b_idx, k=2)) == 6 + # --- Tests for TriangularLattice --- @@ -1022,6 +1147,16 @@ def test_init_with_mismatched_shapes_raises_error(self): with pytest.raises(AssertionError, match="Size tuple length mismatch"): SquareLattice(size=(2, 2, 2)) + def test_init_with_mismatched_pbc_raises_error(self): + """ + Tests that TILattice raises AssertionError if the 'pbc' tuple's + length does not match the dimensionality. + This addresses a gap identified in the code review. + """ + with pytest.raises(AssertionError, match="PBC tuple length mismatch"): + # A 2D lattice requires a pbc tuple of length 2, but we provide one of length 1. + SquareLattice(size=(2, 2), pbc=(True,)) + def test_init_with_tuple_pbc(self): """ Tests that TILattice correctly handles a tuple input for the 'pbc' @@ -1078,6 +1213,24 @@ def test_tilattice_max_k_precomputation_and_ondemand( f"but found {computed_shells_after}." ) + def test_precompute_neighbors_on_init(self): + """ + Tests that the `precompute_neighbors` argument correctly populates + the neighbor map upon initialization for a TILattice subclass. + """ + k_to_precompute = 2 + # Initialize a SquareLattice with the precompute_neighbors argument + lattice = SquareLattice( + size=(4, 4), pbc=True, precompute_neighbors=k_to_precompute + ) + + # Assert that the internal neighbor map is populated correctly + assert lattice._neighbor_maps is not None + # Check that all shells up to k_to_precompute are present + computed_shells = sorted(list(lattice._neighbor_maps.keys())) + expected_shells = list(range(1, k_to_precompute + 1)) + assert computed_shells == expected_shells + class TestLongRangeNeighborFinding: """ @@ -1241,6 +1394,31 @@ def test_mixed_boundary_conditions(self): edge_neighbors == expected_edge_indices ), "Failed for edge site with mixed BC." + def test_mixed_boundary_conditions_on_honeycomb(self): + """ + Tests neighbor finding on a HoneycombLattice with mixed PBC. + This ensures the logic correctly handles composite lattices with + anisotropic boundary conditions. + """ + lattice = HoneycombLattice(size=(3, 3), pbc=(True, False)) + + test_site_idx = lattice.get_index((1, 0, 0)) + + neighbors = lattice.get_neighbors(test_site_idx, k=1) + + assert ( + len(neighbors) == 2 + ), "Site on the open boundary has incorrect number of neighbors." + + expected_neighbor_idents = { + (1, 0, 1), + (0, 0, 1), + } + + actual_neighbor_idents = {lattice.get_identifier(i) for i in neighbors} + + assert actual_neighbor_idents == expected_neighbor_idents + class TestAllTILattices: """ @@ -1480,6 +1658,77 @@ def test_neighbor_finding_returns_sorted_list(self, simple_square_lattice): 3, ], "The neighbor list should be sorted in ascending order." + def test_from_lattice_from_empty_lattice(self): + """Tests creating a CustomizeLattice from an empty TILattice.""" + # Arrange: Create an empty TILattice instance. + empty_sq = SquareLattice(size=(0, 0)) + + # Act: Convert it to a CustomizeLattice. + custom_from_empty = CustomizeLattice.from_lattice(empty_sq) + + # Assert: The resulting lattice should also be empty and have the correct properties. + assert isinstance(custom_from_empty, CustomizeLattice) + assert custom_from_empty.num_sites == 0 + assert custom_from_empty.dimensionality == 2 + + def test_add_sites_to_empty_lattice(self): + """Tests adding sites to a previously empty CustomizeLattice.""" + # Arrange: Create an empty CustomizeLattice. + empty_lat = CustomizeLattice(dimensionality=2, identifiers=[], coordinates=[]) + assert empty_lat.num_sites == 0 + + # Act: Add new sites to it. + empty_lat.add_sites( + identifiers=["X", "Y"], coordinates=[[1.0, 1.0], [2.0, 2.0]] + ) + + # Assert: The lattice should now contain the new sites. + assert empty_lat.num_sites == 2 + assert empty_lat.get_identifier(0) == "X" + np.testing.assert_array_equal( + empty_lat.get_coordinates(1), np.array([2.0, 2.0]) + ) + + def test_add_and_remove_empty_list_of_sites(self, initial_lattice): + """ + Tests that calling add_sites and remove_sites with empty lists + is a no-op and doesn't change the lattice state. + """ + # Arrange + lat = initial_lattice + original_num_sites = lat.num_sites + # In backends like JAX, tensors are immutable. We can check object identity. + original_coords_id = id(lat._coordinates) + + # Act 1: Add an empty list of sites. + lat.add_sites(identifiers=[], coordinates=[]) + + # Assert 1: Nothing should have changed. + assert lat.num_sites == original_num_sites + assert id(lat._coordinates) == original_coords_id + + # Act 2: Remove an empty list of sites. + lat.remove_sites(identifiers=[]) + + # Assert 2: Still no changes. + assert lat.num_sites == original_num_sites + assert id(lat._coordinates) == original_coords_id + + def test_remove_all_sites(self, initial_lattice): + """Tests removing all sites from a lattice, resulting in an empty lattice.""" + # Arrange + lat = initial_lattice + # Get all identifiers before removal. + all_ids = list(lat._identifiers) + + # Act + lat.remove_sites(identifiers=all_ids) + + # Assert: The lattice should now be empty. + assert lat.num_sites == 0 + assert len(lat._ident_to_idx) == 0 + assert lat._coordinates.shape[0] == 0 + class TestDistanceMatrix: @@ -1629,43 +1878,66 @@ def test_distance_matrix_invariants_for_all_lattice_types(self, lattice): matrix[off_diagonal_mask] > 1e-9 ), f"Found non-positive off-diagonal elements in distance matrix for {type(lattice).__name__}." + def test_distance_matrix_caching_is_effective(self): + """ + Tests that the distance_matrix property is cached after the first access. + """ + # Arrange: Create a lattice instance. + lattice = CustomizeLattice( + dimensionality=2, + identifiers=["A", "B", "C"], + coordinates=[[0, 0], [1, 0], [0, 1]], + ) + + # Act & Assert + # We patch the internal _compute_distance_matrix method to spy on it. + with patch.object( + lattice, "_compute_distance_matrix", wraps=lattice._compute_distance_matrix + ) as spy_compute: + # Access the property twice. + _ = lattice.distance_matrix + _ = lattice.distance_matrix + + # The computation method should have been called only on the first access. + spy_compute.assert_called_once() + + +@pytest.mark.slow +class TestPerformance: + def test_pbc_implementation_is_not_significantly_slower_than_obc(self): + """ + A performance regression test. + It ensures that the specialized implementation for fully periodic + lattices (pbc=True) is not substantially slower than the general + implementation used for open boundaries (pbc=False). + This test will FAIL with the current code, exposing the performance bug. + """ + # Arrange: Use a large-enough lattice to make performance differences apparent + size = (40, 40) + k = 1 + + # Act 1: Measure the execution time of the general (OBC) implementation + start_time_obc = time.time() + _ = SquareLattice(size=size, pbc=False, precompute_neighbors=k) + duration_obc = time.time() - start_time_obc + + # Act 2: Measure the execution time of the specialized (PBC) implementation + start_time_pbc = time.time() + _ = SquareLattice(size=size, pbc=True, precompute_neighbors=k) + duration_pbc = time.time() - start_time_pbc + + print( + f"\n[Performance] OBC ({size}): {duration_obc:.4f}s | PBC ({size}): {duration_pbc:.4f}s" + ) -# @pytest.mark.slow -# class TestPerformance: -# def test_pbc_implementation_is_not_significantly_slower_than_obc(self): -# """ -# A performance regression test. -# It ensures that the specialized implementation for fully periodic -# lattices (pbc=True) is not substantially slower than the general -# implementation used for open boundaries (pbc=False). -# This test will FAIL with the current code, exposing the performance bug. -# """ -# # Arrange: Use a large-enough lattice to make performance differences apparent -# size = (30, 30) -# k = 1 - -# # Act 1: Measure the execution time of the general (OBC) implementation -# start_time_obc = time.time() -# _ = SquareLattice(size=size, pbc=False, precompute_neighbors=k) -# duration_obc = time.time() - start_time_obc - -# # Act 2: Measure the execution time of the specialized (PBC) implementation -# start_time_pbc = time.time() -# _ = SquareLattice(size=size, pbc=True, precompute_neighbors=k) -# duration_pbc = time.time() - start_time_pbc - -# print( -# f"\n[Performance] OBC ({size}): {duration_obc:.4f}s | PBC ({size}): {duration_pbc:.4f}s" -# ) - -# # Assert: The PBC implementation should not be drastically slower. -# # We allow it to be up to 3 times slower to account for minor overheads, -# # but this will catch the current 10x+ regression. -# # THIS ASSERTION WILL FAIL with the current buggy code. -# assert duration_pbc < duration_obc * 5, ( -# "The specialized PBC implementation is significantly slower " -# "than the general-purpose implementation." -# ) + # Assert: The PBC implementation should not be drastically slower. + # We allow it to be up to 3 times slower to account for minor overheads, + # but this will catch the current 10x+ regression. + # THIS ASSERTION WILL FAIL with the current buggy code. + assert duration_pbc < duration_obc * 5, ( + "The specialized PBC implementation is significantly slower " + "than the general-purpose implementation." + ) def _validate_layers(bonds, layers) -> None: @@ -1748,3 +2020,545 @@ def test_layering_on_edge_cases(): layers_single = get_compatible_layers(single_edge) assert layers_single == [[(0, 1)]] _validate_layers(single_edge, layers_single) + + +def test_layering_on_disconnected_graph(): + """ + Tests that the layering algorithm correctly handles a graph consisting + of multiple disconnected components. + """ + disconnected_bonds = [ + (0, 1), + (1, 2), + (2, 0), + (3, 4), + (4, 5), + (5, 3), + ] + + layers = get_compatible_layers(disconnected_bonds) + + # Assert + _validate_layers(disconnected_bonds, layers) + + assert len(layers) == 3, "Two disconnected triangles should be 3-edge-colorable." + + layer_with_01 = next( + layer for layer in layers if (0, 1) in layer or (1, 0) in layer + ) + assert (3, 4) in layer_with_01 or (4, 3) in layer_with_01 + + +# A map from backend name to the expected tensor type. +BACKEND_TENSOR_MAP = { + "numpy": np.ndarray, + "jax": jnp.ndarray if jnp else None, + "tensorflow": tf.Tensor if tf else None, + "pytorch": torch.Tensor if torch else None, +} + + +@pytest.fixture(scope="function") +def jax_backend_fixture(): + """ + Pytest fixture to set the backend to 'jax' for a test and restore it afterward. + This ensures that differentiability tests run in the correct environment + without interfering with other tests. + """ + original_backend = tc.backend.name + try: + tc.set_backend("jax") + yield + finally: + tc.set_backend(original_backend) + + +class TestBackendIntegration: + """ + Tests to ensure lattice functionalities are consistent and correct + across different computation backends. + """ + + # Define the test cases for parameterization + lattice_test_cases = [ + ( + SquareLattice, + {"size": (3, 3), "pbc": False}, + (0, 4, np.sqrt(2.0)), + "SquareLattice", + ), + ( + CustomizeLattice, + { + "dimensionality": 2, + "identifiers": [0, 1], + "coordinates": [[0.0, 0.0], [1.0, 1.0]], + }, + (0, 1, np.sqrt(2.0)), + "CustomizeLattice", + ), + ( + HoneycombLattice, + {"size": (2, 2), "pbc": True}, + (0, 1, 1.0), + "HoneycombLattice", + ), + ( + RectangularLattice, + {"size": (2, 3), "lattice_constants": (1.0, 2.0), "pbc": False}, + (0, 1, 2.0), + "RectangularLattice", + ), + ] + + @pytest.mark.parametrize("backend_name", ["numpy", "jax", "tensorflow", "pytorch"]) + @pytest.mark.parametrize( + "LatticeClass, init_args, expected_distance_check, name", + lattice_test_cases, + ids=[case[3] for case in lattice_test_cases], + ) + def test_lattice_creation_and_properties_across_backends( + self, + backend_name: str, + LatticeClass: "AbstractLattice", + init_args: dict[str, Any], + expected_distance_check: tuple[int, int, float], + name: str, + ) -> None: + """ + Tests that various lattices can be created with each backend and that their + core properties (_coordinates, distance_matrix) have the correct + tensor types and values. + """ + expected_tensor_type = BACKEND_TENSOR_MAP[backend_name] + if expected_tensor_type is None: + pytest.skip(f"Backend '{backend_name}' not installed.") + + tc.set_backend(backend_name) + + # Create the lattice instance inside the test function + lat = LatticeClass(**init_args) + + # Assert that the internal coordinate and public distance matrix tensors + # have the correct type for the active backend. + assert isinstance( + lat._coordinates, expected_tensor_type + ), f"Failed for {type(lat).__name__} on backend {backend_name}" + assert isinstance( + lat.distance_matrix, expected_tensor_type + ), f"Failed for {type(lat).__name__} on backend {backend_name}" + + # Unpack the distance check information + idx1, idx2, expected_distance = expected_distance_check + + # Extract the actual distance from the matrix and convert to a numpy float + # for a consistent comparison across all backends. + actual_distance = tc.backend.numpy(lat.distance_matrix)[idx1, idx2] + + # Assert that the computed distance matches the expected value. + np.testing.assert_allclose( + actual_distance, + expected_distance, + err_msg=f"Distance check failed for {type(lat).__name__} on backend {backend_name}", + ) + + @pytest.mark.usefixtures("jax_backend_fixture") + @pytest.mark.parametrize( + "lattice_class, init_params, differentiable_arg_name, test_value", + [ + ( + SquareLattice, + {"size": (2, 2)}, + "lattice_constant", + 1.0, + ), + ( + RectangularLattice, + {"size": (2, 2)}, + "lattice_constants", + (1.0, 1.5), + ), + ( + DimerizedChainLattice, + {"size": (3,)}, + "lattice_constant", + 0.8, + ), + ], + ids=["SquareLattice", "RectangularLattice", "DimerizedChainLattice"], + ) + def test_tilattice_differentiability( + self, + lattice_class: type[AbstractLattice], + init_params: dict[str, Any], + differentiable_arg_name: str, + test_value: Any, + ) -> None: + """ + Tests that the distance_matrix of various TILattices is differentiable + with respect to their geometric parameters. This test has been expanded + based on code review feedback to cover more lattice types. + """ + if not jnp: + pytest.skip("JAX backend is required for this differentiability test.") + + def get_total_distance(param: Any) -> Array: + """A scalar-in, scalar-out function for jax.grad.""" + # Dynamically create the lattice with the parameter being differentiated + lat = lattice_class(**init_params, **{differentiable_arg_name: param}) + return tc.backend.sum(lat.distance_matrix) + + # Compute the gradient. The `argnums` parameter is used for tuple inputs. + grad_fn = ( + tc.backend.grad(get_total_distance, argnums=0) + if isinstance(test_value, tuple) + else tc.backend.grad(get_total_distance) + ) + grad_val = grad_fn(test_value) + + assert grad_val is not None + + # For tuple gradients, check that at least one element is non-zero. + if isinstance(grad_val, tuple): + assert any( + not np.isclose(float(g), 0.0) for g in grad_val + ), f"Gradient for {lattice_class.__name__} was all zeros." + else: + assert not np.isclose( + float(grad_val), 0.0 + ), f"Gradient for {lattice_class.__name__} was zero." + + @pytest.mark.usefixtures("jax_backend_fixture") + def test_customizelattice_differentiability(self) -> None: + """ + Tests that the distance_matrix of a CustomizeLattice is differentiable + with respect to its input coordinates. + """ + # This test requires the JAX backend for its grad function. + if not jnp: + pytest.skip("JAX backend is required for this differentiability test.") + + initial_coords = jnp.array([[0.0, 0.0], [1.0, 1.0], [0.5, 0.5]]) + + def get_total_distance_custom(coords: Array) -> Array: + """ + A helper function that takes coordinates, creates a CustomizeLattice, + and returns a scalar value (the sum of its distance matrix). + """ + lat = CustomizeLattice( + dimensionality=2, identifiers=[0, 1, 2], coordinates=coords + ) + return tc.backend.sum(lat.distance_matrix) + + # Compute the gradient of the total distance with respect to the initial coordinates. + grad_tensor = tc.backend.grad(get_total_distance_custom)(initial_coords) + + # Assert that the gradient tensor is not None and is not all zeros. + # A non-zero gradient confirms that the output is indeed differentiable + # with respect to the input coordinates. + assert grad_tensor is not None + assert not np.all(np.isclose(grad_tensor, 0.0)) + + @pytest.mark.usefixtures("jax_backend_fixture") + def test_tilattice_gradient_value_correctness(self) -> None: + """ + Tests that the AD gradient for a TILattice parameter matches the + analytically calculated, correct gradient value. This is a stronger + test than just checking for non-zero gradients. + """ + if not jnp: + pytest.skip("JAX backend is required for this gradient value test.") + + # 1. Define a simple objective function + def get_energy(a: float) -> Array: + """ + A simple energy function for a 2-site chain. + Energy = (distance between site 0 and 1)^2 = a^2 + """ + # Using a 2-site chain, the simplest possible TILattice + lat = ChainLattice(size=(2,), pbc=False, lattice_constant=a) + # The distance matrix for a 2-site chain is [[0, a], [a, 0]] + dist_matrix = lat.distance_matrix + # Sum of squared distances for all unique pairs. There is only one pair (0,1). + return tc.backend.sum(dist_matrix**2) / 2.0 + + # 2. Define the analytical (manually calculated) gradient + def analytical_gradient(a: float) -> float: + """ + The analytical derivative of the energy function E(a) = a^2. + dE/da = 2a + """ + return 2 * a + + # 3. Set up the test + test_lattice_constant = 1.5 + # Compute the gradient using automatic differentiation + ad_grad = tc.backend.grad(get_energy)(test_lattice_constant) + # Compute the expected gradient using our analytical formula + expected_grad = analytical_gradient(test_lattice_constant) + + # 4. Assert that the two gradients are numerically very close + np.testing.assert_allclose( + ad_grad, + expected_grad, + rtol=1e-6, + err_msg="The automatically differentiated gradient does not match the analytical gradient.", + ) + + @pytest.mark.parametrize("backend_name", ["numpy", "jax", "tensorflow", "pytorch"]) + def test_dynamic_modification_across_backends(self, backend_name: str) -> None: + """ + Tests that the dynamic modification methods (add_sites, remove_sites) + of CustomizeLattice work correctly across all supported backends, + specifically checking tensor shapes. + """ + # Arrange: Set up the backend and skip if not installed + expected_tensor_type = BACKEND_TENSOR_MAP[backend_name] + if expected_tensor_type is None: + pytest.skip(f"Backend '{backend_name}' not installed.") + + tc.set_backend(backend_name) + + # --- Initial State --- + # Create a simple lattice with 3 sites + initial_coords = [[0.0, 0.0], [1.0, 0.0], [0.0, 1.0]] + initial_ids = ["A", "B", "C"] + lattice = CustomizeLattice( + dimensionality=2, identifiers=initial_ids, coordinates=initial_coords + ) + assert lattice.num_sites == 3 + assert lattice._coordinates.shape[0] == 3 + + # --- Test add_sites --- + # Act: Add 2 new sites + lattice.add_sites(identifiers=["D", "E"], coordinates=[[1.0, 1.0], [2.0, 0.0]]) + + # Assert: Check new size and tensor shape + assert lattice.num_sites == 5 + assert ( + lattice._coordinates.shape[0] == 5 + ), f"Tensor shape incorrect after add_sites on {backend_name} backend." + + # --- Test remove_sites --- + # Act: Remove 1 site from the modified lattice + lattice.remove_sites(identifiers=["A"]) + + # Assert: Check final size and tensor shape + assert lattice.num_sites == 4 + assert ( + lattice._coordinates.shape[0] == 4 + ), f"Tensor shape incorrect after remove_sites on {backend_name} backend." + + +@pytest.mark.parametrize("backend_name", ["numpy", "jax", "tensorflow", "pytorch"]) +def test_dtype_consistency_across_backends(backend_name: str) -> None: + """ + Tests that the dtype of user-provided coordinate data is preserved + in internal calculations across all backends. + """ + # Arrange: Set up the backend and skip if not installed + if BACKEND_TENSOR_MAP[backend_name] is None: + pytest.skip(f"Backend '{backend_name}' not installed.") + + tc.set_backend(backend_name) + + # A map from backend name to its corresponding float32 dtype object + # Prepare input data with a specific, non-default dtype + coords_float32 = np.array([[0.0, 0.0], [1.0, 2.0]], dtype=np.float32) + + # Act: Create a lattice with this data + lattice = CustomizeLattice( + dimensionality=2, identifiers=[0, 1], coordinates=coords_float32 + ) + + # Assert: Check that the internal tensors have the correct dtype + assert tc.backend.dtype(lattice._coordinates) == "float32", ( + f"Mismatch in coordinate dtype for backend {backend_name}. " + f"Expected 'float32', got {tc.backend.dtype(lattice._coordinates)}." + ) + assert tc.backend.dtype(lattice.distance_matrix) == "float32", ( + f"Mismatch in distance matrix dtype for backend {backend_name}. " + f"Expected 'float32', got {tc.backend.dtype(lattice.distance_matrix)}." + ) + + +class TestPrivateHelpers: + """ + A dedicated test class for private helper methods of AbstractLattice. + This allows for focused testing of internal logic that is critical for + the public-facing API but not directly exposed. + """ + + @pytest.fixture + def simple_lattice_for_helpers(self) -> CustomizeLattice: + """ + Provides a very simple lattice instance, primarily to gain access to + the private helper methods for testing. The geometry itself is trivial. + """ + # A simple 3-site lattice is sufficient to call the helper methods. + return CustomizeLattice( + dimensionality=1, identifiers=[0, 1, 2], coordinates=[[0.0], [1.0], [2.0]] + ) + + def test_identify_distance_shells_basic(self, simple_lattice_for_helpers): + """ + Tests the basic functionality of _identify_distance_shells with a + clear separation between distance shells. + """ + # Arrange + lattice = simple_lattice_for_helpers + # A set of squared distances with clear gaps between them. + all_distances_sq = np.array([0, 1.0, 1.0, 4.0, 4.0, 9.0]) + + # Act + # Call the private helper method directly. + shells = lattice._identify_distance_shells(all_distances_sq, max_k=10) + + # Assert + # The method should identify the unique, non-zero distances. + np.testing.assert_allclose(shells, [1.0, 4.0, 9.0]) + + def test_identify_distance_shells_with_max_k_limit( + self, simple_lattice_for_helpers + ): + """ + Tests that _identify_distance_shells respects the max_k parameter, + limiting the number of returned shells. + """ + # Arrange + lattice = simple_lattice_for_helpers + all_distances_sq = np.array([0, 1.0, 4.0, 9.0, 16.0, 25.0]) + max_k = 3 # We only want the first 3 shells. + + # Act + shells = lattice._identify_distance_shells(all_distances_sq, max_k=max_k) + + # Assert + # The number of shells should be limited by max_k. + assert len(shells) == max_k + # The returned shells should be the first `max_k` smallest distances. + np.testing.assert_allclose(shells, [1.0, 4.0, 9.0]) + + def test_identify_distance_shells_with_tolerance_merging( + self, simple_lattice_for_helpers + ): + """ + Tests that distances that are very close together are merged into a + single shell when the tolerance `tol` is large enough. + """ + # Arrange + lattice = simple_lattice_for_helpers + # Two distances are very close: 1.0 and 1.000001 + all_distances_sq = np.array([0, 1.0, 1.000001, 4.0]) + # A tolerance larger than the difference between the two close distances. + tol = 1e-5 + + # Act + shells = lattice._identify_distance_shells(all_distances_sq, tol=tol, max_k=10) + + # Assert + # Because of the tolerance, 1.0 and 1.000001 should be considered the same shell. + # The method should return the first distance of the merged group (1.0). + np.testing.assert_allclose(shells, [1.0, 4.0]) + + def test_identify_distance_shells_with_tolerance_separation( + self, simple_lattice_for_helpers + ): + """ + Tests that very close distances are correctly identified as separate + shells when the tolerance `tol` is small enough. + """ + # Arrange + lattice = simple_lattice_for_helpers + all_distances_sq = np.array([0, 1.0, 1.000001, 4.0]) + # A tolerance smaller than the difference. + tol = 1e-7 + + # Act + shells = lattice._identify_distance_shells(all_distances_sq, tol=tol, max_k=10) + + # Assert + # With a small tolerance, the two close distances should be treated as distinct shells. + np.testing.assert_allclose(shells, [1.0, 1.000001, 4.0]) + + def test_identify_distance_shells_with_empty_and_zero_input( + self, simple_lattice_for_helpers + ): + """ + Tests that the method handles edge cases like empty arrays or arrays + containing only zeros, returning an empty list. + """ + # Arrange + lattice = simple_lattice_for_helpers + + # Act & Assert for empty input + shells_empty = lattice._identify_distance_shells(np.array([]), max_k=10) + assert ( + shells_empty == [] + ), "Should return an empty list for an empty distance array." + + # Act & Assert for zero-only input + shells_zero = lattice._identify_distance_shells(np.array([0, 0, 0]), max_k=10) + assert ( + shells_zero == [] + ), "Should return an empty list for a distance array with only zeros." + + def test_get_distance_matrix_with_mic(self): + """ + Tests the internal _get_distance_matrix_with_mic method for TILattice + to ensure it correctly applies the Minimum Image Convention. + """ + # --- Test Case 1: Fully Periodic Boundary Conditions (PBC) --- + lattice_pbc = SquareLattice(size=(3, 3), pbc=True, lattice_constant=1.0) + # We need to use the numpy backend for direct comparison + dist_matrix_pbc = tc.backend.numpy(lattice_pbc._get_distance_matrix_with_mic()) + + # For a 3x3 PBC lattice, the distance between opposite edges should be 1. + # Example: site (0,0) and site (2,0) + idx1 = lattice_pbc.get_index((0, 0, 0)) + idx2 = lattice_pbc.get_index((2, 0, 0)) + np.testing.assert_allclose( + dist_matrix_pbc[idx1, idx2], + 1.0, + err_msg="Distance check failed for x-direction in PBC case.", + ) + + # Example: site (0,0) and site (0,2) + idx3 = lattice_pbc.get_index((0, 2, 0)) + np.testing.assert_allclose( + dist_matrix_pbc[idx1, idx3], + 1.0, + err_msg="Distance check failed for y-direction in PBC case.", + ) + + # Example: corner-to-corner distance, e.g., (0,0) to (2,2) + # dx = min(|0-2|, 3-|0-2|) = 1; dy = min(|0-2|, 3-|0-2|) = 1 + # distance = sqrt(1^2 + 1^2) = sqrt(2) + idx4 = lattice_pbc.get_index((2, 2, 0)) + np.testing.assert_allclose( + dist_matrix_pbc[idx1, idx4], + np.sqrt(2.0), + err_msg="Diagonal distance check failed in PBC case.", + ) + + # --- Test Case 2: Mixed Boundary Conditions --- + lattice_mixed = SquareLattice( + size=(3, 3), pbc=(True, False), lattice_constant=1.0 + ) + dist_matrix_mixed = tc.backend.numpy( + lattice_mixed._get_distance_matrix_with_mic() + ) + + # In the periodic x-direction, distance should be 1. + np.testing.assert_allclose( + dist_matrix_mixed[idx1, idx2], + 1.0, + err_msg="Distance check failed for periodic x-direction in mixed BC case.", + ) + + # In the open y-direction, distance should be 2. + np.testing.assert_allclose( + dist_matrix_mixed[idx1, idx3], + 2.0, + err_msg="Distance check failed for open y-direction in mixed BC case.", + ) From 9d2238488d6a4c00d8f8effbfa929d500d4d5a39 Mon Sep 17 00:00:00 2001 From: Stellogic Date: Mon, 11 Aug 2025 12:05:36 +0800 Subject: [PATCH 02/16] fix mypy errors --- tensorcircuit/backends/jax_backend.py | 2 +- tensorcircuit/backends/numpy_backend.py | 1 + tensorcircuit/backends/tensorflow_backend.py | 2 +- 3 files changed, 3 insertions(+), 2 deletions(-) diff --git a/tensorcircuit/backends/jax_backend.py b/tensorcircuit/backends/jax_backend.py index eeaf1d63..33e99188 100644 --- a/tensorcircuit/backends/jax_backend.py +++ b/tensorcircuit/backends/jax_backend.py @@ -418,7 +418,7 @@ def onehot(self, a: Tensor, num: int) -> Tensor: def cumsum(self, a: Tensor, axis: Optional[int] = None) -> Tensor: return jnp.cumsum(a, axis) - def all(self, a: Tensor, axis: Optional[int] = None) -> Tensor: + def all(self, a: Tensor, axis: Optional[Sequence[int]] = None) -> Tensor: return jnp.all(a, axis=axis) def equal(self, x1: Tensor, x2: Tensor) -> Tensor: diff --git a/tensorcircuit/backends/numpy_backend.py b/tensorcircuit/backends/numpy_backend.py index c0b35a1f..a10dffcb 100644 --- a/tensorcircuit/backends/numpy_backend.py +++ b/tensorcircuit/backends/numpy_backend.py @@ -371,6 +371,7 @@ def where( ) -> Tensor: if x is None and y is None: return np.where(condition) + assert x is not None and y is not None return np.where(condition, x, y) def equal(self, x: Tensor, y: Tensor) -> Tensor: diff --git a/tensorcircuit/backends/tensorflow_backend.py b/tensorcircuit/backends/tensorflow_backend.py index 29508d16..e6518419 100644 --- a/tensorcircuit/backends/tensorflow_backend.py +++ b/tensorcircuit/backends/tensorflow_backend.py @@ -528,7 +528,7 @@ def min(self, a: Tensor, axis: Optional[int] = None) -> Tensor: def max(self, a: Tensor, axis: Optional[int] = None) -> Tensor: return tf.reduce_max(a, axis=axis) - def all(self, a: Tensor, axis: Optional[int] = None) -> Tensor: + def all(self, a: Tensor, axis: Optional[Sequence[int]] = None) -> Tensor: return tf.reduce_all(tf.cast(a, tf.bool), axis=axis) def where( From d71d4a1c2f3c4d331af8e5b59323bad2eefcb418 Mon Sep 17 00:00:00 2001 From: Stellogic Date: Tue, 12 Aug 2025 14:33:06 +0800 Subject: [PATCH 03/16] delete all the debug log --- tensorcircuit/templates/lattice.py | 104 ----------------------------- 1 file changed, 104 deletions(-) diff --git a/tensorcircuit/templates/lattice.py b/tensorcircuit/templates/lattice.py index e02c5342..010e5b15 100644 --- a/tensorcircuit/templates/lattice.py +++ b/tensorcircuit/templates/lattice.py @@ -63,9 +63,6 @@ class AbstractLattice(abc.ABC): def __init__(self, dimensionality: int): """Initializes the base lattice class.""" - logger.debug( - f"[DEBUG-LATTICE] Initializing AbstractLattice with dimensionality: {dimensionality}" - ) self._dimensionality = dimensionality # Core data structures for storing site information. @@ -113,17 +110,11 @@ def distance_matrix(self) -> Coordinates: subsequent calls. This computation can be expensive for large lattices. """ if self._distance_matrix is None: - logger.debug("Distance matrix not cached. Computing now...") - logger.debug("[DEBUG-LATTICE] Computing distance matrix...") self._distance_matrix = self._compute_distance_matrix() - logger.debug("[DEBUG-LATTICE] ...distance matrix computed.") return self._distance_matrix def _validate_index(self, index: SiteIndex) -> None: """A private helper to check if a site index is within the valid range.""" - logger.debug( - f"[DEBUG-LATTICE] Validating index: {index} against num_sites: {self.num_sites}" - ) if not (0 <= index < self.num_sites): raise IndexError( f"Site index {index} out of range (0-{self.num_sites - 1})" @@ -141,7 +132,6 @@ def get_coordinates(self, index: SiteIndex) -> Coordinates: self._validate_index(index) assert self._coordinates is not None coords = self._coordinates[index] - logger.debug(f"[DEBUG-LATTICE] get_coordinates for index {index}: {coords}") return coords def get_identifier(self, index: SiteIndex) -> SiteIdentifier: @@ -166,14 +156,9 @@ def get_index(self, identifier: SiteIdentifier) -> SiteIndex: :rtype: SiteIndex """ try: - logger.debug(f"[DEBUG-LATTICE] Getting index for identifier: {identifier}") index = self._ident_to_idx[identifier] - logger.debug(f"[DEBUG-LATTICE] Found index: {index}") return index except KeyError as e: - logger.debug( - f"[DEBUG-LATTICE] Identifier {identifier} not found in _ident_to_idx map." - ) raise ValueError( f"Identifier {identifier} not found in the lattice." ) from e @@ -198,26 +183,14 @@ def get_site_info( - The site's coordinates as a NumPy array. :rtype: Tuple[SiteIndex, SiteIdentifier, Coordinates] """ - logger.debug( - f"[DEBUG-LATTICE] get_site_info called with: {index_or_identifier} (type: {type(index_or_identifier)})" - ) assert self._coordinates is not None if isinstance(index_or_identifier, int): # SiteIndex is an int idx = index_or_identifier self._validate_index(idx) - logger.debug( - f"[DEBUG-LATTICE] Identified as SiteIndex. Returning info for index {idx}." - ) return idx, self._identifiers[idx], self._coordinates[idx] else: ident = index_or_identifier - logger.debug( - f"[DEBUG-LATTICE] Identified as SiteIdentifier. Looking up index for {ident}." - ) idx = self.get_index(ident) - logger.debug( - f"[DEBUG-LATTICE] Returning info for identifier {ident} (index {idx})." - ) return idx, ident, self._coordinates[idx] def sites(self) -> Iterator[Tuple[SiteIndex, SiteIdentifier, Coordinates]]: @@ -230,7 +203,6 @@ def sites(self) -> Iterator[Tuple[SiteIndex, SiteIdentifier, Coordinates]]: index, identifier, and coordinates. :rtype: Iterator[Tuple[SiteIndex, SiteIdentifier, Coordinates]] """ - logger.debug("[DEBUG-LATTICE] Creating sites iterator.") for i in range(self.num_sites): assert self._coordinates is not None yield i, self._identifiers[i], self._coordinates[i] @@ -249,14 +221,10 @@ def get_neighbors(self, index: SiteIndex, k: int = 1) -> List[SiteIndex]: pre-calculated or if the site has no such neighbors. :rtype: List[SiteIndex] """ - logger.debug(f"[DEBUG-LATTICE] Getting neighbors for index {index}, k={k}") if k not in self._neighbor_maps: logger.info( f"Neighbors for k={k} not pre-computed. Building now up to max_k={k}." ) - logger.debug( - f"[DEBUG-LATTICE] Neighbor map for k={k} not found. Triggering _build_neighbors(max_k={k})." - ) self._build_neighbors(max_k=k) if k not in self._neighbor_maps: @@ -282,16 +250,10 @@ def get_neighbor_pairs( :rtype: List[Tuple[SiteIndex, SiteIndex]] """ - logger.debug( - f"[DEBUG-LATTICE] Getting neighbor pairs for k={k}, unique={unique}" - ) if k not in self._neighbor_maps: logger.info( f"Neighbor pairs for k={k} not pre-computed. Building now up to max_k={k}." ) - logger.debug( - f"[DEBUG-LATTICE] Neighbor map for k={k} not found. Triggering _build_neighbors(max_k={k})." - ) self._build_neighbors(max_k=k) if k not in self._neighbor_maps: @@ -325,7 +287,6 @@ def get_all_pairs(self) -> List[Tuple[SiteIndex, SiteIndex]]: :return: A list of tuples, where each tuple is a unique pair of site indices. :rtype: List[Tuple[SiteIndex, SiteIndex]] """ - logger.debug("[DEBUG-LATTICE] Getting all unique pairs of sites.") if self.num_sites < 2: return [] # Use itertools.combinations to efficiently generate all unique pairs (i, j) with i < j. @@ -402,12 +363,6 @@ def show( :param kwargs: Additional keyword arguments to be passed directly to the `matplotlib.pyplot.scatter` function for customizing site appearance. """ - logger.debug( - ( - f"[DEBUG-LATTICE] show() called with: show_indices={show_indices}, " - f"show_identifiers={show_identifiers}, show_bonds_k={show_bonds_k}" - ) - ) try: import matplotlib.pyplot as plt except ImportError: @@ -432,7 +387,6 @@ def show( if ax is None: # If no Axes object is provided, create a new figure and axes. - logger.debug("[DEBUG-LATTICE] `ax` is None, creating new figure.") fig_created_internally = True if self.dimensionality == 3: fig = plt.figure(figsize=(8, 8)) @@ -440,7 +394,6 @@ def show( else: fig, ax = plt.subplots(figsize=(8, 8)) else: - logger.debug("[DEBUG-LATTICE] Using provided `ax` object.") fig = ax.figure # type: ignore coords = np.array(self._coordinates) @@ -456,7 +409,6 @@ def show( ax.scatter(coords[:, 0], coords[:, 1], coords[:, 2], **scatter_args) # type: ignore if show_indices or show_identifiers: - logger.debug("[DEBUG-LATTICE] Drawing site labels (indices/identifiers).") for i in range(self.num_sites): label = str(self._identifiers[i]) if show_identifiers else str(i) # Calculate a small offset for placing text labels to avoid overlap with sites. @@ -486,7 +438,6 @@ def show( ) if show_bonds_k is not None: - logger.debug(f"[DEBUG-LATTICE] Drawing bonds for k={show_bonds_k}.") if show_bonds_k not in self._neighbor_maps: logger.warning( f"Cannot draw bonds. k={show_bonds_k} neighbors have not been calculated." @@ -563,18 +514,12 @@ def _identify_distance_shells( :return: A sorted list of squared distances representing the shells. :rtype: List[float] """ - logger.debug( - f"[DEBUG-LATTICE] Identifying up to {max_k} distance shells with tolerance {tol}." - ) # A small threshold to filter out zero distances (site to itself). ZERO_THRESHOLD_SQ = 1e-12 all_distances_sq = backend.convert_to_tensor(all_distances_sq) # Now, the .size call below is guaranteed to be safe. if backend.sizen(all_distances_sq) == 0: - logger.debug( - "[DEBUG-LATTICE] No non-zero distances found, returning empty shells." - ) return [] # Filter out self-distances and sort the remaining squared distances. @@ -583,9 +528,6 @@ def _identify_distance_shells( ) if backend.sizen(sorted_dist) == 0: - logger.debug( - "[DEBUG-LATTICE] Sorted distances are empty, returning empty shells." - ) return [] dist_shells = [sorted_dist[0]] @@ -596,9 +538,6 @@ def _identify_distance_shells( if backend.sqrt(d_sq) - backend.sqrt(dist_shells[-1]) > tol: dist_shells.append(d_sq) - logger.debug( - f"[DEBUG-LATTICE] Identified distance shells (squared): {dist_shells}" - ) return dist_shells def _build_neighbors_by_distance_matrix( @@ -615,9 +554,6 @@ def _build_neighbors_by_distance_matrix( comparisons. Defaults to 1e-6. :type tol: float, optional """ - logger.debug( - f"[DEBUG-LATTICE] Building neighbors via distance matrix up to max_k={max_k}." - ) if self.num_sites < 2: return @@ -719,7 +655,6 @@ def __init__( precompute_neighbors: Optional[int] = None, ): """Initializes the Translationally Invariant Lattice.""" - logger.debug(f"[DEBUG-LATTICE] Initializing TILattice: {size}, pbc={pbc}") super().__init__(dimensionality) self.lattice_vectors = backend.convert_to_tensor(lattice_vectors) @@ -751,7 +686,6 @@ def _build_lattice(self) -> None: """ Generates all site information for the periodic lattice in a vectorized manner. """ - logger.debug("[DEBUG-LATTICE] Starting _build_lattice for TILattice.") ranges = [backend.arange(s) for s in self.size] # Generate a grid of all integer unit cell coordinates. @@ -792,19 +726,12 @@ def _build_lattice(self) -> None: self._ident_to_idx[identifier] = current_index current_index += 1 - logger.debug( - f"[DEBUG-LATTICE] Finished _build_lattice. Total sites: {self.num_sites}" - ) - def _get_distance_matrix_with_mic(self) -> Coordinates: """ Computes the full N x N distance matrix using backend operations, correctly applying the Minimum Image Convention (MIC) for all periodic dimensions in a memory-efficient manner. """ - logger.debug( - "[DEBUG-LATTICE] Computing distance matrix with Minimum Image Convention." - ) size_arr = backend.convert_to_tensor(self.size) size_arr = backend.cast(size_arr, self.lattice_vectors.dtype) @@ -820,9 +747,6 @@ def _get_distance_matrix_with_mic(self) -> Coordinates: [1, self.dimensionality], dtype=self.lattice_vectors.dtype ) else: - logger.debug( - f"[DEBUG-LATTICE] Applying MIC for periodic dimensions: {pbc_dims}" - ) num_pbc_dims = len(pbc_dims) pbc_system_vectors = backend.gather1d( system_vectors, backend.convert_to_tensor(pbc_dims) @@ -864,7 +788,6 @@ def _get_distance_matrix_with_mic_vectorized(self) -> Coordinates: Computes the full N x N distance matrix using a fully vectorized approach to be compatible with JIT compilation (e.g., JAX). """ - logger.debug("[DEBUG-LATTICE] Computing distance matrix with MIC (vectorized).") size_arr = backend.cast( backend.convert_to_tensor(self.size), self.lattice_vectors.dtype ) @@ -927,9 +850,6 @@ def _build_neighbors(self, max_k: int = 2, **kwargs: Any) -> None: - ``tol`` (float): The numerical tolerance used to determine if two distances are equal when identifying shells. Defaults to 1e-6. """ - logger.debug( - f"[DEBUG-LATTICE] Building neighbors for TILattice up to max_k={max_k}." - ) tol = kwargs.get("tol", 1e-6) dist_matrix = self._get_distance_matrix_with_mic_vectorized() dist_matrix_sq = dist_matrix**2 @@ -1423,9 +1343,6 @@ def __init__( precompute_neighbors: Optional[int] = None, ): """Initializes the CustomizeLattice.""" - logger.debug( - f"[DEBUG-LATTICE] Initializing CustomizeLattice with {len(identifiers)} sites." - ) super().__init__(dimensionality) self._coordinates = backend.convert_to_tensor(coordinates) @@ -1471,9 +1388,6 @@ def _build_neighbors(self, max_k: int = 1, **kwargs: Any) -> None: to ensure differentiability, avoiding non-differentiable libraries like SciPy's KDTree. """ - logger.debug( - f"[DEBUG-LATTICE] Building neighbors for CustomizeLattice up to max_k={max_k}." - ) tol = kwargs.get("tol", 1e-6) if self.num_sites < 2: return @@ -1497,7 +1411,6 @@ def _compute_distance_matrix(self) -> Coordinates: Computes the full N x N distance matrix using backend operations. This implementation is fully differentiable. """ - logger.debug("[DEBUG-LATTICE] Computing distance matrix for CustomizeLattice.") if self.num_sites == 0: return backend.zeros((0, 0)) if self.num_sites < 2: @@ -1521,9 +1434,6 @@ def _compute_distance_matrix(self) -> Coordinates: def _reset_computations(self) -> None: """Resets all cached data that depends on the lattice structure.""" - logger.debug( - "[DEBUG-LATTICE] Resetting cached computations (_neighbor_maps, _distance_matrix)." - ) self._neighbor_maps = {} self._distance_matrix = None @@ -1540,15 +1450,9 @@ def from_lattice(cls, lattice: "AbstractLattice") -> "CustomizeLattice": :return: A new CustomizeLattice instance with the same sites. :rtype: CustomizeLattice """ - logger.debug( - f"[DEBUG-LATTICE] Creating CustomizeLattice from existing lattice: {type(lattice).__name__}" - ) all_sites_info = list(lattice.sites()) if not all_sites_info: - logger.debug( - "[DEBUG-LATTICE] Source lattice is empty, creating an empty CustomizeLattice." - ) return cls( dimensionality=lattice.dimensionality, identifiers=[], coordinates=[] ) @@ -1580,9 +1484,6 @@ def add_sites( :type coordinates: Any """ if not identifiers: - logger.debug( - "[DEBUG-LATTICE] add_sites called with empty identifiers list. No action taken." - ) return new_coords_tensor = backend.convert_to_tensor(coordinates) @@ -1630,9 +1531,6 @@ def remove_sites(self, identifiers: List[SiteIdentifier]) -> None: :type identifiers: List[SiteIdentifier] """ if not identifiers: - logger.debug( - "[DEBUG-LATTICE] remove_sites called with empty identifiers list. No action taken." - ) return ids_to_remove = set(identifiers) @@ -1695,7 +1593,6 @@ def get_compatible_layers(bonds: List[Tuple[int, int]]) -> List[List[Tuple[int, tuple represents a bond. All bonds within a layer are non-overlapping. :rtype: List[List[Tuple[int, int]]] """ - logger.debug(f"[DEBUG-LATTICE] Getting compatible layers for {len(bonds)} bonds.") # Ensure all bonds are in a canonical form (i, j) with i < j and remove duplicates. sorted_edges = sorted(list({(min(bond), max(bond)) for bond in bonds})) @@ -1721,5 +1618,4 @@ def get_compatible_layers(bonds: List[Tuple[int, int]]) -> List[List[Tuple[int, unassigned_edges -= set(current_layer) layers.append(current_layer) - logger.debug(f"[DEBUG-LATTICE] Found {len(layers)} compatible layers.") return layers From bb6559218150386386ed4e278a45838c34fc7dde Mon Sep 17 00:00:00 2001 From: Stellogic Date: Tue, 12 Aug 2025 20:22:24 +0800 Subject: [PATCH 04/16] fix according to the review --- examples/lennard_jones_optimization.py | 84 ++++--- tensorcircuit/backends/abstract_backend.py | 33 +-- tensorcircuit/backends/cupy_backend.py | 4 +- tensorcircuit/backends/jax_backend.py | 18 +- tensorcircuit/backends/numpy_backend.py | 10 +- tensorcircuit/backends/pytorch_backend.py | 17 +- tensorcircuit/backends/tensorflow_backend.py | 19 +- tensorcircuit/templates/lattice.py | 251 ++++++++++++------- tests/test_backends.py | 63 +++++ tests/test_lattice.py | 115 +++------ 10 files changed, 354 insertions(+), 260 deletions(-) diff --git a/examples/lennard_jones_optimization.py b/examples/lennard_jones_optimization.py index 400d1420..9edc63ca 100644 --- a/examples/lennard_jones_optimization.py +++ b/examples/lennard_jones_optimization.py @@ -1,49 +1,68 @@ +""" +Lennard-Jones Potential Optimization Example + +This script demonstrates how to use TensorCircuit's differentiable lattice geometries +to optimize crystal structure. It finds the equilibrium lattice constant that minimizes +the total Lennard-Jones potential energy of a 2D square lattice. + +The optimization showcases the key Task 3 capability: making lattice parameters +differentiable for variational material design. +""" import optax import numpy as np -import jax.numpy as jnp import matplotlib.pyplot as plt -import jax -import tensorcircuit as tc + +# Try to enable JAX 64-bit precision if available (safe fallback) +import jax # noqa: E402 +try: # pragma: no cover - optional optimization + from jax import config as jax_config # type: ignore + + jax_config.update("jax_enable_x64", True) +except Exception: # broad: environment may not have config attribute + pass +import jax.numpy as jnp # noqa: E402 +import tensorcircuit as tc # noqa: E402 -jax.config.update("jax_enable_x64", True) +tc.set_dtype("float64") # Use tc for universal control K = tc.set_backend("jax") -def calculate_potential(log_a, base_distance_matrix, epsilon=0.5, sigma=1.0): +def calculate_potential(log_a, epsilon=0.5, sigma=1.0): """ Calculate the total Lennard-Jones potential energy for a given logarithm of the lattice constant (log_a). + This version creates the lattice inside the function to demonstrate truly differentiable geometry. """ - lattice_constant = jnp.exp(log_a) - d = base_distance_matrix * lattice_constant - d_safe = jnp.where(d > 1e-9, d, 1e-9) - - term12 = (sigma / d_safe) ** 12 - term6 = (sigma / d_safe) ** 6 + lattice_constant = K.exp(log_a) + + # Create lattice with the differentiable parameter + size = (4, 4) # Smaller size for demonstration + lattice = tc.templates.lattice.SquareLattice(size, lattice_constant=lattice_constant, pbc=True) + d = lattice.distance_matrix + + d_safe = K.where(d > 1e-9, d, K.convert_to_tensor(1e-9)) + + term12 = K.power(sigma / d_safe, 12) + term6 = K.power(sigma / d_safe, 6) potential_matrix = 4 * epsilon * (term12 - term6) - num_sites = d.shape[0] - potential_matrix = potential_matrix * ( - 1 - K.eye(num_sites, dtype=potential_matrix.dtype) - ) + num_sites = lattice.num_sites + # Zero out self-interactions (diagonal elements) + eye_mask = K.eye(num_sites, dtype=potential_matrix.dtype) + potential_matrix = potential_matrix * (1 - eye_mask) potential_energy = K.sum(potential_matrix) / 2.0 return potential_energy -# Pre-calculate the base distance matrix (for lattice_constant=1.0) -size = (10, 10) -lat_base = tc.templates.lattice.SquareLattice(size, lattice_constant=1.0, pbc=True) -base_distance_matrix = lat_base.distance_matrix - -# Create a lambda function to pass the base distance matrix to the potential function -potential_fun_for_grad = lambda log_a: calculate_potential(log_a, base_distance_matrix) +# Create a lambda function for optimization +potential_fun_for_grad = lambda log_a: calculate_potential(log_a) value_and_grad_fun = K.jit(K.value_and_grad(potential_fun_for_grad)) optimizer = optax.adam(learning_rate=0.01) -log_a = K.convert_to_tensor(jnp.log(1.1)) +log_a = K.convert_to_tensor(K.log(K.convert_to_tensor(1.1))) opt_state = optimizer.init(log_a) @@ -53,10 +72,11 @@ def calculate_potential(log_a, base_distance_matrix, epsilon=0.5, sigma=1.0): for i in range(200): energy, grad = value_and_grad_fun(log_a) - history["a"].append(jnp.exp(log_a)) + history["a"].append(K.exp(log_a)) history["energy"].append(energy) - if jnp.isnan(grad): + # Check for NaN gradients using TensorCircuit's backend-agnostic approach + if K.sum(tc.num_to_tensor(np.isnan(K.numpy(grad)))) > 0: print(f"Gradient became NaN at iteration {i+1}. Stopping optimization.") print(f"Current energy: {energy}, Current log_a: {log_a}") break @@ -65,26 +85,26 @@ def calculate_potential(log_a, base_distance_matrix, epsilon=0.5, sigma=1.0): log_a = optax.apply_updates(log_a, updates) if (i + 1) % 20 == 0: - current_a = jnp.exp(log_a) + current_a = K.exp(log_a) print( f"Iteration {i+1}/200: Total Energy = {energy:.4f}, Lattice Constant = {current_a:.4f}" ) -final_a = jnp.exp(log_a) -final_energy = calculate_potential(K.convert_to_tensor(log_a), base_distance_matrix) +final_a = K.exp(log_a) +final_energy = calculate_potential(log_a) -if not jnp.isnan(final_energy): +if not np.isnan(K.numpy(final_energy)): print("\nOptimization finished!") print(f"Final optimized lattice constant: {final_a:.6f}") print(f"Corresponding minimum total energy: {final_energy:.6f}") # Vectorized calculation for the potential curve a_vals = np.linspace(0.8, 1.5, 200) - log_a_vals = np.log(a_vals) + log_a_vals = K.log(K.convert_to_tensor(a_vals)) # Use vmap to create a vectorized version of the potential function - vmap_potential = jax.vmap(lambda la: calculate_potential(la, base_distance_matrix)) - potential_curve = vmap_potential(K.convert_to_tensor(log_a_vals)) + vmap_potential = K.vmap(lambda la: calculate_potential(la)) + potential_curve = vmap_potential(log_a_vals) plt.figure(figsize=(10, 6)) plt.plot(a_vals, potential_curve, label="Lennard-Jones Potential", color="blue") diff --git a/tensorcircuit/backends/abstract_backend.py b/tensorcircuit/backends/abstract_backend.py index 9e6de0fe..30580af1 100644 --- a/tensorcircuit/backends/abstract_backend.py +++ b/tensorcircuit/backends/abstract_backend.py @@ -633,7 +633,8 @@ def meshgrid(self: Any, *args: Any, **kwargs: Any) -> Any: :param args: coordinate vectors :type args: Any - :param kwargs: keyword arguments for meshgrid + :param kwargs: keyword arguments for meshgrid, typically includes 'indexing' + which can be 'ij' (matrix indexing) or 'xy' (Cartesian indexing) :type kwargs: Any :return: list of coordinate matrices :rtype: Any @@ -659,21 +660,6 @@ def expand_dims(self: Any, a: Tensor, axis: int) -> Tensor: "Backend '{}' has not implemented `expand_dims`.".format(self.name) ) - def power(self: Any, a: Tensor, b: Union[Tensor, float]) -> Tensor: - """ - First array elements raised to powers from second array, element-wise. - - :param a: The bases - :type a: Tensor - :param b: The exponents - :type b: Union[Tensor, float] - :return: The bases in `a` raised to the powers in `b`. - :rtype: Tensor - """ - raise NotImplementedError( - "Backend '{}' has not implemented `power`.".format(self.name) - ) - def unique_with_counts(self: Any, a: Tensor, **kws: Any) -> Tuple[Tensor, Tensor]: """ Find the unique elements and their corresponding counts of the given tensor ``a``. @@ -1504,21 +1490,6 @@ def where( "Backend '{}' has not implemented `where`.".format(self.name) ) - def equal(self: Any, x1: Tensor, x2: Tensor) -> Tensor: - """ - Return the truth value of (x1 == x2) element-wise. - - :param x1: Input tensor. - :type x1: Tensor - :param x2: Input tensor. - :type x2: Tensor - :return: Output tensor, element-wise truth value of (x1 == x2). - :rtype: Tensor (bool) - """ - raise NotImplementedError( - "Backend '{}' has not implemented `equal`.".format(self.name) - ) - def switch( self: Any, index: Tensor, branches: Sequence[Callable[[], Tensor]] ) -> Tensor: diff --git a/tensorcircuit/backends/cupy_backend.py b/tensorcircuit/backends/cupy_backend.py index 4525cf08..7ee08923 100644 --- a/tensorcircuit/backends/cupy_backend.py +++ b/tensorcircuit/backends/cupy_backend.py @@ -56,10 +56,12 @@ def __init__(self) -> None: cpx = cupyx self.name = "cupy" - def convert_to_tensor(self, a: Tensor) -> Tensor: + def convert_to_tensor(self, a: Tensor, dtype: Optional[str] = None) -> Tensor: if not isinstance(a, cp.ndarray) and not cp.isscalar(a): a = cp.array(a) a = cp.asarray(a) + if dtype is not None: + a = self.cast(a, dtype) return a def sum( diff --git a/tensorcircuit/backends/jax_backend.py b/tensorcircuit/backends/jax_backend.py index 33e99188..c64b9828 100644 --- a/tensorcircuit/backends/jax_backend.py +++ b/tensorcircuit/backends/jax_backend.py @@ -50,12 +50,15 @@ def update(self, grads: pytree, params: pytree) -> pytree: return params -def _convert_to_tensor_jax(self: Any, tensor: Tensor) -> Tensor: +def _convert_to_tensor_jax(self: Any, tensor: Tensor, dtype: Optional[str] = None) -> Tensor: if not isinstance(tensor, (np.ndarray, jnp.ndarray)) and not jnp.isscalar(tensor): raise TypeError( ("Expected a `jnp.array`, `np.array` or scalar. " f"Got {type(tensor)}") ) result = jnp.asarray(tensor) + if dtype is not None: + # Use the backend's cast method to handle dtype conversion + result = self.cast(result, dtype) return result @@ -243,10 +246,10 @@ def zeros(self, shape: Tuple[int, ...], dtype: Optional[str] = None) -> Tensor: def copy(self, tensor: Tensor) -> Tensor: return jnp.array(tensor, copy=True) - def convert_to_tensor(self, tensor: Tensor, **kwargs: Any) -> Tensor: + def convert_to_tensor(self, tensor: Tensor, dtype: Optional[str] = None) -> Tensor: result = jnp.asarray(tensor) - if "dtype" in kwargs and kwargs["dtype"] is not None: - result = self.cast(result, kwargs["dtype"]) + if dtype is not None: + result = self.cast(result, dtype) return result def abs(self, a: Tensor) -> Tensor: @@ -354,10 +357,6 @@ def expm(self, a: Tensor) -> Tensor: return jsp.linalg.expm(a) # currently expm in jax doesn't support AD, it will raise an AssertError, # see https://github.com/google/jax/issues/2645 - - def power(self, a: Tensor, b: Union[Tensor, float]) -> Tensor: - return jnp.power(a, b) - def stack(self, a: Sequence[Tensor], axis: int = 0) -> Tensor: return jnp.stack(a, axis=axis) @@ -421,9 +420,6 @@ def cumsum(self, a: Tensor, axis: Optional[int] = None) -> Tensor: def all(self, a: Tensor, axis: Optional[Sequence[int]] = None) -> Tensor: return jnp.all(a, axis=axis) - def equal(self, x1: Tensor, x2: Tensor) -> Tensor: - return jnp.equal(x1, x2) - def is_tensor(self, a: Any) -> bool: if not isinstance(a, jnp.ndarray): return False diff --git a/tensorcircuit/backends/numpy_backend.py b/tensorcircuit/backends/numpy_backend.py index a10dffcb..ae8d4e14 100644 --- a/tensorcircuit/backends/numpy_backend.py +++ b/tensorcircuit/backends/numpy_backend.py @@ -35,10 +35,12 @@ def _sum_numpy( # see https://github.com/google/TensorNetwork/issues/952 -def _convert_to_tensor_numpy(self: Any, a: Tensor, **kwargs: Any) -> Tensor: +def _convert_to_tensor_numpy(self: Any, a: Tensor, dtype: Optional[str] = None) -> Tensor: if not isinstance(a, np.ndarray) and not np.isscalar(a): a = np.array(a) a = np.asarray(a) + if dtype is not None: + a = a.astype(getattr(np, dtype)) return a @@ -80,9 +82,6 @@ def copy(self, a: Tensor) -> Tensor: def expm(self, a: Tensor) -> Tensor: return expm(a) - def power(self, a: Tensor, b: Union[Tensor, float]) -> Tensor: - return np.power(a, b) - def abs(self, a: Tensor) -> Tensor: return np.abs(a) @@ -374,9 +373,6 @@ def where( assert x is not None and y is not None return np.where(condition, x, y) - def equal(self, x: Tensor, y: Tensor) -> Tensor: - return np.equal(x, y) - def cond( self, pred: bool, diff --git a/tensorcircuit/backends/pytorch_backend.py b/tensorcircuit/backends/pytorch_backend.py index 73b8c4d1..31654046 100644 --- a/tensorcircuit/backends/pytorch_backend.py +++ b/tensorcircuit/backends/pytorch_backend.py @@ -238,14 +238,22 @@ def zeros(self, shape: Tuple[int, ...], dtype: Optional[str] = None) -> Tensor: def copy(self, a: Tensor) -> Tensor: return a.clone() + def convert_to_tensor(self, tensor: Tensor, dtype: Optional[str] = None) -> Tensor: + if self.is_tensor(tensor): + result = tensor + else: + result = torchlib.tensor(tensor) + if dtype is not None: + result = self.cast(result, dtype) + return result + def expm(self, a: Tensor) -> Tensor: raise NotImplementedError("pytorch backend doesn't support expm") # in 2020, torch has no expm, hmmm. but that's ok, # it doesn't support complex numbers which is more severe issue. # see https://github.com/pytorch/pytorch/issues/9983 - def power(self, a: Tensor, b: Union[Tensor, float]) -> Tensor: - return torchlib.pow(a, b) + # see https://github.com/pytorch/pytorch/issues/9983 def sin(self, a: Tensor) -> Tensor: return torchlib.sin(a) @@ -449,11 +457,6 @@ def where( return torchlib.where(condition) return torchlib.where(condition, x, y) - def equal(self, x1: Tensor, x2: Any) -> Tensor: - if not self.is_tensor(x2): - x2 = torchlib.tensor(x2, device=x1.device, dtype=x1.dtype) - return torchlib.eq(x1, x2) - def reverse(self, a: Tensor) -> Tensor: return torchlib.flip(a, dims=(-1,)) diff --git a/tensorcircuit/backends/tensorflow_backend.py b/tensorcircuit/backends/tensorflow_backend.py index e6518419..9d33a4fd 100644 --- a/tensorcircuit/backends/tensorflow_backend.py +++ b/tensorcircuit/backends/tensorflow_backend.py @@ -75,7 +75,12 @@ def update(self, grads: pytree, params: pytree) -> pytree: def _tensordot_tf( self: Any, a: Tensor, b: Tensor, axes: Union[int, Sequence[Sequence[int]]] ) -> Tensor: - b = tf.cast(b, a.dtype) + # Use TensorFlow's dtype promotion rules by converting both to a common dtype + if a.dtype != b.dtype: + # Find the result dtype by performing a dummy operation + common_dtype = (tf.constant(0, dtype=a.dtype) + tf.constant(0, dtype=b.dtype)).dtype + a = tf.cast(a, common_dtype) + b = tf.cast(b, common_dtype) return tf.tensordot(a, b, axes) @@ -442,12 +447,15 @@ def zeros(self, shape: Tuple[int, ...], dtype: Optional[str] = None) -> Tensor: def copy(self, a: Tensor) -> Tensor: return tf.identity(a) + def convert_to_tensor(self, tensor: Tensor, dtype: Optional[str] = None) -> Tensor: + result = tf.convert_to_tensor(tensor) + if dtype is not None: + result = self.cast(result, dtype) + return result + def expm(self, a: Tensor) -> Tensor: return tf.linalg.expm(a) - def power(self, a: Tensor, b: Union[Tensor, float]) -> Tensor: - return tf.math.pow(a, b) - def sin(self, a: Tensor) -> Tensor: return tf.math.sin(a) @@ -542,9 +550,6 @@ def where( return tuple(tf.unstack(tf.where(condition), axis=1)) return tf.where(condition, x, y) - def equal(self, x1: Tensor, x2: Tensor) -> Tensor: - return tf.math.equal(x1, x2) - def argmax(self, a: Tensor, axis: int = 0) -> Tensor: return tf.math.argmax(a, axis=axis) diff --git a/tensorcircuit/templates/lattice.py b/tensorcircuit/templates/lattice.py index 010e5b15..1f72f569 100644 --- a/tensorcircuit/templates/lattice.py +++ b/tensorcircuit/templates/lattice.py @@ -21,6 +21,7 @@ logger = logging.getLogger(__name__) import itertools import numpy as np +from scipy.spatial import cKDTree from .. import backend @@ -365,6 +366,7 @@ def show( """ try: import matplotlib.pyplot as plt + import numpy as np except ImportError: logger.error( "Matplotlib is required for visualization. " @@ -563,7 +565,7 @@ def _build_neighbors_by_distance_matrix( displacements = backend.expand_dims(all_coords, 1) - backend.expand_dims( all_coords, 0 ) - dist_matrix_sq = backend.sum(backend.power(displacements, 2), axis=-1) + dist_matrix_sq = backend.sum(displacements**2, axis=-1) # Flatten the matrix to a list of all squared distances to identify shells. all_distances_sq = backend.reshape(dist_matrix_sq, [-1]) @@ -726,67 +728,28 @@ def _build_lattice(self) -> None: self._ident_to_idx[identifier] = current_index current_index += 1 - def _get_distance_matrix_with_mic(self) -> Coordinates: - """ - Computes the full N x N distance matrix using backend operations, - correctly applying the Minimum Image Convention (MIC) for all - periodic dimensions in a memory-efficient manner. - """ - - size_arr = backend.convert_to_tensor(self.size) - size_arr = backend.cast(size_arr, self.lattice_vectors.dtype) - - # Calculate the full system vectors that span the entire finite lattice. - system_vectors = self.lattice_vectors * backend.expand_dims(size_arr, axis=1) - - pbc_dims = [d for d in range(self.dimensionality) if self.pbc[d]] - - if not pbc_dims: - # If no PBC, the only 'translation' is the zero vector. - translations_arr = backend.zeros( - [1, self.dimensionality], dtype=self.lattice_vectors.dtype - ) - else: - num_pbc_dims = len(pbc_dims) - pbc_system_vectors = backend.gather1d( - system_vectors, backend.convert_to_tensor(pbc_dims) - ) - - # Generate all 3^d possible image shifts (-1, 0, 1) for periodic dimensions. - shift_options = [backend.convert_to_tensor([-1.0, 0.0, 1.0])] * num_pbc_dims - shifts_grid = backend.meshgrid(*shift_options, indexing="ij") - all_shifts = backend.reshape( - backend.stack(shifts_grid, axis=-1), (-1, num_pbc_dims) - ) - - translations_arr = backend.tensordot( - all_shifts, pbc_system_vectors, axes=[[1], [0]] - ) - - dist_sq_rows = [] - # Iterate through each site `i` to compute its distance to all other sites `j`. - # This is done row-by-row to manage memory for very large lattices. - assert self._coordinates is not None - for i in range(self.num_sites): - # For each site `i`, calculate displacements to all other sites `j`. - displacements_i = self._coordinates - self._coordinates[i] # Shape: (N, D) - # Then, for each displacement `d_ij`, find the minimum distance among - # `d_ij` and all its periodic images. - image_displacements_i = backend.expand_dims( - displacements_i, 1 - ) - backend.expand_dims(translations_arr, 0) - image_d_sq_i = backend.sum(image_displacements_i**2, axis=2) - min_dist_sq_i = backend.min(image_d_sq_i, axis=1) - dist_sq_rows.append(min_dist_sq_i) - - dist_matrix_sq = backend.stack(dist_sq_rows, axis=0) - safe_dist_matrix_sq = backend.where(dist_matrix_sq > 0, dist_matrix_sq, 0.0) - return backend.sqrt(safe_dist_matrix_sq) - def _get_distance_matrix_with_mic_vectorized(self) -> Coordinates: """ Computes the full N x N distance matrix using a fully vectorized approach - to be compatible with JIT compilation (e.g., JAX). + that correctly applies the Minimum Image Convention (MIC) for periodic + boundary conditions. + + This method uses full vectorization for optimal performance and compatibility + with JIT compilation frameworks like JAX. The implementation processes all + site pairs simultaneously rather than iterating row-by-row, which provides: + + - Better performance through vectorized operations + - Full compatibility with automatic differentiation + - JIT compilation support (e.g., JAX, TensorFlow) + - Consistent tensor operations throughout + + The trade-off is higher memory usage compared to iterative approaches, + as it computes all pairwise distances simultaneously. For very large + lattices (N > 10^4 sites), memory usage scales as O(N^2). + + :return: Distance matrix with shape (N, N) where entry (i,j) is the + minimum distance between sites i and j under periodic boundary conditions. + :rtype: Coordinates """ size_arr = backend.cast( backend.convert_to_tensor(self.size), self.lattice_vectors.dtype @@ -825,7 +788,7 @@ def _get_distance_matrix_with_mic_vectorized(self) -> Coordinates: ) - backend.expand_dims(backend.expand_dims(translations_arr, 0), 0) # Sum of squares for distances - image_d_sq = backend.sum(backend.power(image_displacements, 2), axis=3) + image_d_sq = backend.sum(image_displacements**2, axis=3) # Find the minimum distance among all images (Minimum Image Convention) min_dist_sq = backend.min(image_d_sq, axis=2) @@ -1382,30 +1345,147 @@ def _build_lattice(self, *args: Any, **kwargs: Any) -> None: def _build_neighbors(self, max_k: int = 1, **kwargs: Any) -> None: """ - Calculates neighbor relationships using a distance matrix. + Calculates neighbor relationships using either KDTree or distance matrix methods. - This method leverages the generic `_build_neighbors_by_distance_matrix` - to ensure differentiability, avoiding non-differentiable libraries - like SciPy's KDTree. + This method supports two modes: + 1. KDTree mode (use_kdtree=True): Fast, O(N log N) performance for large lattices + but breaks differentiability due to scipy dependency + 2. Distance matrix mode (use_kdtree=False): Slower O(N²) but fully differentiable + and backend-agnostic + + :param max_k: Maximum number of neighbor shells to compute + :type max_k: int + :param kwargs: Additional arguments including: + - use_kdtree (bool): Whether to use KDTree optimization. Defaults to True. + - tol (float): Distance tolerance for neighbor identification. Defaults to 1e-6. + - force_differentiable (bool): If True, forces distance matrix method even when + KDTree is available. Defaults to False. """ tol = kwargs.get("tol", 1e-6) + use_kdtree = kwargs.get("use_kdtree", True) + force_differentiable = kwargs.get("force_differentiable", False) + if self.num_sites < 2: return - # For CustomizeLattice, we must use the distance matrix method. - dist_matrix = self._compute_distance_matrix() - dist_matrix_sq = dist_matrix**2 - self._distance_matrix = dist_matrix - - all_distances_sq = backend.reshape(dist_matrix_sq, [-1]) - dist_shells_sq = self._identify_distance_shells(all_distances_sq, max_k, tol) - - self._neighbor_maps = self._build_neighbor_map_from_distances( - dist_matrix_sq, dist_shells_sq, tol - ) + # Override KDTree if differentiability is explicitly required + if force_differentiable: + use_kdtree = False + logger.info("Using differentiable distance matrix method (forced)") + + # Choose algorithm based on user preference + if use_kdtree and not force_differentiable: + logger.info(f"Using KDTree method for {self.num_sites} sites up to k={max_k}") + self._build_neighbors_kdtree(max_k, tol) + else: + logger.info(f"Using differentiable distance matrix method for {self.num_sites} sites up to k={max_k}") + + # Use the existing distance matrix method + self._build_neighbors_by_distance_matrix(max_k, tol) logger.info(f"Neighbor building complete for CustomizeLattice up to k={max_k}.") + def _build_neighbors_kdtree(self, max_k: int, tol: float) -> None: + """ + Build neighbors using KDTree for optimal performance. + + This method provides O(N log N) performance for neighbor finding but breaks + differentiability due to scipy dependency. Use this method when: + - Performance is critical + - Differentiability is not required + - Large lattices (N > 1000) + + Note: This method uses numpy arrays directly and may not be compatible + with all backend types (JAX, TensorFlow, etc.). + """ + # Convert coordinates to numpy for KDTree + coords_np = backend.numpy(self._coordinates) + + # Build KDTree + logger.info("Building KDTree...") + tree = cKDTree(coords_np) + + # For small lattices or cases with potential duplicate coordinates, + # fall back to distance matrix method for robustness + if self.num_sites < 1000: + logger.info("Small lattice detected, falling back to distance matrix method for robustness") + self._build_neighbors_by_distance_matrix(max_k, tol) + return + + # Find all distances for shell identification - use comprehensive sampling + logger.info("Identifying distance shells...") + distances_for_shells = [] + + # For robust shell identification, query all pairwise distances for smaller lattices + # or use dense sampling for larger ones + if self.num_sites <= 100: + # For small lattices, compute all pairwise distances for accuracy + for i in range(self.num_sites): + query_k = min(self.num_sites - 1, max_k * 20) + if query_k > 0: + dists, _ = tree.query(coords_np[i], k=query_k + 1) # +1 to exclude self + distances_for_shells.extend(dists[1:]) # Skip distance to self + else: + # For larger lattices, use adaptive sampling but ensure we capture all shells + sample_size = min(1000, self.num_sites // 2) # More conservative sampling + for i in range(0, self.num_sites, max(1, self.num_sites // sample_size)): + query_k = min(max_k * 20 + 50, self.num_sites - 1) + if query_k > 0: + dists, _ = tree.query(coords_np[i], k=query_k + 1) # +1 to exclude self + distances_for_shells.extend(dists[1:]) # Skip distance to self + + # Filter out zero distances (duplicate coordinates) before shell identification + ZERO_THRESHOLD = 1e-12 + distances_for_shells = [d for d in distances_for_shells if d > ZERO_THRESHOLD] + + if not distances_for_shells: + logger.warning("No valid distances found for shell identification") + self._neighbor_maps = {} + return + + # Use the same shell identification logic as distance matrix method + distances_for_shells_sq = [d*d for d in distances_for_shells] + dist_shells_sq = self._identify_distance_shells(distances_for_shells_sq, max_k, tol) + dist_shells = [np.sqrt(d_sq) for d_sq in dist_shells_sq] + + logger.info(f"Found {len(dist_shells)} distance shells: {dist_shells[:5]}...") + + # Initialize neighbor maps + self._neighbor_maps = {k: {} for k in range(1, len(dist_shells) + 1)} + + # Build neighbor lists for each site + logger.info("Building neighbor lists...") + for i in range(self.num_sites): + # Query enough neighbors to capture all shells + query_k = min(max_k * 20 + 50, self.num_sites - 1) + if query_k > 0: + distances, indices = tree.query(coords_np[i], k=query_k + 1) # +1 for self + + # Skip the first entry (distance to self) + distances = distances[1:] + indices = indices[1:] + + # Filter out zero distances (duplicate coordinates) + valid_pairs = [(d, idx) for d, idx in zip(distances, indices) if d > ZERO_THRESHOLD] + + # Assign neighbors to shells + for shell_idx, shell_dist in enumerate(dist_shells): + k = shell_idx + 1 + shell_neighbors = [] + + for dist, neighbor_idx in valid_pairs: + if abs(dist - shell_dist) <= tol: + shell_neighbors.append(int(neighbor_idx)) + elif dist > shell_dist + tol: + break # Distances are sorted, no more matches + + if shell_neighbors: + self._neighbor_maps[k][i] = sorted(shell_neighbors) + + # Set distance matrix to None - will compute on demand + self._distance_matrix = None + logger.info("KDTree neighbor building completed") + def _compute_distance_matrix(self) -> Coordinates: """ Computes the full N x N distance matrix using backend operations. @@ -1424,10 +1504,10 @@ def _compute_distance_matrix(self) -> Coordinates: self._coordinates, 0 ) - dist_matrix_sq = backend.sum(backend.power(displacements, 2), axis=-1) + dist_matrix_sq = backend.sum(displacements**2, axis=-1) return backend.where( - backend.equal(dist_matrix_sq, 0), + dist_matrix_sq == 0, 0, backend.sqrt(dist_matrix_sq), ) @@ -1593,29 +1673,24 @@ def get_compatible_layers(bonds: List[Tuple[int, int]]) -> List[List[Tuple[int, tuple represents a bond. All bonds within a layer are non-overlapping. :rtype: List[List[Tuple[int, int]]] """ - # Ensure all bonds are in a canonical form (i, j) with i < j and remove duplicates. - sorted_edges = sorted(list({(min(bond), max(bond)) for bond in bonds})) + uncolored_edges: Set[Tuple[int, int]] = {(min(bond), max(bond)) for bond in bonds} layers: List[List[Tuple[int, int]]] = [] - unassigned_edges = set(sorted_edges) - # Greedily build layers until all edges have been assigned. - while unassigned_edges: + while uncolored_edges: current_layer: List[Tuple[int, int]] = [] qubits_in_this_layer: Set[int] = set() - sorted_unassigned = sorted(list(unassigned_edges)) + edges_to_process = sorted(list(uncolored_edges)) - # Iterate through remaining edges and add an edge to the current layer - # if it doesn't conflict with (share a qubit with) edges already in the layer. - for edge in sorted_unassigned: + for edge in edges_to_process: i, j = edge if i not in qubits_in_this_layer and j not in qubits_in_this_layer: current_layer.append(edge) qubits_in_this_layer.add(i) qubits_in_this_layer.add(j) - unassigned_edges -= set(current_layer) - layers.append(current_layer) + uncolored_edges -= set(current_layer) + layers.append(sorted(current_layer)) return layers diff --git a/tests/test_backends.py b/tests/test_backends.py index a681a184..6aeb94d8 100644 --- a/tests/test_backends.py +++ b/tests/test_backends.py @@ -1174,3 +1174,66 @@ def fb2(*args): np.testing.assert_allclose(ya, tc.backend.transpose(yb, [1, 0, 2]), atol=1e-5) np.testing.assert_allclose(ya, yajit, atol=1e-5) np.testing.assert_allclose(yajit, tc.backend.transpose(ybjit, [1, 0, 2]), atol=1e-5) + + +# Test new backend methods added for differentiable lattice support + +@pytest.mark.parametrize("backend", [lf("npb"), lf("tfb"), lf("jaxb"), lf("torchb")]) +def test_backend_sort(backend): + """Test sort method.""" + a = tc.backend.convert_to_tensor([3, 1, 4, 1, 5]) + result = tc.backend.sort(a) + expected = tc.backend.convert_to_tensor([1, 1, 3, 4, 5]) + np.testing.assert_allclose(result, expected) + + +@pytest.mark.parametrize("backend", [lf("npb"), lf("tfb"), lf("jaxb"), lf("torchb")]) +def test_backend_all(backend): + """Test all method.""" + # Test all True + a = tc.backend.convert_to_tensor([True, True, True]) + result = tc.backend.all(a) + assert result == True + + # Test with False + b = tc.backend.convert_to_tensor([True, False, True]) + result = tc.backend.all(b) + assert result == False + + +@pytest.mark.parametrize("backend", [lf("npb"), lf("tfb"), lf("jaxb"), lf("torchb")]) +def test_backend_meshgrid(backend): + """Test meshgrid method.""" + x = tc.backend.convert_to_tensor([1, 2]) + y = tc.backend.convert_to_tensor([3, 4]) + xx, yy = tc.backend.meshgrid(x, y, indexing='ij') + + expected_xx = tc.backend.convert_to_tensor([[1, 1], [2, 2]]) + expected_yy = tc.backend.convert_to_tensor([[3, 4], [3, 4]]) + + np.testing.assert_allclose(xx, expected_xx) + np.testing.assert_allclose(yy, expected_yy) + + +@pytest.mark.parametrize("backend", [lf("npb"), lf("tfb"), lf("jaxb"), lf("torchb")]) +def test_backend_expand_dims(backend): + """Test expand_dims method.""" + a = tc.backend.convert_to_tensor([[1, 2], [3, 4]]) + result = tc.backend.expand_dims(a, axis=0) + assert result.shape == (1, 2, 2) + + result = tc.backend.expand_dims(a, axis=1) + assert result.shape == (2, 1, 2) + + + + +@pytest.mark.parametrize("backend", [lf("npb"), lf("tfb"), lf("jaxb"), lf("torchb")]) +def test_backend_where(backend): + """Test where method.""" + condition = tc.backend.convert_to_tensor([True, False, True]) + x = tc.backend.convert_to_tensor([1, 2, 3]) + y = tc.backend.convert_to_tensor([4, 5, 6]) + result = tc.backend.where(condition, x, y) + expected = tc.backend.convert_to_tensor([1, 5, 3]) + np.testing.assert_allclose(result, expected) diff --git a/tests/test_lattice.py b/tests/test_lattice.py index 39f064bd..05accbb3 100644 --- a/tests/test_lattice.py +++ b/tests/test_lattice.py @@ -1,23 +1,9 @@ -from __future__ import annotations from unittest.mock import patch -import logging -import sys - -# Configure logging for debugging purposes -logging.basicConfig( - level=logging.DEBUG, - stream=sys.stdout, - format="[%(levelname)s] %(asctime)s - %(name)s:%(lineno)d - %(message)s", -) -logger = logging.getLogger(__name__) -from unittest.mock import patch -from typing import TYPE_CHECKING, Any - import time +import logging import matplotlib - try: import jax.numpy as jnp except ImportError: @@ -31,16 +17,8 @@ except ImportError: torch = None -if TYPE_CHECKING: - import jax.numpy as jnp - from jax import Array - import tensorflow as tf - import torch - - matplotlib.use("Agg") - import pytest import numpy as np @@ -56,14 +34,13 @@ RectangularLattice, SquareLattice, TriangularLattice, - AbstractLattice, get_compatible_layers, ) import tensorcircuit as tc @pytest.fixture -def simple_square_lattice() -> CustomizeLattice: +def simple_square_lattice(): """ Provides a simple 2x2 square CustomizeLattice instance for neighbor tests. The sites are indexed as follows: @@ -80,7 +57,7 @@ def simple_square_lattice() -> CustomizeLattice: @pytest.fixture -def kagome_lattice_fragment() -> CustomizeLattice: +def kagome_lattice_fragment(): """ Pytest fixture to provide a standard CustomizeLattice instance. This represents the Kagome fragment from the project requirements, @@ -572,13 +549,13 @@ def test_precompute_neighbors_on_init_custom(self): @pytest.fixture -def obc_square_lattice() -> SquareLattice: +def obc_square_lattice(): """Provides a 3x3 SquareLattice with Open Boundary Conditions.""" return SquareLattice(size=(3, 3), pbc=False) @pytest.fixture -def pbc_square_lattice() -> SquareLattice: +def pbc_square_lattice(): """Provides a 3x3 SquareLattice with Periodic Boundary Conditions.""" return SquareLattice(size=(3, 3), pbc=True) @@ -676,7 +653,7 @@ def test_neighbors_with_periodic_boundaries(self, pbc_square_lattice): @pytest.fixture -def pbc_honeycomb_lattice() -> HoneycombLattice: +def pbc_honeycomb_lattice(): """Provides a 3x3 HoneycombLattice with Periodic Boundary Conditions.""" return HoneycombLattice(size=(3, 3), pbc=True) @@ -724,7 +701,7 @@ def test_honeycomb_next_nearest_neighbors(self, pbc_honeycomb_lattice): @pytest.fixture -def pbc_triangular_lattice() -> TriangularLattice: +def pbc_triangular_lattice(): """ Provides a 3x3 TriangularLattice with Periodic Boundary Conditions. A 3x3 size is used to ensure all 6 nearest neighbors are unique sites. @@ -782,7 +759,7 @@ class TestTILatticeEdgeCases: """ @pytest.fixture - def obc_1d_chain(self) -> ChainLattice: + def obc_1d_chain(self): """ Provides a 5-site 1D chain with Open Boundary Conditions. """ @@ -808,7 +785,7 @@ def test_1d_chain_properties_and_neighbors(self, obc_1d_chain): assert set(lattice.get_neighbors(middle_idx, k=1)) == {1, 3} @pytest.fixture - def nonsquare_lattice(self) -> SquareLattice: + def nonsquare_lattice(self): """Provides a non-square 2x3 lattice to test indexing.""" return SquareLattice(size=(2, 3), pbc=False) @@ -1239,7 +1216,7 @@ class TestLongRangeNeighborFinding: """ @pytest.fixture(scope="class") - def large_pbc_square_lattice(self) -> SquareLattice: + def large_pbc_square_lattice(self): """ Provides a single 6x8 SquareLattice with PBC for all tests in this class. Using scope="class" makes it more efficient as it's created only once. @@ -1508,7 +1485,7 @@ class TestCustomizeLatticeDynamic: """Tests the dynamic modification capabilities of CustomizeLattice.""" @pytest.fixture - def initial_lattice(self) -> CustomizeLattice: + def initial_lattice(self): """Provides a basic 3-site lattice for modification tests.""" return CustomizeLattice( dimensionality=2, @@ -1940,7 +1917,7 @@ def test_pbc_implementation_is_not_significantly_slower_than_obc(self): ) -def _validate_layers(bonds, layers) -> None: +def _validate_layers(bonds, layers): """ A helper function to scientifically validate the output of get_compatible_layers. """ @@ -1980,7 +1957,7 @@ def _validate_layers(bonds, layers) -> None: "HoneycombLattice_2x2_OBC", ], ) -def test_layering_on_various_lattices(lattice_instance: AbstractLattice): +def test_layering_on_various_lattices(lattice_instance): """Tests gate layering for various standard lattice types.""" bonds = lattice_instance.get_neighbor_pairs(k=1, unique=True) layers = get_compatible_layers(bonds) @@ -2058,19 +2035,7 @@ def test_layering_on_disconnected_graph(): } -@pytest.fixture(scope="function") -def jax_backend_fixture(): - """ - Pytest fixture to set the backend to 'jax' for a test and restore it afterward. - This ensures that differentiability tests run in the correct environment - without interfering with other tests. - """ - original_backend = tc.backend.name - try: - tc.set_backend("jax") - yield - finally: - tc.set_backend(original_backend) + class TestBackendIntegration: @@ -2119,12 +2084,12 @@ class TestBackendIntegration: ) def test_lattice_creation_and_properties_across_backends( self, - backend_name: str, - LatticeClass: "AbstractLattice", - init_args: dict[str, Any], - expected_distance_check: tuple[int, int, float], - name: str, - ) -> None: + backend_name, + LatticeClass, + init_args, + expected_distance_check, + name, + ): """ Tests that various lattices can be created with each backend and that their core properties (_coordinates, distance_matrix) have the correct @@ -2162,7 +2127,6 @@ def test_lattice_creation_and_properties_across_backends( err_msg=f"Distance check failed for {type(lat).__name__} on backend {backend_name}", ) - @pytest.mark.usefixtures("jax_backend_fixture") @pytest.mark.parametrize( "lattice_class, init_params, differentiable_arg_name, test_value", [ @@ -2189,11 +2153,12 @@ def test_lattice_creation_and_properties_across_backends( ) def test_tilattice_differentiability( self, - lattice_class: type[AbstractLattice], - init_params: dict[str, Any], - differentiable_arg_name: str, - test_value: Any, - ) -> None: + jaxb, + lattice_class, + init_params, + differentiable_arg_name, + test_value, + ): """ Tests that the distance_matrix of various TILattices is differentiable with respect to their geometric parameters. This test has been expanded @@ -2202,7 +2167,7 @@ def test_tilattice_differentiability( if not jnp: pytest.skip("JAX backend is required for this differentiability test.") - def get_total_distance(param: Any) -> Array: + def get_total_distance(param): """A scalar-in, scalar-out function for jax.grad.""" # Dynamically create the lattice with the parameter being differentiated lat = lattice_class(**init_params, **{differentiable_arg_name: param}) @@ -2228,8 +2193,7 @@ def get_total_distance(param: Any) -> Array: float(grad_val), 0.0 ), f"Gradient for {lattice_class.__name__} was zero." - @pytest.mark.usefixtures("jax_backend_fixture") - def test_customizelattice_differentiability(self) -> None: + def test_customizelattice_differentiability(self, jaxb): """ Tests that the distance_matrix of a CustomizeLattice is differentiable with respect to its input coordinates. @@ -2240,7 +2204,7 @@ def test_customizelattice_differentiability(self) -> None: initial_coords = jnp.array([[0.0, 0.0], [1.0, 1.0], [0.5, 0.5]]) - def get_total_distance_custom(coords: Array) -> Array: + def get_total_distance_custom(coords): """ A helper function that takes coordinates, creates a CustomizeLattice, and returns a scalar value (the sum of its distance matrix). @@ -2259,8 +2223,7 @@ def get_total_distance_custom(coords: Array) -> Array: assert grad_tensor is not None assert not np.all(np.isclose(grad_tensor, 0.0)) - @pytest.mark.usefixtures("jax_backend_fixture") - def test_tilattice_gradient_value_correctness(self) -> None: + def test_tilattice_gradient_value_correctness(self, jaxb): """ Tests that the AD gradient for a TILattice parameter matches the analytically calculated, correct gradient value. This is a stronger @@ -2270,7 +2233,7 @@ def test_tilattice_gradient_value_correctness(self) -> None: pytest.skip("JAX backend is required for this gradient value test.") # 1. Define a simple objective function - def get_energy(a: float) -> Array: + def get_energy(a): """ A simple energy function for a 2-site chain. Energy = (distance between site 0 and 1)^2 = a^2 @@ -2283,7 +2246,7 @@ def get_energy(a: float) -> Array: return tc.backend.sum(dist_matrix**2) / 2.0 # 2. Define the analytical (manually calculated) gradient - def analytical_gradient(a: float) -> float: + def analytical_gradient(a): """ The analytical derivative of the energy function E(a) = a^2. dE/da = 2a @@ -2306,7 +2269,7 @@ def analytical_gradient(a: float) -> float: ) @pytest.mark.parametrize("backend_name", ["numpy", "jax", "tensorflow", "pytorch"]) - def test_dynamic_modification_across_backends(self, backend_name: str) -> None: + def test_dynamic_modification_across_backends(self, backend_name): """ Tests that the dynamic modification methods (add_sites, remove_sites) of CustomizeLattice work correctly across all supported backends, @@ -2351,7 +2314,7 @@ def test_dynamic_modification_across_backends(self, backend_name: str) -> None: @pytest.mark.parametrize("backend_name", ["numpy", "jax", "tensorflow", "pytorch"]) -def test_dtype_consistency_across_backends(backend_name: str) -> None: +def test_dtype_consistency_across_backends(backend_name): """ Tests that the dtype of user-provided coordinate data is preserved in internal calculations across all backends. @@ -2390,7 +2353,7 @@ class TestPrivateHelpers: """ @pytest.fixture - def simple_lattice_for_helpers(self) -> CustomizeLattice: + def simple_lattice_for_helpers(self): """ Provides a very simple lattice instance, primarily to gain access to the private helper methods for testing. The geometry itself is trivial. @@ -2503,15 +2466,15 @@ def test_identify_distance_shells_with_empty_and_zero_input( shells_zero == [] ), "Should return an empty list for a distance array with only zeros." - def test_get_distance_matrix_with_mic(self): + def test_get_distance_matrix_with_mic_vectorized(self): """ - Tests the internal _get_distance_matrix_with_mic method for TILattice + Tests the internal _get_distance_matrix_with_mic_vectorized method for TILattice to ensure it correctly applies the Minimum Image Convention. """ # --- Test Case 1: Fully Periodic Boundary Conditions (PBC) --- lattice_pbc = SquareLattice(size=(3, 3), pbc=True, lattice_constant=1.0) # We need to use the numpy backend for direct comparison - dist_matrix_pbc = tc.backend.numpy(lattice_pbc._get_distance_matrix_with_mic()) + dist_matrix_pbc = tc.backend.numpy(lattice_pbc._get_distance_matrix_with_mic_vectorized()) # For a 3x3 PBC lattice, the distance between opposite edges should be 1. # Example: site (0,0) and site (2,0) @@ -2546,7 +2509,7 @@ def test_get_distance_matrix_with_mic(self): size=(3, 3), pbc=(True, False), lattice_constant=1.0 ) dist_matrix_mixed = tc.backend.numpy( - lattice_mixed._get_distance_matrix_with_mic() + lattice_mixed._get_distance_matrix_with_mic_vectorized() ) # In the periodic x-direction, distance should be 1. From 0ad707c4d93a48e89b6675c846190b4b7a110011 Mon Sep 17 00:00:00 2001 From: Stellogic Date: Tue, 12 Aug 2025 20:24:22 +0800 Subject: [PATCH 05/16] fix black --- examples/lennard_jones_optimization.py | 10 ++- tensorcircuit/backends/jax_backend.py | 5 +- tensorcircuit/backends/numpy_backend.py | 4 +- tensorcircuit/backends/tensorflow_backend.py | 4 +- tensorcircuit/templates/lattice.py | 88 ++++++++++++-------- tests/test_backends.py | 13 ++- tests/test_lattice.py | 7 +- 7 files changed, 78 insertions(+), 53 deletions(-) diff --git a/examples/lennard_jones_optimization.py b/examples/lennard_jones_optimization.py index 9edc63ca..068fafe7 100644 --- a/examples/lennard_jones_optimization.py +++ b/examples/lennard_jones_optimization.py @@ -8,12 +8,14 @@ The optimization showcases the key Task 3 capability: making lattice parameters differentiable for variational material design. """ + import optax import numpy as np import matplotlib.pyplot as plt # Try to enable JAX 64-bit precision if available (safe fallback) import jax # noqa: E402 + try: # pragma: no cover - optional optimization from jax import config as jax_config # type: ignore @@ -34,12 +36,14 @@ def calculate_potential(log_a, epsilon=0.5, sigma=1.0): This version creates the lattice inside the function to demonstrate truly differentiable geometry. """ lattice_constant = K.exp(log_a) - + # Create lattice with the differentiable parameter size = (4, 4) # Smaller size for demonstration - lattice = tc.templates.lattice.SquareLattice(size, lattice_constant=lattice_constant, pbc=True) + lattice = tc.templates.lattice.SquareLattice( + size, lattice_constant=lattice_constant, pbc=True + ) d = lattice.distance_matrix - + d_safe = K.where(d > 1e-9, d, K.convert_to_tensor(1e-9)) term12 = K.power(sigma / d_safe, 12) diff --git a/tensorcircuit/backends/jax_backend.py b/tensorcircuit/backends/jax_backend.py index c64b9828..d678b42d 100644 --- a/tensorcircuit/backends/jax_backend.py +++ b/tensorcircuit/backends/jax_backend.py @@ -50,7 +50,9 @@ def update(self, grads: pytree, params: pytree) -> pytree: return params -def _convert_to_tensor_jax(self: Any, tensor: Tensor, dtype: Optional[str] = None) -> Tensor: +def _convert_to_tensor_jax( + self: Any, tensor: Tensor, dtype: Optional[str] = None +) -> Tensor: if not isinstance(tensor, (np.ndarray, jnp.ndarray)) and not jnp.isscalar(tensor): raise TypeError( ("Expected a `jnp.array`, `np.array` or scalar. " f"Got {type(tensor)}") @@ -357,6 +359,7 @@ def expm(self, a: Tensor) -> Tensor: return jsp.linalg.expm(a) # currently expm in jax doesn't support AD, it will raise an AssertError, # see https://github.com/google/jax/issues/2645 + def stack(self, a: Sequence[Tensor], axis: int = 0) -> Tensor: return jnp.stack(a, axis=axis) diff --git a/tensorcircuit/backends/numpy_backend.py b/tensorcircuit/backends/numpy_backend.py index ae8d4e14..5dd7f625 100644 --- a/tensorcircuit/backends/numpy_backend.py +++ b/tensorcircuit/backends/numpy_backend.py @@ -35,7 +35,9 @@ def _sum_numpy( # see https://github.com/google/TensorNetwork/issues/952 -def _convert_to_tensor_numpy(self: Any, a: Tensor, dtype: Optional[str] = None) -> Tensor: +def _convert_to_tensor_numpy( + self: Any, a: Tensor, dtype: Optional[str] = None +) -> Tensor: if not isinstance(a, np.ndarray) and not np.isscalar(a): a = np.array(a) a = np.asarray(a) diff --git a/tensorcircuit/backends/tensorflow_backend.py b/tensorcircuit/backends/tensorflow_backend.py index 9d33a4fd..4113bf2a 100644 --- a/tensorcircuit/backends/tensorflow_backend.py +++ b/tensorcircuit/backends/tensorflow_backend.py @@ -78,7 +78,9 @@ def _tensordot_tf( # Use TensorFlow's dtype promotion rules by converting both to a common dtype if a.dtype != b.dtype: # Find the result dtype by performing a dummy operation - common_dtype = (tf.constant(0, dtype=a.dtype) + tf.constant(0, dtype=b.dtype)).dtype + common_dtype = ( + tf.constant(0, dtype=a.dtype) + tf.constant(0, dtype=b.dtype) + ).dtype a = tf.cast(a, common_dtype) b = tf.cast(b, common_dtype) return tf.tensordot(a, b, axes) diff --git a/tensorcircuit/templates/lattice.py b/tensorcircuit/templates/lattice.py index 1f72f569..985b5af0 100644 --- a/tensorcircuit/templates/lattice.py +++ b/tensorcircuit/templates/lattice.py @@ -733,20 +733,20 @@ def _get_distance_matrix_with_mic_vectorized(self) -> Coordinates: Computes the full N x N distance matrix using a fully vectorized approach that correctly applies the Minimum Image Convention (MIC) for periodic boundary conditions. - + This method uses full vectorization for optimal performance and compatibility with JIT compilation frameworks like JAX. The implementation processes all site pairs simultaneously rather than iterating row-by-row, which provides: - + - Better performance through vectorized operations - Full compatibility with automatic differentiation - JIT compilation support (e.g., JAX, TensorFlow) - Consistent tensor operations throughout - + The trade-off is higher memory usage compared to iterative approaches, - as it computes all pairwise distances simultaneously. For very large + as it computes all pairwise distances simultaneously. For very large lattices (N > 10^4 sites), memory usage scales as O(N^2). - + :return: Distance matrix with shape (N, N) where entry (i,j) is the minimum distance between sites i and j under periodic boundary conditions. :rtype: Coordinates @@ -1364,7 +1364,7 @@ def _build_neighbors(self, max_k: int = 1, **kwargs: Any) -> None: tol = kwargs.get("tol", 1e-6) use_kdtree = kwargs.get("use_kdtree", True) force_differentiable = kwargs.get("force_differentiable", False) - + if self.num_sites < 2: return @@ -1372,14 +1372,18 @@ def _build_neighbors(self, max_k: int = 1, **kwargs: Any) -> None: if force_differentiable: use_kdtree = False logger.info("Using differentiable distance matrix method (forced)") - + # Choose algorithm based on user preference if use_kdtree and not force_differentiable: - logger.info(f"Using KDTree method for {self.num_sites} sites up to k={max_k}") + logger.info( + f"Using KDTree method for {self.num_sites} sites up to k={max_k}" + ) self._build_neighbors_kdtree(max_k, tol) else: - logger.info(f"Using differentiable distance matrix method for {self.num_sites} sites up to k={max_k}") - + logger.info( + f"Using differentiable distance matrix method for {self.num_sites} sites up to k={max_k}" + ) + # Use the existing distance matrix method self._build_neighbors_by_distance_matrix(max_k, tol) @@ -1388,34 +1392,36 @@ def _build_neighbors(self, max_k: int = 1, **kwargs: Any) -> None: def _build_neighbors_kdtree(self, max_k: int, tol: float) -> None: """ Build neighbors using KDTree for optimal performance. - + This method provides O(N log N) performance for neighbor finding but breaks differentiability due to scipy dependency. Use this method when: - Performance is critical - Differentiability is not required - Large lattices (N > 1000) - + Note: This method uses numpy arrays directly and may not be compatible with all backend types (JAX, TensorFlow, etc.). - """ + """ # Convert coordinates to numpy for KDTree coords_np = backend.numpy(self._coordinates) - + # Build KDTree logger.info("Building KDTree...") tree = cKDTree(coords_np) - + # For small lattices or cases with potential duplicate coordinates, # fall back to distance matrix method for robustness if self.num_sites < 1000: - logger.info("Small lattice detected, falling back to distance matrix method for robustness") + logger.info( + "Small lattice detected, falling back to distance matrix method for robustness" + ) self._build_neighbors_by_distance_matrix(max_k, tol) return - + # Find all distances for shell identification - use comprehensive sampling logger.info("Identifying distance shells...") distances_for_shells = [] - + # For robust shell identification, query all pairwise distances for smaller lattices # or use dense sampling for larger ones if self.num_sites <= 100: @@ -1423,7 +1429,9 @@ def _build_neighbors_kdtree(self, max_k: int, tol: float) -> None: for i in range(self.num_sites): query_k = min(self.num_sites - 1, max_k * 20) if query_k > 0: - dists, _ = tree.query(coords_np[i], k=query_k + 1) # +1 to exclude self + dists, _ = tree.query( + coords_np[i], k=query_k + 1 + ) # +1 to exclude self distances_for_shells.extend(dists[1:]) # Skip distance to self else: # For larger lattices, use adaptive sampling but ensure we capture all shells @@ -1431,57 +1439,65 @@ def _build_neighbors_kdtree(self, max_k: int, tol: float) -> None: for i in range(0, self.num_sites, max(1, self.num_sites // sample_size)): query_k = min(max_k * 20 + 50, self.num_sites - 1) if query_k > 0: - dists, _ = tree.query(coords_np[i], k=query_k + 1) # +1 to exclude self + dists, _ = tree.query( + coords_np[i], k=query_k + 1 + ) # +1 to exclude self distances_for_shells.extend(dists[1:]) # Skip distance to self - + # Filter out zero distances (duplicate coordinates) before shell identification ZERO_THRESHOLD = 1e-12 distances_for_shells = [d for d in distances_for_shells if d > ZERO_THRESHOLD] - + if not distances_for_shells: logger.warning("No valid distances found for shell identification") self._neighbor_maps = {} return - + # Use the same shell identification logic as distance matrix method - distances_for_shells_sq = [d*d for d in distances_for_shells] - dist_shells_sq = self._identify_distance_shells(distances_for_shells_sq, max_k, tol) + distances_for_shells_sq = [d * d for d in distances_for_shells] + dist_shells_sq = self._identify_distance_shells( + distances_for_shells_sq, max_k, tol + ) dist_shells = [np.sqrt(d_sq) for d_sq in dist_shells_sq] - + logger.info(f"Found {len(dist_shells)} distance shells: {dist_shells[:5]}...") - + # Initialize neighbor maps self._neighbor_maps = {k: {} for k in range(1, len(dist_shells) + 1)} - + # Build neighbor lists for each site logger.info("Building neighbor lists...") for i in range(self.num_sites): # Query enough neighbors to capture all shells query_k = min(max_k * 20 + 50, self.num_sites - 1) if query_k > 0: - distances, indices = tree.query(coords_np[i], k=query_k + 1) # +1 for self - + distances, indices = tree.query( + coords_np[i], k=query_k + 1 + ) # +1 for self + # Skip the first entry (distance to self) distances = distances[1:] indices = indices[1:] - + # Filter out zero distances (duplicate coordinates) - valid_pairs = [(d, idx) for d, idx in zip(distances, indices) if d > ZERO_THRESHOLD] - + valid_pairs = [ + (d, idx) for d, idx in zip(distances, indices) if d > ZERO_THRESHOLD + ] + # Assign neighbors to shells for shell_idx, shell_dist in enumerate(dist_shells): k = shell_idx + 1 shell_neighbors = [] - + for dist, neighbor_idx in valid_pairs: if abs(dist - shell_dist) <= tol: shell_neighbors.append(int(neighbor_idx)) elif dist > shell_dist + tol: break # Distances are sorted, no more matches - + if shell_neighbors: self._neighbor_maps[k][i] = sorted(shell_neighbors) - + # Set distance matrix to None - will compute on demand self._distance_matrix = None logger.info("KDTree neighbor building completed") diff --git a/tests/test_backends.py b/tests/test_backends.py index 6aeb94d8..e113a8f9 100644 --- a/tests/test_backends.py +++ b/tests/test_backends.py @@ -1178,6 +1178,7 @@ def fb2(*args): # Test new backend methods added for differentiable lattice support + @pytest.mark.parametrize("backend", [lf("npb"), lf("tfb"), lf("jaxb"), lf("torchb")]) def test_backend_sort(backend): """Test sort method.""" @@ -1194,7 +1195,7 @@ def test_backend_all(backend): a = tc.backend.convert_to_tensor([True, True, True]) result = tc.backend.all(a) assert result == True - + # Test with False b = tc.backend.convert_to_tensor([True, False, True]) result = tc.backend.all(b) @@ -1206,11 +1207,11 @@ def test_backend_meshgrid(backend): """Test meshgrid method.""" x = tc.backend.convert_to_tensor([1, 2]) y = tc.backend.convert_to_tensor([3, 4]) - xx, yy = tc.backend.meshgrid(x, y, indexing='ij') - + xx, yy = tc.backend.meshgrid(x, y, indexing="ij") + expected_xx = tc.backend.convert_to_tensor([[1, 1], [2, 2]]) expected_yy = tc.backend.convert_to_tensor([[3, 4], [3, 4]]) - + np.testing.assert_allclose(xx, expected_xx) np.testing.assert_allclose(yy, expected_yy) @@ -1221,13 +1222,11 @@ def test_backend_expand_dims(backend): a = tc.backend.convert_to_tensor([[1, 2], [3, 4]]) result = tc.backend.expand_dims(a, axis=0) assert result.shape == (1, 2, 2) - + result = tc.backend.expand_dims(a, axis=1) assert result.shape == (2, 1, 2) - - @pytest.mark.parametrize("backend", [lf("npb"), lf("tfb"), lf("jaxb"), lf("torchb")]) def test_backend_where(backend): """Test where method.""" diff --git a/tests/test_lattice.py b/tests/test_lattice.py index 05accbb3..689b7b23 100644 --- a/tests/test_lattice.py +++ b/tests/test_lattice.py @@ -2035,9 +2035,6 @@ def test_layering_on_disconnected_graph(): } - - - class TestBackendIntegration: """ Tests to ensure lattice functionalities are consistent and correct @@ -2474,7 +2471,9 @@ def test_get_distance_matrix_with_mic_vectorized(self): # --- Test Case 1: Fully Periodic Boundary Conditions (PBC) --- lattice_pbc = SquareLattice(size=(3, 3), pbc=True, lattice_constant=1.0) # We need to use the numpy backend for direct comparison - dist_matrix_pbc = tc.backend.numpy(lattice_pbc._get_distance_matrix_with_mic_vectorized()) + dist_matrix_pbc = tc.backend.numpy( + lattice_pbc._get_distance_matrix_with_mic_vectorized() + ) # For a 3x3 PBC lattice, the distance between opposite edges should be 1. # Example: site (0,0) and site (2,0) From 92bc8e46a1e7e7587e8a3852e8628133a28fca77 Mon Sep 17 00:00:00 2001 From: Stellogic Date: Tue, 12 Aug 2025 20:32:57 +0800 Subject: [PATCH 06/16] fix black --- examples/lennard_jones_optimization.py | 2 -- tests/test_backends.py | 4 ++-- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/examples/lennard_jones_optimization.py b/examples/lennard_jones_optimization.py index 068fafe7..6c784fa8 100644 --- a/examples/lennard_jones_optimization.py +++ b/examples/lennard_jones_optimization.py @@ -14,7 +14,6 @@ import matplotlib.pyplot as plt # Try to enable JAX 64-bit precision if available (safe fallback) -import jax # noqa: E402 try: # pragma: no cover - optional optimization from jax import config as jax_config # type: ignore @@ -22,7 +21,6 @@ jax_config.update("jax_enable_x64", True) except Exception: # broad: environment may not have config attribute pass -import jax.numpy as jnp # noqa: E402 import tensorcircuit as tc # noqa: E402 diff --git a/tests/test_backends.py b/tests/test_backends.py index e113a8f9..8bad022d 100644 --- a/tests/test_backends.py +++ b/tests/test_backends.py @@ -1194,12 +1194,12 @@ def test_backend_all(backend): # Test all True a = tc.backend.convert_to_tensor([True, True, True]) result = tc.backend.all(a) - assert result == True + assert result is True # Test with False b = tc.backend.convert_to_tensor([True, False, True]) result = tc.backend.all(b) - assert result == False + assert result is False @pytest.mark.parametrize("backend", [lf("npb"), lf("tfb"), lf("jaxb"), lf("torchb")]) From efaee050aacb82a3487fe2c51feeca4aedee7102 Mon Sep 17 00:00:00 2001 From: Stellogic Date: Tue, 12 Aug 2025 20:45:55 +0800 Subject: [PATCH 07/16] fix mypy errors --- tensorcircuit/templates/lattice.py | 24 ++++++++++++++++++------ 1 file changed, 18 insertions(+), 6 deletions(-) diff --git a/tensorcircuit/templates/lattice.py b/tensorcircuit/templates/lattice.py index 985b5af0..1fc6982f 100644 --- a/tensorcircuit/templates/lattice.py +++ b/tensorcircuit/templates/lattice.py @@ -1420,7 +1420,7 @@ def _build_neighbors_kdtree(self, max_k: int, tol: float) -> None: # Find all distances for shell identification - use comprehensive sampling logger.info("Identifying distance shells...") - distances_for_shells = [] + distances_for_shells: List[float] = [] # For robust shell identification, query all pairwise distances for smaller lattices # or use dense sampling for larger ones @@ -1432,7 +1432,10 @@ def _build_neighbors_kdtree(self, max_k: int, tol: float) -> None: dists, _ = tree.query( coords_np[i], k=query_k + 1 ) # +1 to exclude self - distances_for_shells.extend(dists[1:]) # Skip distance to self + if isinstance(dists, np.ndarray): + distances_for_shells.extend(dists[1:]) # Skip distance to self + else: + distances_for_shells.append(dists) # Single distance else: # For larger lattices, use adaptive sampling but ensure we capture all shells sample_size = min(1000, self.num_sites // 2) # More conservative sampling @@ -1442,7 +1445,10 @@ def _build_neighbors_kdtree(self, max_k: int, tol: float) -> None: dists, _ = tree.query( coords_np[i], k=query_k + 1 ) # +1 to exclude self - distances_for_shells.extend(dists[1:]) # Skip distance to self + if isinstance(dists, np.ndarray): + distances_for_shells.extend(dists[1:]) # Skip distance to self + else: + distances_for_shells.append(dists) # Single distance # Filter out zero distances (duplicate coordinates) before shell identification ZERO_THRESHOLD = 1e-12 @@ -1476,12 +1482,18 @@ def _build_neighbors_kdtree(self, max_k: int, tol: float) -> None: ) # +1 for self # Skip the first entry (distance to self) - distances = distances[1:] - indices = indices[1:] + # Handle both single value and array cases + if isinstance(distances, np.ndarray) and len(distances) > 1: + distances_slice = distances[1:] + indices_slice = indices[1:] if isinstance(indices, np.ndarray) else np.array([], dtype=int) + else: + # Single value or empty case - no neighbors to process + distances_slice = np.array([]) + indices_slice = np.array([], dtype=int) # Filter out zero distances (duplicate coordinates) valid_pairs = [ - (d, idx) for d, idx in zip(distances, indices) if d > ZERO_THRESHOLD + (d, idx) for d, idx in zip(distances_slice, indices_slice) if d > ZERO_THRESHOLD ] # Assign neighbors to shells From 7063c6fae4dfba1835229dc986016bef94e97bef Mon Sep 17 00:00:00 2001 From: Stellogic Date: Tue, 12 Aug 2025 21:11:02 +0800 Subject: [PATCH 08/16] fix test_backends.py --- tests/test_backends.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_backends.py b/tests/test_backends.py index 8bad022d..3788d953 100644 --- a/tests/test_backends.py +++ b/tests/test_backends.py @@ -1194,12 +1194,12 @@ def test_backend_all(backend): # Test all True a = tc.backend.convert_to_tensor([True, True, True]) result = tc.backend.all(a) - assert result is True + assert tc.backend.numpy(result).item() is True # Test with False b = tc.backend.convert_to_tensor([True, False, True]) result = tc.backend.all(b) - assert result is False + assert tc.backend.numpy(result).item() is False @pytest.mark.parametrize("backend", [lf("npb"), lf("tfb"), lf("jaxb"), lf("torchb")]) From 589763ef1bddd0245b3d87556e06395196d35590 Mon Sep 17 00:00:00 2001 From: Stellogic Date: Tue, 12 Aug 2025 21:38:23 +0800 Subject: [PATCH 09/16] fix black --- tensorcircuit/templates/lattice.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/tensorcircuit/templates/lattice.py b/tensorcircuit/templates/lattice.py index 1fc6982f..369c514d 100644 --- a/tensorcircuit/templates/lattice.py +++ b/tensorcircuit/templates/lattice.py @@ -1485,7 +1485,11 @@ def _build_neighbors_kdtree(self, max_k: int, tol: float) -> None: # Handle both single value and array cases if isinstance(distances, np.ndarray) and len(distances) > 1: distances_slice = distances[1:] - indices_slice = indices[1:] if isinstance(indices, np.ndarray) else np.array([], dtype=int) + indices_slice = ( + indices[1:] + if isinstance(indices, np.ndarray) + else np.array([], dtype=int) + ) else: # Single value or empty case - no neighbors to process distances_slice = np.array([]) @@ -1493,7 +1497,9 @@ def _build_neighbors_kdtree(self, max_k: int, tol: float) -> None: # Filter out zero distances (duplicate coordinates) valid_pairs = [ - (d, idx) for d, idx in zip(distances_slice, indices_slice) if d > ZERO_THRESHOLD + (d, idx) + for d, idx in zip(distances_slice, indices_slice) + if d > ZERO_THRESHOLD ] # Assign neighbors to shells From daa3ff2ace082a56658a15cd160f08caec32404b Mon Sep 17 00:00:00 2001 From: Stellogic Date: Wed, 13 Aug 2025 22:15:42 +0800 Subject: [PATCH 10/16] fix according to the review --- examples/lennard_jones_optimization.py | 22 +- tensorcircuit/backends/abstract_backend.py | 41 +- tensorcircuit/backends/pytorch_backend.py | 2 - tensorcircuit/backends/tensorflow_backend.py | 6 +- tensorcircuit/templates/lattice.py | 206 ++++++---- tests/test_backends.py | 3 - tests/test_lattice.py | 390 ++++++++++++------- 7 files changed, 421 insertions(+), 249 deletions(-) diff --git a/examples/lennard_jones_optimization.py b/examples/lennard_jones_optimization.py index 6c784fa8..7c38cf9d 100644 --- a/examples/lennard_jones_optimization.py +++ b/examples/lennard_jones_optimization.py @@ -12,16 +12,7 @@ import optax import numpy as np import matplotlib.pyplot as plt - -# Try to enable JAX 64-bit precision if available (safe fallback) - -try: # pragma: no cover - optional optimization - from jax import config as jax_config # type: ignore - - jax_config.update("jax_enable_x64", True) -except Exception: # broad: environment may not have config attribute - pass -import tensorcircuit as tc # noqa: E402 +import tensorcircuit as tc tc.set_dtype("float64") # Use tc for universal control @@ -58,9 +49,8 @@ def calculate_potential(log_a, epsilon=0.5, sigma=1.0): return potential_energy -# Create a lambda function for optimization -potential_fun_for_grad = lambda log_a: calculate_potential(log_a) -value_and_grad_fun = K.jit(K.value_and_grad(potential_fun_for_grad)) +# Create value and grad function for optimization +value_and_grad_fun = K.jit(K.value_and_grad(calculate_potential)) optimizer = optax.adam(learning_rate=0.01) @@ -77,11 +67,7 @@ def calculate_potential(log_a, epsilon=0.5, sigma=1.0): history["a"].append(K.exp(log_a)) history["energy"].append(energy) - # Check for NaN gradients using TensorCircuit's backend-agnostic approach - if K.sum(tc.num_to_tensor(np.isnan(K.numpy(grad)))) > 0: - print(f"Gradient became NaN at iteration {i+1}. Stopping optimization.") - print(f"Current energy: {energy}, Current log_a: {log_a}") - break + # (Removed previously added blanket NaN guard per reviewer request to keep example minimal.) updates, opt_state = optimizer.update(grad, opt_state) log_a = optax.apply_updates(log_a, updates) diff --git a/tensorcircuit/backends/abstract_backend.py b/tensorcircuit/backends/abstract_backend.py index 30580af1..8dad618c 100644 --- a/tensorcircuit/backends/abstract_backend.py +++ b/tensorcircuit/backends/abstract_backend.py @@ -629,15 +629,27 @@ def all(self: Any, a: Tensor, axis: Optional[Sequence[int]] = None) -> Tensor: def meshgrid(self: Any, *args: Any, **kwargs: Any) -> Any: """ - Return coordinate matrices from coordinate vectors. + Return coordinate matrices from coordinate vectors. - :param args: coordinate vectors - :type args: Any + :param args: coordinate vectors + :type args: Any :param kwargs: keyword arguments for meshgrid, typically includes 'indexing' - which can be 'ij' (matrix indexing) or 'xy' (Cartesian indexing) - :type kwargs: Any - :return: list of coordinate matrices - :rtype: Any + which can be 'ij' (matrix indexing) or 'xy' (Cartesian indexing). + - 'ij': matrix indexing, first dimension corresponds to rows (default) + - 'xy': Cartesian indexing, first dimension corresponds to columns + Example: + >>> x, y = backend.meshgrid([0, 1], [0, 2], indexing='xy') + Shapes: + - x.shape == (2, 2) # rows correspond to y vector length + - y.shape == (2, 2) + Values: + x = [[0, 1], + [0, 1]] + y = [[0, 0], + [2, 2]] + :type kwargs: Any + :return: list of coordinate matrices + :rtype: Any """ raise NotImplementedError( "Backend '{}' has not implemented `meshgrid`.".format(self.name) @@ -797,6 +809,21 @@ def cast(self: Any, a: Tensor, dtype: str) -> Tensor: "Backend '{}' has not implemented `cast`.".format(self.name) ) + def convert_to_tensor(self: Any, a: Tensor, dtype: Optional[str] = None) -> Tensor: + """ + Convert input to tensor. + + :param a: input data to be converted + :type a: Tensor + :param dtype: target dtype, optional + :type dtype: Optional[str] + :return: converted tensor + :rtype: Tensor + """ + raise NotImplementedError( + "Backend '{}' has not implemented `convert_to_tensor`.".format(self.name) + ) + def mod(self: Any, x: Tensor, y: Tensor) -> Tensor: """ Compute y-mod of x (negative number behavior is not guaranteed to be consistent) diff --git a/tensorcircuit/backends/pytorch_backend.py b/tensorcircuit/backends/pytorch_backend.py index 31654046..cd037b6d 100644 --- a/tensorcircuit/backends/pytorch_backend.py +++ b/tensorcircuit/backends/pytorch_backend.py @@ -253,8 +253,6 @@ def expm(self, a: Tensor) -> Tensor: # it doesn't support complex numbers which is more severe issue. # see https://github.com/pytorch/pytorch/issues/9983 - # see https://github.com/pytorch/pytorch/issues/9983 - def sin(self, a: Tensor) -> Tensor: return torchlib.sin(a) diff --git a/tensorcircuit/backends/tensorflow_backend.py b/tensorcircuit/backends/tensorflow_backend.py index 13df9f67..f7c7ae5c 100644 --- a/tensorcircuit/backends/tensorflow_backend.py +++ b/tensorcircuit/backends/tensorflow_backend.py @@ -77,10 +77,8 @@ def _tensordot_tf( ) -> Tensor: # Use TensorFlow's dtype promotion rules by converting both to a common dtype if a.dtype != b.dtype: - # Find the result dtype by performing a dummy operation - common_dtype = ( - tf.constant(0, dtype=a.dtype) + tf.constant(0, dtype=b.dtype) - ).dtype + # Find the result dtype using TensorFlow's type promotion rules + common_dtype = tf.experimental.numpy.result_type(a.dtype, b.dtype) a = tf.cast(a, common_dtype) b = tf.cast(b, common_dtype) return tf.tensordot(a, b, axes) diff --git a/tensorcircuit/templates/lattice.py b/tensorcircuit/templates/lattice.py index 369c514d..ad86bf91 100644 --- a/tensorcircuit/templates/lattice.py +++ b/tensorcircuit/templates/lattice.py @@ -20,6 +20,7 @@ logger = logging.getLogger(__name__) import itertools +import math import numpy as np from scipy.spatial import cKDTree from .. import backend @@ -131,7 +132,8 @@ def get_coordinates(self, index: SiteIndex) -> Coordinates: :rtype: Coordinates """ self._validate_index(index) - assert self._coordinates is not None + if self._coordinates is None: + raise ValueError("Lattice coordinates have not been initialized.") coords = self._coordinates[index] return coords @@ -184,7 +186,8 @@ def get_site_info( - The site's coordinates as a NumPy array. :rtype: Tuple[SiteIndex, SiteIdentifier, Coordinates] """ - assert self._coordinates is not None + if self._coordinates is None: + raise ValueError("Lattice coordinates have not been initialized.") if isinstance(index_or_identifier, int): # SiteIndex is an int idx = index_or_identifier self._validate_index(idx) @@ -204,8 +207,9 @@ def sites(self) -> Iterator[Tuple[SiteIndex, SiteIdentifier, Coordinates]]: index, identifier, and coordinates. :rtype: Iterator[Tuple[SiteIndex, SiteIdentifier, Coordinates]] """ + if self._coordinates is None: + raise ValueError("Lattice coordinates have not been initialized.") for i in range(self.num_sites): - assert self._coordinates is not None yield i, self._identifiers[i], self._coordinates[i] def get_neighbors(self, index: SiteIndex, k: int = 1) -> List[SiteIndex]: @@ -366,7 +370,6 @@ def show( """ try: import matplotlib.pyplot as plt - import numpy as np except ImportError: logger.error( "Matplotlib is required for visualization. " @@ -398,6 +401,10 @@ def show( else: fig = ax.figure # type: ignore + if self._coordinates is None: + logger.error("Cannot show lattice: coordinates have not been initialized.") + return + coords = np.array(self._coordinates) # Prepare arguments for the scatter plot, allowing user overrides. scatter_args = {"s": 100, "zorder": 2} @@ -459,7 +466,6 @@ def show( if self.dimensionality > 2: ax_3d = cast("Axes3D", ax) for i, j in bonds: - assert self._coordinates is not None p1, p2 = self._coordinates[i], self._coordinates[j] ax_3d.plot( [p1[0], p2[0]], @@ -469,7 +475,6 @@ def show( ) else: for i, j in bonds: - assert self._coordinates is not None p1, p2 = self._coordinates[i], self._coordinates[j] if self.dimensionality == 1: # type: ignore @@ -662,21 +667,30 @@ def __init__( self.lattice_vectors = backend.convert_to_tensor(lattice_vectors) self.basis_coords = backend.convert_to_tensor(basis_coords) - assert self.lattice_vectors.shape == ( - dimensionality, - dimensionality, - ), "Lattice vectors shape mismatch" - assert ( - self.basis_coords.shape[1] == dimensionality - ), "Basis coordinates dimension mismatch" - assert len(size) == dimensionality, "Size tuple length mismatch" + if self.lattice_vectors.shape != (dimensionality, dimensionality): + raise ValueError( + f"Lattice vectors shape {self.lattice_vectors.shape} does not match " + f"expected ({dimensionality}, {dimensionality})" + ) + if self.basis_coords.shape[1] != dimensionality: + raise ValueError( + f"Basis coordinates dimension {self.basis_coords.shape[1]} does not " + f"match lattice dimensionality {dimensionality}" + ) + if len(size) != dimensionality: + raise ValueError( + f"Size tuple length {len(size)} does not match dimensionality {dimensionality}" + ) self.num_basis = self.basis_coords.shape[0] self.size = size if isinstance(pbc, bool): self.pbc = tuple([pbc] * dimensionality) else: - assert len(pbc) == dimensionality, "PBC tuple length mismatch" + if len(pbc) != dimensionality: + raise ValueError( + f"PBC tuple length {len(pbc)} does not match dimensionality {dimensionality}" + ) self.pbc = tuple(pbc) self._build_lattice() @@ -751,10 +765,25 @@ def _get_distance_matrix_with_mic_vectorized(self) -> Coordinates: minimum distance between sites i and j under periodic boundary conditions. :rtype: Coordinates """ - size_arr = backend.cast( - backend.convert_to_tensor(self.size), self.lattice_vectors.dtype + # Ensure dtype consistency across backends (especially torch) by explicitly + # casting size and lattice_vectors to the same floating dtype used internally. + # Strategy: prefer existing lattice_vectors dtype, fallback to float64 for precision, + # then float32 for compatibility. This avoids dtype mismatches in vectorized ops. + target_dt = None + try: + # prefer existing lattice_vectors dtype if possible + target_dt = backend.dtype(self.lattice_vectors) # type: ignore + except Exception: # pragma: no cover - defensive + target_dt = "float64" + if target_dt not in ("float32", "float64"): + # fallback for unusual dtypes + target_dt = "float64" + + size_arr = backend.cast(backend.convert_to_tensor(self.size), target_dt) + lattice_vecs = backend.cast( + backend.convert_to_tensor(self.lattice_vectors), target_dt ) - system_vectors = self.lattice_vectors * backend.expand_dims(size_arr, axis=1) + system_vectors = lattice_vecs * backend.expand_dims(size_arr, axis=1) pbc_mask = backend.convert_to_tensor(self.pbc) @@ -913,12 +942,15 @@ def __init__( a = lattice_constant # Define the two primitive lattice vectors for the underlying triangular Bravais lattice. - lattice_vectors = [ - [a * 1.5, a * backend.sqrt(backend.convert_to_tensor(3.0)) / 2], - [a * 1.5, -a * backend.sqrt(backend.convert_to_tensor(3.0)) / 2], - ] + rt3_over_2 = math.sqrt(3.0) / 2.0 + lattice_vectors = backend.convert_to_tensor( + [ + [a * 1.5, a * rt3_over_2], + [a * 1.5, -a * rt3_over_2], + ] + ) # Define the two basis sites (A and B) within the unit cell. - basis_coords = [[0.0, 0.0], [a * 1.0, 0.0]] + basis_coords = backend.convert_to_tensor([[0.0, 0.0], [a * 1.0, 0.0]]) super().__init__( dimensionality=dimensionality, @@ -962,12 +994,14 @@ def __init__( a = lattice_constant # Define the primitive lattice vectors for a triangular lattice. - lattice_vectors = [ - [a * 1.0, 0.0], - [a * 0.5, a * backend.sqrt(backend.convert_to_tensor(3.0)) / 2], - ] + lattice_vectors = backend.convert_to_tensor( + [ + [a * 1.0, 0.0], + [a * 0.5, a * backend.sqrt(backend.convert_to_tensor(3.0)) / 2.0], + ] + ) # A triangular lattice is a Bravais lattice with a single-site basis. - basis_coords = [[0.0, 0.0]] + basis_coords = backend.convert_to_tensor([[0.0, 0.0]]) super().__init__( dimensionality=dimensionality, @@ -1002,9 +1036,9 @@ def __init__( ): dimensionality = 1 # The lattice vector is just the lattice constant along one dimension. - lattice_vectors = [[lattice_constant]] + lattice_vectors = backend.convert_to_tensor([[lattice_constant]]) # A simple chain is a Bravais lattice with a single-site basis. - basis_coords = [[0.0]] + basis_coords = backend.convert_to_tensor([[0.0]]) super().__init__( dimensionality=dimensionality, @@ -1043,9 +1077,9 @@ def __init__( ): dimensionality = 1 # The unit cell is twice the bond length, as it contains two sites. - lattice_vectors = [[2 * lattice_constant]] + lattice_vectors = backend.convert_to_tensor([[2 * lattice_constant]]) # Two basis sites (A and B) separated by the bond length. - basis_coords = [[0.0], [lattice_constant]] + basis_coords = backend.convert_to_tensor([[0.0], [lattice_constant]]) super().__init__( dimensionality=dimensionality, @@ -1085,9 +1119,9 @@ def __init__( dimensionality = 2 ax, ay = lattice_constants # Orthogonal lattice vectors with potentially different lengths. - lattice_vectors = [[ax, 0.0], [0.0, ay]] + lattice_vectors = backend.convert_to_tensor([[ax, 0.0], [0.0, ay]]) # A rectangular lattice is a Bravais lattice with a single-site basis. - basis_coords = [[0.0, 0.0]] + basis_coords = backend.convert_to_tensor([[0.0, 0.0]]) super().__init__( dimensionality=dimensionality, @@ -1125,9 +1159,13 @@ def __init__( dimensionality = 2 a = lattice_constant # The unit cell is a square rotated by 45 degrees. - lattice_vectors = [[a * 1.0, a * 1.0], [a * 1.0, a * -1.0]] + lattice_vectors = backend.convert_to_tensor( + [[a * 1.0, a * 1.0], [a * 1.0, a * -1.0]] + ) # Two basis sites (A and B) within the unit cell. - basis_coords = [[a * 0.0, a * 0.0], [a * 1.0, a * 0.0]] + basis_coords = backend.convert_to_tensor( + [[a * 0.0, a * 0.0], [a * 1.0, a * 0.0]] + ) super().__init__( dimensionality=dimensionality, @@ -1165,16 +1203,20 @@ def __init__( dimensionality = 2 a = lattice_constant # The Kagome lattice is based on a triangular Bravais lattice. - lattice_vectors = [ - [a * 2.0, a * 0.0], - [a * 1.0, a * backend.sqrt(backend.convert_to_tensor(3.0))], - ] + lattice_vectors = backend.convert_to_tensor( + [ + [a * 2.0, a * 0.0], + [a * 1.0, a * backend.sqrt(3.0)], + ] + ) # It has a three-site basis, forming the corners of the triangles. - basis_coords = [ - [a * 0.0, a * 0.0], - [a * 1.0, a * 0.0], - [a * 0.5, a * backend.sqrt(backend.convert_to_tensor(3.0)) / 2.0], - ] + basis_coords = backend.convert_to_tensor( + [ + [a * 0.0, a * 0.0], + [a * 1.0, a * 0.0], + [a * 0.5, a * backend.sqrt(3.0) / 2.0], + ] + ) super().__init__( dimensionality=dimensionality, @@ -1216,13 +1258,17 @@ def __init__( unit_cell_side = 2 * bond_length # The Lieb lattice is based on a square Bravais lattice. - lattice_vectors = [[unit_cell_side, 0.0], [0.0, unit_cell_side]] + lattice_vectors = backend.convert_to_tensor( + [[unit_cell_side, 0.0], [0.0, unit_cell_side]] + ) # It has a three-site basis: one corner and two edge-centers. - basis_coords = [ - [0.0, 0.0], # Corner site - [bond_length, 0.0], # x-edge center - [0.0, bond_length], # y-edge center - ] + basis_coords = backend.convert_to_tensor( + [ + [0.0, 0.0], # Corner site + [bond_length, 0.0], # x-edge center + [0.0, bond_length], # y-edge center + ] + ) super().__init__( dimensionality=dimensionality, @@ -1260,9 +1306,9 @@ def __init__( dimensionality = 3 a = lattice_constant # Orthogonal lattice vectors of equal length in 3D. - lattice_vectors = [[a, 0, 0], [0, a, 0], [0, 0, a]] + lattice_vectors = backend.convert_to_tensor([[a, 0, 0], [0, a, 0], [0, 0, a]]) # A simple cubic lattice is a Bravais lattice with a single-site basis. - basis_coords = [[0.0, 0.0, 0.0]] + basis_coords = backend.convert_to_tensor([[0.0, 0.0, 0.0]]) super().__init__( dimensionality=dimensionality, lattice_vectors=lattice_vectors, @@ -1358,23 +1404,15 @@ def _build_neighbors(self, max_k: int = 1, **kwargs: Any) -> None: :param kwargs: Additional arguments including: - use_kdtree (bool): Whether to use KDTree optimization. Defaults to True. - tol (float): Distance tolerance for neighbor identification. Defaults to 1e-6. - - force_differentiable (bool): If True, forces distance matrix method even when - KDTree is available. Defaults to False. """ tol = kwargs.get("tol", 1e-6) use_kdtree = kwargs.get("use_kdtree", True) - force_differentiable = kwargs.get("force_differentiable", False) if self.num_sites < 2: return - # Override KDTree if differentiability is explicitly required - if force_differentiable: - use_kdtree = False - logger.info("Using differentiable distance matrix method (forced)") - # Choose algorithm based on user preference - if use_kdtree and not force_differentiable: + if use_kdtree: logger.info( f"Using KDTree method for {self.num_sites} sites up to k={max_k}" ) @@ -1522,29 +1560,21 @@ def _build_neighbors_kdtree(self, max_k: int, tol: float) -> None: def _compute_distance_matrix(self) -> Coordinates: """ - Computes the full N x N distance matrix using backend operations. - This implementation is fully differentiable. + Computes the full N x N distance matrix by delegating to the inherited method. + This avoids code duplication with the base class implementation. """ if self.num_sites == 0: return backend.zeros((0, 0)) if self.num_sites < 2: - assert self._coordinates is not None + if self._coordinates is None: + raise ValueError("Lattice coordinates have not been initialized.") return backend.zeros( (self.num_sites, self.num_sites), dtype=self._coordinates.dtype ) - # Vectorized computation of displacements: (N, 1, D) - (1, N, D) -> (N, N, D) - displacements = backend.expand_dims(self._coordinates, 1) - backend.expand_dims( - self._coordinates, 0 - ) - - dist_matrix_sq = backend.sum(displacements**2, axis=-1) - - return backend.where( - dist_matrix_sq == 0, - 0, - backend.sqrt(dist_matrix_sq), - ) + # Use the inherited method from AbstractLattice which computes and caches the distance matrix + self._build_neighbors_by_distance_matrix(max_k=1, tol=1e-6) + return self._distance_matrix def _reset_computations(self) -> None: """Resets all cached data that depends on the lattice structure.""" @@ -1574,10 +1604,30 @@ def from_lattice(cls, lattice: "AbstractLattice") -> "CustomizeLattice": # Unzip the list of tuples into separate lists of identifiers and coordinates _, identifiers, coordinates = zip(*all_sites_info) + # Normalize coordinates to plain nested Python float lists to avoid + # backend-specific tensor list issues (e.g., torch.tensor(list_of_tensors) ValueError). + # This ensures the resulting CustomizeLattice works consistently across all backends + # by converting any backend tensors to backend-agnostic Python lists. + normalized_coords = [] + for c in coordinates: + try: + # If already a backend tensor, convert to numpy then to list + if hasattr(backend, "is_tensor") and backend.is_tensor(c): # type: ignore + normalized_coords.append(backend.numpy(c).tolist()) # type: ignore + else: + # c may be a numpy array or list-like + if hasattr(c, "tolist"): + normalized_coords.append(c.tolist()) # type: ignore + else: + normalized_coords.append(list(c)) # fallback + except Exception: # pragma: no cover - defensive + # Last resort: wrap scalar(s) + normalized_coords.append([float(x) for x in c]) # type: ignore + return cls( dimensionality=lattice.dimensionality, identifiers=list(identifiers), - coordinates=list(coordinates), + coordinates=normalized_coords, ) def add_sites( diff --git a/tests/test_backends.py b/tests/test_backends.py index 3788d953..b589163e 100644 --- a/tests/test_backends.py +++ b/tests/test_backends.py @@ -1176,9 +1176,6 @@ def fb2(*args): np.testing.assert_allclose(yajit, tc.backend.transpose(ybjit, [1, 0, 2]), atol=1e-5) -# Test new backend methods added for differentiable lattice support - - @pytest.mark.parametrize("backend", [lf("npb"), lf("tfb"), lf("jaxb"), lf("torchb")]) def test_backend_sort(backend): """Test sort method.""" diff --git a/tests/test_lattice.py b/tests/test_lattice.py index 689b7b23..69923071 100644 --- a/tests/test_lattice.py +++ b/tests/test_lattice.py @@ -4,23 +4,11 @@ import matplotlib -try: - import jax.numpy as jnp -except ImportError: - jnp = None -try: - import tensorflow as tf -except ImportError: - tf = None -try: - import torch -except ImportError: - torch = None - matplotlib.use("Agg") import pytest import numpy as np +from pytest_lazyfixture import lazy_fixture as lf from tensorcircuit.templates.lattice import ( ChainLattice, @@ -83,7 +71,10 @@ class TestCustomizeLattice: This helps in organizing the test suite. """ - def test_initialization_and_properties(self, kagome_lattice_fragment): + @pytest.mark.parametrize( + "backend", [lf("npb"), lf("tfb"), lf("jaxb"), lf("torchb")] + ) + def test_initialization_and_properties(self, backend, kagome_lattice_fragment): """ Test case for successful initialization and verification of basic properties. This test function receives the 'kagome_lattice_fragment' fixture as an argument. @@ -96,10 +87,15 @@ def test_initialization_and_properties(self, kagome_lattice_fragment): assert lattice.num_sites == 6 assert len(lattice) == 6 # This also tests the __len__ dunder method - # Verify that coordinates are correctly stored as numpy arrays. - # It's important to use np.testing.assert_array_equal for numpy array comparison. + # Verify that coordinates are correctly stored as backend tensors. + # It's important to use np.testing.assert_array_equal for coordinate comparison. expected_coord = np.array([0.5, np.sqrt(3) / 2]) - np.testing.assert_array_equal(lattice.get_coordinates(2), expected_coord) + np.testing.assert_allclose( + tc.backend.numpy(lattice.get_coordinates(2)), + expected_coord, + rtol=1e-6, + atol=1e-6, + ) # Verify that the mapping between identifiers and indices is correct. assert lattice.get_identifier(4) == 4 @@ -139,7 +135,10 @@ def test_input_validation_wrong_dimension(self): dimensionality=2, identifiers=ids_ok, coordinates=coords_wrong_dim ) - def test_neighbor_finding(self, simple_square_lattice): + @pytest.mark.parametrize( + "backend", [lf("npb"), lf("tfb"), lf("jaxb"), lf("torchb")] + ) + def test_neighbor_finding(self, backend, simple_square_lattice): """ Tests the k-th nearest neighbor finding functionality (_build_neighbors and get_neighbors). @@ -195,7 +194,10 @@ def test_neighbor_finding(self, simple_square_lattice): f"Missing: {expected - neighbors}. Extra: {neighbors - expected}." ) - def test_neighbor_pairs(self, simple_square_lattice): + @pytest.mark.parametrize( + "backend", [lf("npb"), lf("tfb"), lf("jaxb"), lf("torchb")] + ) + def test_neighbor_pairs(self, backend, simple_square_lattice): """ Tests the retrieval of unique neighbor pairs (bonds) using get_neighbor_pairs. @@ -220,7 +222,10 @@ def test_neighbor_pairs(self, simple_square_lattice): expected_nnn_pairs = {(0, 3), (1, 2)} assert set(map(tuple, nnn_pairs)) == expected_nnn_pairs - def test_neighbor_pairs_non_unique(self, simple_square_lattice): + @pytest.mark.parametrize( + "backend", [lf("npb"), lf("tfb"), lf("jaxb"), lf("torchb")] + ) + def test_neighbor_pairs_non_unique(self, backend, simple_square_lattice): """ Tests get_neighbor_pairs with unique=False to ensure all directed pairs (bonds) are returned. @@ -566,7 +571,10 @@ class TestSquareLattice: the core functionality of its parent, TILattice. """ - def test_initialization_and_properties(self, obc_square_lattice): + @pytest.mark.parametrize( + "backend", [lf("npb"), lf("tfb"), lf("jaxb"), lf("torchb")] + ) + def test_initialization_and_properties(self, backend, obc_square_lattice): """ Tests the basic properties of a SquareLattice instance. """ @@ -575,7 +583,10 @@ def test_initialization_and_properties(self, obc_square_lattice): assert lattice.num_sites == 9 # A 3x3 lattice should have 9 sites. assert len(lattice) == 9 - def test_site_info_and_identifiers(self, obc_square_lattice): + @pytest.mark.parametrize( + "backend", [lf("npb"), lf("tfb"), lf("jaxb"), lf("torchb")] + ) + def test_site_info_and_identifiers(self, backend, obc_square_lattice): """ Tests that site information (coordinates, identifiers) is correct. """ @@ -585,14 +596,17 @@ def test_site_info_and_identifiers(self, obc_square_lattice): _, ident, coords = lattice.get_site_info(center_idx) assert ident == (1, 1, 0) - np.testing.assert_array_equal(coords, np.array([1.0, 1.0])) + np.testing.assert_array_equal(tc.backend.numpy(coords), np.array([1.0, 1.0])) corner_idx = 0 _, ident, coords = lattice.get_site_info(corner_idx) assert ident == (0, 0, 0) - np.testing.assert_array_equal(coords, np.array([0.0, 0.0])) + np.testing.assert_array_equal(tc.backend.numpy(coords), np.array([0.0, 0.0])) - def test_neighbors_with_open_boundaries(self, obc_square_lattice): + @pytest.mark.parametrize( + "backend", [lf("npb"), lf("tfb"), lf("jaxb"), lf("torchb")] + ) + def test_neighbors_with_open_boundaries(self, backend, obc_square_lattice): """ Tests neighbor finding with Open Boundary Conditions (OBC) using specific neighbor identities. @@ -663,7 +677,10 @@ class TestHoneycombLattice: Tests the HoneycombLattice class, focusing on its two-site basis. """ - def test_initialization_and_properties(self, pbc_honeycomb_lattice): + @pytest.mark.parametrize( + "backend", [lf("npb"), lf("tfb"), lf("jaxb"), lf("torchb")] + ) + def test_initialization_and_properties(self, backend, pbc_honeycomb_lattice): """ Tests that the total number of sites is correct for a composite lattice. """ @@ -671,7 +688,10 @@ def test_initialization_and_properties(self, pbc_honeycomb_lattice): assert lattice.num_sites == 18 assert lattice.num_basis == 2 - def test_honeycomb_neighbors(self, pbc_honeycomb_lattice): + @pytest.mark.parametrize( + "backend", [lf("npb"), lf("tfb"), lf("jaxb"), lf("torchb")] + ) + def test_honeycomb_neighbors(self, backend, pbc_honeycomb_lattice): """ Tests that every site in a honeycomb lattice has 3 nearest neighbors. """ @@ -682,7 +702,10 @@ def test_honeycomb_neighbors(self, pbc_honeycomb_lattice): site_b_idx = lattice.get_index((0, 0, 1)) assert len(lattice.get_neighbors(site_b_idx, k=1)) == 3 - def test_honeycomb_next_nearest_neighbors(self, pbc_honeycomb_lattice): + @pytest.mark.parametrize( + "backend", [lf("npb"), lf("tfb"), lf("jaxb"), lf("torchb")] + ) + def test_honeycomb_next_nearest_neighbors(self, backend, pbc_honeycomb_lattice): """ Tests that every site in a honeycomb lattice has 6 next-nearest neighbors under periodic boundary conditions. @@ -714,14 +737,20 @@ class TestTriangularLattice: Tests the TriangularLattice class, focusing on its coordination number. """ - def test_initialization_and_properties(self, pbc_triangular_lattice): + @pytest.mark.parametrize( + "backend", [lf("npb"), lf("tfb"), lf("jaxb"), lf("torchb")] + ) + def test_initialization_and_properties(self, backend, pbc_triangular_lattice): """ Tests the basic properties of the triangular lattice. """ lattice = pbc_triangular_lattice assert lattice.num_sites == 9 # 3 * 3 = 9 sites for a 3x3 grid - def test_triangular_neighbors(self, pbc_triangular_lattice): + @pytest.mark.parametrize( + "backend", [lf("npb"), lf("tfb"), lf("jaxb"), lf("torchb")] + ) + def test_triangular_neighbors(self, backend, pbc_triangular_lattice): """ Tests that every site in a triangular lattice has 6 nearest neighbors. """ @@ -1120,17 +1149,21 @@ def test_init_with_mismatched_shapes_raises_error(self): """ # Act & Assert: # Pass a 'size' tuple with 3 elements to a 2D SquareLattice. - # This should trigger the AssertionError from the parent TILattice class. - with pytest.raises(AssertionError, match="Size tuple length mismatch"): + # This should trigger the ValueError from the parent TILattice class. + with pytest.raises( + ValueError, match="Size tuple length .* does not match dimensionality" + ): SquareLattice(size=(2, 2, 2)) def test_init_with_mismatched_pbc_raises_error(self): """ - Tests that TILattice raises AssertionError if the 'pbc' tuple's + Tests that TILattice raises ValueError if the 'pbc' tuple's length does not match the dimensionality. This addresses a gap identified in the code review. """ - with pytest.raises(AssertionError, match="PBC tuple length mismatch"): + with pytest.raises( + ValueError, match="PBC tuple length .* does not match dimensionality" + ): # A 2D lattice requires a pbc tuple of length 2, but we provide one of length 1. SquareLattice(size=(2, 2), pbc=(True,)) @@ -1490,11 +1523,16 @@ def initial_lattice(self): return CustomizeLattice( dimensionality=2, identifiers=["A", "B", "C"], - coordinates=[[0, 0], [1, 0], [0, 1]], + coordinates=[[0.0, 0.0], [1.0, 0.0], [0.0, 1.0]], ) - def test_from_lattice_conversion(self): + @pytest.mark.parametrize( + "backend", [lf("npb"), lf("tfb"), lf("jaxb"), lf("torchb")] + ) + def test_from_lattice_conversion(self, backend): """Tests creating a CustomizeLattice from a TILattice.""" + tc.set_backend(backend) + # Arrange sq_lattice = SquareLattice(size=(2, 2), pbc=False) @@ -1506,28 +1544,42 @@ def test_from_lattice_conversion(self): assert custom_lattice.num_sites == sq_lattice.num_sites assert custom_lattice.dimensionality == sq_lattice.dimensionality # Verify a site to be sure + # custom_lattice coordinates are normalized to python lists; source lattice returns backend tensor np.testing.assert_array_equal( - custom_lattice.get_coordinates(3), sq_lattice.get_coordinates(3) + np.array(custom_lattice.get_coordinates(3)), + tc.backend.numpy(sq_lattice.get_coordinates(3)), ) assert custom_lattice.get_identifier(3) == sq_lattice.get_identifier(3) - def test_add_sites_successfully(self, initial_lattice): + @pytest.mark.parametrize( + "backend", [lf("npb"), lf("tfb"), lf("jaxb"), lf("torchb")] + ) + def test_add_sites_successfully(self, backend, initial_lattice): """Tests adding new, valid sites to the lattice.""" + tc.set_backend(backend) + # Arrange lat = initial_lattice assert lat.num_sites == 3 # Act - lat.add_sites(identifiers=["D", "E"], coordinates=[[1, 1], [2, 2]]) + lat.add_sites(identifiers=["D", "E"], coordinates=[[1.0, 1.0], [2.0, 2.0]]) # Assert assert lat.num_sites == 5 assert lat.get_identifier(4) == "E" - np.testing.assert_array_equal(lat.get_coordinates(3), np.array([1, 1])) + np.testing.assert_array_equal( + tc.backend.numpy(lat.get_coordinates(3)), np.array([1.0, 1.0]) + ) assert "E" in lat._ident_to_idx - def test_remove_sites_successfully(self, initial_lattice): + @pytest.mark.parametrize( + "backend", [lf("npb"), lf("tfb"), lf("jaxb"), lf("torchb")] + ) + def test_remove_sites_successfully(self, backend, initial_lattice): """Tests removing existing sites from the lattice.""" + tc.set_backend(backend) + # Arrange lat = initial_lattice assert lat.num_sites == 3 @@ -1539,29 +1591,46 @@ def test_remove_sites_successfully(self, initial_lattice): assert lat.num_sites == 1 assert lat.get_identifier(0) == "B" # Site 'B' is now at index 0 assert "A" not in lat._ident_to_idx - np.testing.assert_array_equal(lat.get_coordinates(0), np.array([1, 0])) + np.testing.assert_array_equal( + tc.backend.numpy(lat.get_coordinates(0)), np.array([1.0, 0.0]) + ) - def test_add_duplicate_identifier_raises_error(self, initial_lattice): + @pytest.mark.parametrize( + "backend", [lf("npb"), lf("tfb"), lf("jaxb"), lf("torchb")] + ) + def test_add_duplicate_identifier_raises_error(self, backend, initial_lattice): """Tests that adding a site with an existing identifier fails.""" + tc.set_backend(backend) + with pytest.raises(ValueError, match="Duplicate identifiers found"): - initial_lattice.add_sites(identifiers=["A"], coordinates=[[9, 9]]) + initial_lattice.add_sites(identifiers=["A"], coordinates=[[9.0, 9.0]]) - def test_remove_nonexistent_identifier_raises_error(self, initial_lattice): + @pytest.mark.parametrize( + "backend", [lf("npb"), lf("tfb"), lf("jaxb"), lf("torchb")] + ) + def test_remove_nonexistent_identifier_raises_error(self, backend, initial_lattice): """Tests that removing a non-existent site fails.""" + tc.set_backend(backend) + with pytest.raises(ValueError, match="Non-existent identifiers provided"): initial_lattice.remove_sites(identifiers=["Z"]) - def test_modification_clears_neighbor_cache(self, initial_lattice): + @pytest.mark.parametrize( + "backend", [lf("npb"), lf("tfb"), lf("jaxb"), lf("torchb")] + ) + def test_modification_clears_neighbor_cache(self, backend, initial_lattice): """ Tests that add_sites and remove_sites correctly invalidate the pre-computed neighbor map. """ + tc.set_backend(backend) + # Arrange: Pre-compute neighbors on the initial lattice initial_lattice._build_neighbors(max_k=1) assert 0 in initial_lattice._neighbor_maps[1] # Check that neighbors exist # Act 1: Add a site - initial_lattice.add_sites(identifiers=["D"], coordinates=[[5, 5]]) + initial_lattice.add_sites(identifiers=["D"], coordinates=[[5.0, 5.0]]) # Assert 1: The neighbor map should now be empty assert not initial_lattice._neighbor_maps @@ -1576,21 +1645,26 @@ def test_modification_clears_neighbor_cache(self, initial_lattice): # Assert 2: The neighbor map should be empty again assert not initial_lattice._neighbor_maps - def test_modification_clears_distance_matrix_cache(self, initial_lattice): + @pytest.mark.parametrize( + "backend", [lf("npb"), lf("tfb"), lf("jaxb"), lf("torchb")] + ) + def test_modification_clears_distance_matrix_cache(self, backend, initial_lattice): """ Tests that add_sites and remove_sites correctly invalidate the cached distance matrix and that the recomputed matrix is correct. """ + tc.set_backend(backend) + # Arrange 1: Compute, cache, and perform a meaningful check on the original matrix. lat = initial_lattice original_matrix = lat.distance_matrix assert lat._distance_matrix is not None assert original_matrix.shape == (3, 3) # Meaningful check: distance from 'A'(idx 0) to 'B'(idx 1) should be 1.0 - np.testing.assert_allclose(original_matrix[0, 1], 1.0) + np.testing.assert_allclose(tc.backend.numpy(original_matrix)[0, 1], 1.0) # Act 1: Add a site. This should invalidate the cache. - lat.add_sites(identifiers=["D"], coordinates=[[1, 1]]) + lat.add_sites(identifiers=["D"], coordinates=[[1.0, 1.0]]) # Assert 1: Check cache is cleared and the new matrix is correct. assert lat._distance_matrix is None # Verify cache invalidation @@ -1598,7 +1672,7 @@ def test_modification_clears_distance_matrix_cache(self, initial_lattice): assert new_matrix_added.shape == (4, 4) # Meaningful check: distance from 'B'(idx 1) to new site 'D'(idx 3) should be 1.0 # Coords: B=[1,0], D=[1,1] - np.testing.assert_allclose(new_matrix_added[1, 3], 1.0) + np.testing.assert_allclose(tc.backend.numpy(new_matrix_added)[1, 3], 1.0) # Act 2: Remove a site. This should also invalidate the cache. lat.remove_sites(identifiers=["A"]) @@ -1612,13 +1686,18 @@ def test_modification_clears_distance_matrix_cache(self, initial_lattice): # 'C' is now at index 1 (coords [0,1]) # 'D' is now at index 2 (coords [1,1]) # Distance from new 'B' (idx 0) to new 'D' (idx 2) should be 1.0 - np.testing.assert_allclose(final_matrix[0, 2], 1.0) + np.testing.assert_allclose(tc.backend.numpy(final_matrix)[0, 2], 1.0) - def test_neighbor_finding_returns_sorted_list(self, simple_square_lattice): + @pytest.mark.parametrize( + "backend", [lf("npb"), lf("tfb"), lf("jaxb"), lf("torchb")] + ) + def test_neighbor_finding_returns_sorted_list(self, backend, simple_square_lattice): """ Ensures that the list of neighbors returned by get_neighbors is always sorted. This provides a stricter check than set-based comparisons. """ + tc.set_backend(backend) + # Arrange lattice = simple_square_lattice @@ -1635,8 +1714,13 @@ def test_neighbor_finding_returns_sorted_list(self, simple_square_lattice): 3, ], "The neighbor list should be sorted in ascending order." - def test_from_lattice_from_empty_lattice(self): + @pytest.mark.parametrize( + "backend", [lf("npb"), lf("tfb"), lf("jaxb"), lf("torchb")] + ) + def test_from_lattice_from_empty_lattice(self, backend): """Tests creating a CustomizeLattice from an empty TILattice.""" + tc.set_backend(backend) + # Arrange: Create an empty TILattice instance. empty_sq = SquareLattice(size=(0, 0)) @@ -1648,8 +1732,13 @@ def test_from_lattice_from_empty_lattice(self): assert custom_from_empty.num_sites == 0 assert custom_from_empty.dimensionality == 2 - def test_add_sites_to_empty_lattice(self): + @pytest.mark.parametrize( + "backend", [lf("npb"), lf("tfb"), lf("jaxb"), lf("torchb")] + ) + def test_add_sites_to_empty_lattice(self, backend): """Tests adding sites to a previously empty CustomizeLattice.""" + tc.set_backend(backend) + # Arrange: Create an empty CustomizeLattice. empty_lat = CustomizeLattice(dimensionality=2, identifiers=[], coordinates=[]) assert empty_lat.num_sites == 0 @@ -1663,14 +1752,19 @@ def test_add_sites_to_empty_lattice(self): assert empty_lat.num_sites == 2 assert empty_lat.get_identifier(0) == "X" np.testing.assert_array_equal( - empty_lat.get_coordinates(1), np.array([2.0, 2.0]) + tc.backend.numpy(empty_lat.get_coordinates(1)), np.array([2.0, 2.0]) ) - def test_add_and_remove_empty_list_of_sites(self, initial_lattice): + @pytest.mark.parametrize( + "backend", [lf("npb"), lf("tfb"), lf("jaxb"), lf("torchb")] + ) + def test_add_and_remove_empty_list_of_sites(self, backend, initial_lattice): """ Tests that calling add_sites and remove_sites with empty lists is a no-op and doesn't change the lattice state. """ + tc.set_backend(backend) + # Arrange lat = initial_lattice original_num_sites = lat.num_sites @@ -1691,8 +1785,13 @@ def test_add_and_remove_empty_list_of_sites(self, initial_lattice): assert lat.num_sites == original_num_sites assert id(lat._coordinates) == original_coords_id - def test_remove_all_sites(self, initial_lattice): + @pytest.mark.parametrize( + "backend", [lf("npb"), lf("tfb"), lf("jaxb"), lf("torchb")] + ) + def test_remove_all_sites(self, backend, initial_lattice): """Tests removing all sites from a lattice, resulting in an empty lattice.""" + tc.set_backend(backend) + # Arrange lat = initial_lattice # Get all identifiers before removal. @@ -1710,6 +1809,9 @@ def test_remove_all_sites(self, initial_lattice): class TestDistanceMatrix: # This is the upgraded, parameterized test. + @pytest.mark.parametrize( + "backend", [lf("npb"), lf("tfb"), lf("jaxb"), lf("torchb")] + ) @pytest.mark.parametrize( # We define test scenarios as tuples: # (build_k, check_site_identifier, expected_dist_sq) @@ -1730,7 +1832,7 @@ class TestDistanceMatrix: ], ) def test_tilattice_full_pbc_distance_matrix_is_correct_regardless_of_build_k( - self, build_k, check_site_identifier, expected_dist_sq + self, backend, build_k, check_site_identifier, expected_dist_sq ): """ Tests that the distance matrix for a fully periodic TILattice is @@ -1769,11 +1871,16 @@ def test_tilattice_full_pbc_distance_matrix_is_correct_regardless_of_build_k( actual_dist_sq, expected_dist_sq, err_msg=error_message ) - def test_tilattice_mixed_bc_distance_matrix_is_correct(self): + @pytest.mark.parametrize( + "backend", [lf("npb"), lf("tfb"), lf("jaxb"), lf("torchb")] + ) + def test_tilattice_mixed_bc_distance_matrix_is_correct(self, backend): """ Tests that the distance matrix is correctly calculated for a TILattice with mixed boundary conditions (e.g., periodic in x, open in y). """ + tc.set_backend(backend) + # Arrange # pbc=(True, False) means periodic along x-axis, open along y-axis. lat = SquareLattice(size=(5, 5), pbc=(True, False)) @@ -1804,25 +1911,44 @@ def test_tilattice_mixed_bc_distance_matrix_is_correct(self): ) # --- This list and the following test are now at the correct indentation level --- - lattice_instances_for_invariant_test = [ - SquareLattice(size=(4, 4), pbc=True), - SquareLattice(size=(4, 3), pbc=(True, False)), # Mixed BC, non-square - HoneycombLattice(size=(3, 3), pbc=True), - TriangularLattice(size=(4, 4), pbc=False), - CustomizeLattice( - dimensionality=2, - identifiers=list(range(4)), - coordinates=[[0, 0], [1, 1], [0, 1], [1, 0]], + # 使用工厂函数而不是预先实例化对象,避免跨 backend 复用已缓存 _distance_matrix(numpy 数组) + lattice_factories_for_invariant_test = [ + pytest.param(lambda: SquareLattice(size=(4, 4), pbc=True), id="Square_4x4_pbc"), + pytest.param( + lambda: SquareLattice(size=(4, 3), pbc=(True, False)), id="Square_4x3_mixed" + ), + pytest.param( + lambda: HoneycombLattice(size=(3, 3), pbc=True), id="Honeycomb_3x3_pbc" + ), + pytest.param( + lambda: TriangularLattice(size=(4, 4), pbc=False), id="Triangular_4x4_obc" + ), + pytest.param( + lambda: CustomizeLattice( + dimensionality=2, + identifiers=list(range(4)), + coordinates=[[0.0, 0.0], [1.0, 1.0], [0.0, 1.0], [1.0, 0.0]], + ), + id="Customize_4sites", ), ] - @pytest.mark.parametrize("lattice", lattice_instances_for_invariant_test) - def test_distance_matrix_invariants_for_all_lattice_types(self, lattice): + @pytest.mark.parametrize( + "backend", [lf("npb"), lf("tfb"), lf("jaxb"), lf("torchb")] + ) + @pytest.mark.parametrize("lattice_factory", lattice_factories_for_invariant_test) + def test_distance_matrix_invariants_for_all_lattice_types( + self, backend, lattice_factory + ): """ Tests that the distance matrix for any lattice type adheres to fundamental mathematical properties (invariants): symmetry, zero diagonal, and positive off-diagonal elements. """ + tc.set_backend(backend) + # 重新实例化 lattice,确保没有跨 backend 的缓存副作用 + lattice = lattice_factory() + # Arrange n = lattice.num_sites if n < 2: @@ -1835,15 +1961,16 @@ def test_distance_matrix_invariants_for_all_lattice_types(self, lattice): # Assert # 1. Symmetry: The matrix must be equal to its transpose. + matrix_numpy = tc.backend.numpy(matrix) np.testing.assert_allclose( - matrix, - matrix.T, + matrix_numpy, + matrix_numpy.T, err_msg=f"Distance matrix for {type(lattice).__name__} is not symmetric.", ) # 2. Zero Diagonal: All diagonal elements must be zero. np.testing.assert_allclose( - np.diag(matrix), + np.diag(matrix_numpy), np.zeros(n), err_msg=f"Diagonal of distance matrix for {type(lattice).__name__} is not zero.", ) @@ -1852,18 +1979,23 @@ def test_distance_matrix_invariants_for_all_lattice_types(self, lattice): # We create a boolean mask for the off-diagonal elements. off_diagonal_mask = ~np.eye(n, dtype=bool) assert np.all( - matrix[off_diagonal_mask] > 1e-9 + matrix_numpy[off_diagonal_mask] > 1e-9 ), f"Found non-positive off-diagonal elements in distance matrix for {type(lattice).__name__}." - def test_distance_matrix_caching_is_effective(self): + @pytest.mark.parametrize( + "backend", [lf("npb"), lf("tfb"), lf("jaxb"), lf("torchb")] + ) + def test_distance_matrix_caching_is_effective(self, backend): """ Tests that the distance_matrix property is cached after the first access. """ + tc.set_backend(backend) + # Arrange: Create a lattice instance. lattice = CustomizeLattice( dimensionality=2, identifiers=["A", "B", "C"], - coordinates=[[0, 0], [1, 0], [0, 1]], + coordinates=[[0.0, 0.0], [1.0, 0.0], [0.0, 1.0]], ) # Act & Assert @@ -2026,15 +2158,6 @@ def test_layering_on_disconnected_graph(): assert (3, 4) in layer_with_01 or (4, 3) in layer_with_01 -# A map from backend name to the expected tensor type. -BACKEND_TENSOR_MAP = { - "numpy": np.ndarray, - "jax": jnp.ndarray if jnp else None, - "tensorflow": tf.Tensor if tf else None, - "pytorch": torch.Tensor if torch else None, -} - - class TestBackendIntegration: """ Tests to ensure lattice functionalities are consistent and correct @@ -2073,7 +2196,9 @@ class TestBackendIntegration: ), ] - @pytest.mark.parametrize("backend_name", ["numpy", "jax", "tensorflow", "pytorch"]) + @pytest.mark.parametrize( + "backend", [lf("npb"), lf("tfb"), lf("jaxb"), lf("torchb")] + ) @pytest.mark.parametrize( "LatticeClass, init_args, expected_distance_check, name", lattice_test_cases, @@ -2081,7 +2206,7 @@ class TestBackendIntegration: ) def test_lattice_creation_and_properties_across_backends( self, - backend_name, + backend, LatticeClass, init_args, expected_distance_check, @@ -2092,23 +2217,17 @@ def test_lattice_creation_and_properties_across_backends( core properties (_coordinates, distance_matrix) have the correct tensor types and values. """ - expected_tensor_type = BACKEND_TENSOR_MAP[backend_name] - if expected_tensor_type is None: - pytest.skip(f"Backend '{backend_name}' not installed.") - - tc.set_backend(backend_name) - # Create the lattice instance inside the test function lat = LatticeClass(**init_args) - # Assert that the internal coordinate and public distance matrix tensors - # have the correct type for the active backend. - assert isinstance( - lat._coordinates, expected_tensor_type - ), f"Failed for {type(lat).__name__} on backend {backend_name}" - assert isinstance( - lat.distance_matrix, expected_tensor_type - ), f"Failed for {type(lat).__name__} on backend {backend_name}" + # Verify that the lattice can be created and has correct properties + assert lat.num_sites > 0, f"Failed for {type(lat).__name__} on current backend" + assert hasattr( + lat, "_coordinates" + ), f"Failed for {type(lat).__name__} on current backend" + assert hasattr( + lat, "distance_matrix" + ), f"Failed for {type(lat).__name__} on current backend" # Unpack the distance check information idx1, idx2, expected_distance = expected_distance_check @@ -2121,7 +2240,7 @@ def test_lattice_creation_and_properties_across_backends( np.testing.assert_allclose( actual_distance, expected_distance, - err_msg=f"Distance check failed for {type(lat).__name__} on backend {backend_name}", + err_msg=f"Distance check failed for {type(lat).__name__} on current backend", ) @pytest.mark.parametrize( @@ -2161,8 +2280,6 @@ def test_tilattice_differentiability( with respect to their geometric parameters. This test has been expanded based on code review feedback to cover more lattice types. """ - if not jnp: - pytest.skip("JAX backend is required for this differentiability test.") def get_total_distance(param): """A scalar-in, scalar-out function for jax.grad.""" @@ -2195,11 +2312,9 @@ def test_customizelattice_differentiability(self, jaxb): Tests that the distance_matrix of a CustomizeLattice is differentiable with respect to its input coordinates. """ - # This test requires the JAX backend for its grad function. - if not jnp: - pytest.skip("JAX backend is required for this differentiability test.") - - initial_coords = jnp.array([[0.0, 0.0], [1.0, 1.0], [0.5, 0.5]]) + initial_coords = tc.backend.convert_to_tensor( + [[0.0, 0.0], [1.0, 1.0], [0.5, 0.5]] + ) def get_total_distance_custom(coords): """ @@ -2226,8 +2341,6 @@ def test_tilattice_gradient_value_correctness(self, jaxb): analytically calculated, correct gradient value. This is a stronger test than just checking for non-zero gradients. """ - if not jnp: - pytest.skip("JAX backend is required for this gradient value test.") # 1. Define a simple objective function def get_energy(a): @@ -2265,20 +2378,15 @@ def analytical_gradient(a): err_msg="The automatically differentiated gradient does not match the analytical gradient.", ) - @pytest.mark.parametrize("backend_name", ["numpy", "jax", "tensorflow", "pytorch"]) - def test_dynamic_modification_across_backends(self, backend_name): + @pytest.mark.parametrize( + "backend", [lf("npb"), lf("tfb"), lf("jaxb"), lf("torchb")] + ) + def test_dynamic_modification_across_backends(self, backend): """ Tests that the dynamic modification methods (add_sites, remove_sites) of CustomizeLattice work correctly across all supported backends, specifically checking tensor shapes. """ - # Arrange: Set up the backend and skip if not installed - expected_tensor_type = BACKEND_TENSOR_MAP[backend_name] - if expected_tensor_type is None: - pytest.skip(f"Backend '{backend_name}' not installed.") - - tc.set_backend(backend_name) - # --- Initial State --- # Create a simple lattice with 3 sites initial_coords = [[0.0, 0.0], [1.0, 0.0], [0.0, 1.0]] @@ -2297,7 +2405,7 @@ def test_dynamic_modification_across_backends(self, backend_name): assert lattice.num_sites == 5 assert ( lattice._coordinates.shape[0] == 5 - ), f"Tensor shape incorrect after add_sites on {backend_name} backend." + ), "Tensor shape incorrect after add_sites on current backend." # --- Test remove_sites --- # Act: Remove 1 site from the modified lattice @@ -2307,22 +2415,15 @@ def test_dynamic_modification_across_backends(self, backend_name): assert lattice.num_sites == 4 assert ( lattice._coordinates.shape[0] == 4 - ), f"Tensor shape incorrect after remove_sites on {backend_name} backend." + ), "Tensor shape incorrect after remove_sites on current backend." -@pytest.mark.parametrize("backend_name", ["numpy", "jax", "tensorflow", "pytorch"]) -def test_dtype_consistency_across_backends(backend_name): +@pytest.mark.parametrize("backend", [lf("npb"), lf("tfb"), lf("jaxb"), lf("torchb")]) +def test_dtype_consistency_across_backends(backend): """ Tests that the dtype of user-provided coordinate data is preserved in internal calculations across all backends. """ - # Arrange: Set up the backend and skip if not installed - if BACKEND_TENSOR_MAP[backend_name] is None: - pytest.skip(f"Backend '{backend_name}' not installed.") - - tc.set_backend(backend_name) - - # A map from backend name to its corresponding float32 dtype object # Prepare input data with a specific, non-default dtype coords_float32 = np.array([[0.0, 0.0], [1.0, 2.0]], dtype=np.float32) @@ -2333,11 +2434,11 @@ def test_dtype_consistency_across_backends(backend_name): # Assert: Check that the internal tensors have the correct dtype assert tc.backend.dtype(lattice._coordinates) == "float32", ( - f"Mismatch in coordinate dtype for backend {backend_name}. " + f"Mismatch in coordinate dtype for current backend. " f"Expected 'float32', got {tc.backend.dtype(lattice._coordinates)}." ) assert tc.backend.dtype(lattice.distance_matrix) == "float32", ( - f"Mismatch in distance matrix dtype for backend {backend_name}. " + f"Mismatch in distance matrix dtype for current backend. " f"Expected 'float32', got {tc.backend.dtype(lattice.distance_matrix)}." ) @@ -2360,11 +2461,16 @@ def simple_lattice_for_helpers(self): dimensionality=1, identifiers=[0, 1, 2], coordinates=[[0.0], [1.0], [2.0]] ) - def test_identify_distance_shells_basic(self, simple_lattice_for_helpers): + @pytest.mark.parametrize( + "backend", [lf("npb"), lf("tfb"), lf("jaxb"), lf("torchb")] + ) + def test_identify_distance_shells_basic(self, backend, simple_lattice_for_helpers): """ Tests the basic functionality of _identify_distance_shells with a clear separation between distance shells. """ + tc.set_backend(backend) + # Arrange lattice = simple_lattice_for_helpers # A set of squared distances with clear gaps between them. @@ -2378,13 +2484,18 @@ def test_identify_distance_shells_basic(self, simple_lattice_for_helpers): # The method should identify the unique, non-zero distances. np.testing.assert_allclose(shells, [1.0, 4.0, 9.0]) + @pytest.mark.parametrize( + "backend", [lf("npb"), lf("tfb"), lf("jaxb"), lf("torchb")] + ) def test_identify_distance_shells_with_max_k_limit( - self, simple_lattice_for_helpers + self, backend, simple_lattice_for_helpers ): """ Tests that _identify_distance_shells respects the max_k parameter, limiting the number of returned shells. """ + tc.set_backend(backend) + # Arrange lattice = simple_lattice_for_helpers all_distances_sq = np.array([0, 1.0, 4.0, 9.0, 16.0, 25.0]) @@ -2399,13 +2510,18 @@ def test_identify_distance_shells_with_max_k_limit( # The returned shells should be the first `max_k` smallest distances. np.testing.assert_allclose(shells, [1.0, 4.0, 9.0]) + @pytest.mark.parametrize( + "backend", [lf("npb"), lf("tfb"), lf("jaxb"), lf("torchb")] + ) def test_identify_distance_shells_with_tolerance_merging( - self, simple_lattice_for_helpers + self, backend, simple_lattice_for_helpers ): """ Tests that distances that are very close together are merged into a single shell when the tolerance `tol` is large enough. """ + tc.set_backend(backend) + # Arrange lattice = simple_lattice_for_helpers # Two distances are very close: 1.0 and 1.000001 From 9575be5c9470a54cc023b567fdc46377c5513e95 Mon Sep 17 00:00:00 2001 From: Stellogic Date: Thu, 14 Aug 2025 15:23:40 +0800 Subject: [PATCH 11/16] fix according to the review --- examples/lattice_neighbor_time_compare.py | 118 +++++++++++ examples/lennard_jones_optimization.py | 4 - tensorcircuit/backends/abstract_backend.py | 10 +- tensorcircuit/backends/numpy_backend.py | 3 - tensorcircuit/templates/lattice.py | 231 +++++++++++---------- tests/test_lattice.py | 81 ++++++-- 6 files changed, 304 insertions(+), 143 deletions(-) create mode 100644 examples/lattice_neighbor_time_compare.py diff --git a/examples/lattice_neighbor_time_compare.py b/examples/lattice_neighbor_time_compare.py new file mode 100644 index 00000000..ec4e5a0d --- /dev/null +++ b/examples/lattice_neighbor_time_compare.py @@ -0,0 +1,118 @@ +""" +Benchmark: Compare neighbor-building time between KDTree and distance-matrix +methods in CustomizeLattice for varying lattice sizes. +""" + +import argparse +import csv +import time +from typing import Iterable, List, Tuple, Optional +import logging + +import numpy as np + +# Silence verbose infos from the library during benchmarks + +logging.basicConfig(level=logging.WARNING) + +from tensorcircuit.templates.lattice import CustomizeLattice + + +def _timeit(fn, repeats: int) -> float: + """Return average wall time (seconds) over repeats for calling fn().""" + times: List[float] = [] + for _ in range(repeats): + t0 = time.perf_counter() + fn() + times.append(time.perf_counter() - t0) + return float(np.mean(times)) + + +def _gen_coords(n: int, d: int, seed: int) -> np.ndarray: + rng = np.random.default_rng(seed) + return rng.random((n, d), dtype=float) + + +def run_once( + n: int, d: int, max_k: int, repeats: int, seed: int +) -> Tuple[float, float]: + """Run one size point and return (time_kdtree, time_matrix).""" + coords = _gen_coords(n, d, seed) + ids = list(range(n)) + lat = CustomizeLattice(dimensionality=d, identifiers=ids, coordinates=coords) + + # KDTree path + t_kdtree = _timeit( + lambda: lat._build_neighbors(max_k=max_k, use_kdtree=True), repeats + ) + + # Distance-matrix path (fully differentiable) + t_matrix = _timeit( + lambda: lat._build_neighbors(max_k=max_k, use_kdtree=False), repeats + ) + + return t_kdtree, t_matrix + + +def parse_sizes(s: str) -> List[int]: + return [int(x) for x in s.split(",") if x.strip()] + + +def format_row(n: int, t_kdtree: float, t_matrix: float) -> str: + speedup = (t_matrix / t_kdtree) if t_kdtree > 0 else float("inf") + return f"{n:>8} | {t_kdtree:>12.6f} | {t_matrix:>14.6f} | {speedup:>7.2f}x" + + +def main(argv: Optional[Iterable[str]] = None) -> int: + p = argparse.ArgumentParser(description="Neighbor-building time comparison") + p.add_argument( + "--sizes", + type=parse_sizes, + default=[128, 256, 512, 1024, 2048], + help="Comma-separated site counts to benchmark (default: 128,256,512,1024,2048)", + ) + p.add_argument( + "--dims", type=int, default=2, help="Lattice dimensionality (default: 2)" + ) + p.add_argument( + "--max-k", type=int, default=6, help="Max neighbor shells k (default: 6)" + ) + p.add_argument( + "--repeats", type=int, default=5, help="Repeats per measurement (default: 5)" + ) + p.add_argument("--seed", type=int, default=42, help="PRNG seed (default: 42)") + p.add_argument("--csv", type=str, default="", help="Optional CSV output path") + args = p.parse_args(list(argv) if argv is not None else None) + + print("=" * 74) + print( + f"Benchmark CustomizeLattice neighbor-building | dims={args.dims} max_k={args.max_k} repeats={args.repeats}" + ) + print("=" * 74) + print(f"{'N':>8} | {'KDTree(s)':>12} | {'DistMatrix(s)':>14} | {'Speedup':>7}") + print("-" * 74) + + rows: List[Tuple[int, float, float]] = [] + for n in args.sizes: + t_kdtree, t_matrix = run_once(n, args.dims, args.max_k, args.repeats, args.seed) + rows.append((n, t_kdtree, t_matrix)) + print(format_row(n, t_kdtree, t_matrix)) + + if args.csv: + with open(args.csv, "w", newline="", encoding="utf-8") as f: + w = csv.writer(f) + w.writerow(["N", "time_kdtree_s", "time_distance_matrix_s", "speedup"]) + for n, t_kdtree, t_matrix in rows: + speedup = (t_matrix / t_kdtree) if t_kdtree > 0 else float("inf") + w.writerow([n, f"{t_kdtree:.6f}", f"{t_matrix:.6f}", f"{speedup:.2f}"]) + + print("-" * 74) + print(f"Saved CSV to: {args.csv}") + + print("-" * 74) + print("Done.") + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/examples/lennard_jones_optimization.py b/examples/lennard_jones_optimization.py index 7c38cf9d..298fd599 100644 --- a/examples/lennard_jones_optimization.py +++ b/examples/lennard_jones_optimization.py @@ -67,8 +67,6 @@ def calculate_potential(log_a, epsilon=0.5, sigma=1.0): history["a"].append(K.exp(log_a)) history["energy"].append(energy) - # (Removed previously added blanket NaN guard per reviewer request to keep example minimal.) - updates, opt_state = optimizer.update(grad, opt_state) log_a = optax.apply_updates(log_a, updates) @@ -120,5 +118,3 @@ def calculate_potential(log_a, epsilon=0.5, sigma=1.0): plt.legend() plt.grid(True) plt.show() -else: - print("\nOptimization failed. Final energy is NaN.") diff --git a/tensorcircuit/backends/abstract_backend.py b/tensorcircuit/backends/abstract_backend.py index 8dad618c..286a21f9 100644 --- a/tensorcircuit/backends/abstract_backend.py +++ b/tensorcircuit/backends/abstract_backend.py @@ -631,8 +631,8 @@ def meshgrid(self: Any, *args: Any, **kwargs: Any) -> Any: """ Return coordinate matrices from coordinate vectors. - :param args: coordinate vectors - :type args: Any + :param args: coordinate vectors + :type args: Any :param kwargs: keyword arguments for meshgrid, typically includes 'indexing' which can be 'ij' (matrix indexing) or 'xy' (Cartesian indexing). - 'ij': matrix indexing, first dimension corresponds to rows (default) @@ -647,9 +647,9 @@ def meshgrid(self: Any, *args: Any, **kwargs: Any) -> Any: [0, 1]] y = [[0, 0], [2, 2]] - :type kwargs: Any - :return: list of coordinate matrices - :rtype: Any + :type kwargs: Any + :return: list of coordinate matrices + :rtype: Any """ raise NotImplementedError( "Backend '{}' has not implemented `meshgrid`.".format(self.name) diff --git a/tensorcircuit/backends/numpy_backend.py b/tensorcircuit/backends/numpy_backend.py index 5dd7f625..8678f3dc 100644 --- a/tensorcircuit/backends/numpy_backend.py +++ b/tensorcircuit/backends/numpy_backend.py @@ -137,9 +137,6 @@ def kron(self, a: Tensor, b: Tensor) -> Tensor: return np.kron(a, b) def meshgrid(self, *args: Any, **kwargs: Any) -> Any: - """ - Backend-agnostic meshgrid function. - """ return np.meshgrid(*args, **kwargs) def dtype(self, a: Tensor) -> str: diff --git a/tensorcircuit/templates/lattice.py b/tensorcircuit/templates/lattice.py index ad86bf91..7074be4d 100644 --- a/tensorcircuit/templates/lattice.py +++ b/tensorcircuit/templates/lattice.py @@ -22,7 +22,7 @@ import itertools import math import numpy as np -from scipy.spatial import cKDTree +from scipy.spatial import KDTree from .. import backend @@ -72,9 +72,9 @@ def __init__(self, dimensionality: int): self._identifiers: List[SiteIdentifier] = ( [] ) # List of unique, hashable site identifiers - self._coordinates: Optional[Coordinates] = ( - None # N x D array of site coordinates - ) + # Always initialize to an empty coordinate tensor with correct dimensionality + # so that type checkers know this is indexable and not Optional. + self._coordinates: Coordinates = backend.zeros((0, dimensionality)) # Mappings for efficient lookups. self._ident_to_idx: Dict[SiteIdentifier, SiteIndex] = ( @@ -132,8 +132,6 @@ def get_coordinates(self, index: SiteIndex) -> Coordinates: :rtype: Coordinates """ self._validate_index(index) - if self._coordinates is None: - raise ValueError("Lattice coordinates have not been initialized.") coords = self._coordinates[index] return coords @@ -186,8 +184,6 @@ def get_site_info( - The site's coordinates as a NumPy array. :rtype: Tuple[SiteIndex, SiteIdentifier, Coordinates] """ - if self._coordinates is None: - raise ValueError("Lattice coordinates have not been initialized.") if isinstance(index_or_identifier, int): # SiteIndex is an int idx = index_or_identifier self._validate_index(idx) @@ -207,8 +203,6 @@ def sites(self) -> Iterator[Tuple[SiteIndex, SiteIdentifier, Coordinates]]: index, identifier, and coordinates. :rtype: Iterator[Tuple[SiteIndex, SiteIdentifier, Coordinates]] """ - if self._coordinates is None: - raise ValueError("Lattice coordinates have not been initialized.") for i in range(self.num_sites): yield i, self._identifiers[i], self._coordinates[i] @@ -324,14 +318,24 @@ def _build_neighbors(self, max_k: int = 1, **kwargs: Any) -> None: """ pass - @abc.abstractmethod def _compute_distance_matrix(self) -> Coordinates: """ - Abstract method for subclasses to implement the actual matrix calculation. - This method is called by the `distance_matrix` property when the matrix - needs to be computed for the first time. + Default generic distance matrix computation (no periodic images). + + Subclasses can override this when a specialized rule is required + (e.g., applying Minimum Image Convention for PBC in TILattice). """ - pass + # Handle empty lattices and trivial 1-site lattices + if self.num_sites == 0: + return backend.zeros((0, 0)) + + # Vectorized pairwise Euclidean distances + all_coords = self._coordinates + displacements = backend.expand_dims(all_coords, 1) - backend.expand_dims( + all_coords, 0 + ) + dist_matrix_sq = backend.sum(displacements**2, axis=-1) + return backend.sqrt(dist_matrix_sq) def show( self, @@ -401,10 +405,6 @@ def show( else: fig = ax.figure # type: ignore - if self._coordinates is None: - logger.error("Cannot show lattice: coordinates have not been initialized.") - return - coords = np.array(self._coordinates) # Prepare arguments for the scatter plot, allowing user overrides. scatter_args = {"s": 100, "zorder": 2} @@ -767,17 +767,17 @@ def _get_distance_matrix_with_mic_vectorized(self) -> Coordinates: """ # Ensure dtype consistency across backends (especially torch) by explicitly # casting size and lattice_vectors to the same floating dtype used internally. - # Strategy: prefer existing lattice_vectors dtype, fallback to float64 for precision, - # then float32 for compatibility. This avoids dtype mismatches in vectorized ops. + # Strategy: prefer existing lattice_vectors dtype, fallback to float32 for efficiency, + # then float64 for precision. This avoids dtype mismatches in vectorized ops. target_dt = None try: # prefer existing lattice_vectors dtype if possible target_dt = backend.dtype(self.lattice_vectors) # type: ignore - except Exception: # pragma: no cover - defensive - target_dt = "float64" + except (AttributeError, TypeError): # pragma: no cover - defensive + target_dt = "float32" if target_dt not in ("float32", "float64"): # fallback for unusual dtypes - target_dt = "float64" + target_dt = "float32" size_arr = backend.cast(backend.convert_to_tensor(self.size), target_dt) lattice_vecs = backend.cast( @@ -890,13 +890,17 @@ def __init__( ): """Initializes the SquareLattice.""" dimensionality = 2 - # Define orthogonal lattice vectors for a square. - lattice_vectors = backend.convert_to_tensor( - [[lattice_constant, 0.0], [0.0, lattice_constant]] - ) + # Avoid mixing Python floats with backend Tensors (TF would error), + # so first convert inputs to tensors of a unified dtype, then stack. + lc = backend.convert_to_tensor(lattice_constant) + dt = backend.dtype(lc) + z = backend.cast(backend.convert_to_tensor(0.0), dt) + row1 = backend.stack([lc, z]) + row2 = backend.stack([z, lc]) + lattice_vectors = backend.stack([row1, row2]) # A square lattice is a Bravais lattice, so it has a single-site basis. - basis_coords = backend.convert_to_tensor([[0.0, 0.0]]) + basis_coords = backend.stack([backend.stack([z, z])]) super().__init__( dimensionality=dimensionality, @@ -940,17 +944,21 @@ def __init__( """Initializes the HoneycombLattice.""" dimensionality = 2 a = lattice_constant + a_t = backend.convert_to_tensor(a) + zero = a_t * 0.0 # Define the two primitive lattice vectors for the underlying triangular Bravais lattice. rt3_over_2 = math.sqrt(3.0) / 2.0 - lattice_vectors = backend.convert_to_tensor( + lattice_vectors = backend.stack( [ - [a * 1.5, a * rt3_over_2], - [a * 1.5, -a * rt3_over_2], + backend.stack([a_t * 1.5, a_t * rt3_over_2]), + backend.stack([a_t * 1.5, -a_t * rt3_over_2]), ] ) # Define the two basis sites (A and B) within the unit cell. - basis_coords = backend.convert_to_tensor([[0.0, 0.0], [a * 1.0, 0.0]]) + basis_coords = backend.stack( + [backend.stack([zero, zero]), backend.stack([a_t * 1.0, zero])] + ) super().__init__( dimensionality=dimensionality, @@ -992,16 +1000,23 @@ def __init__( """Initializes the TriangularLattice.""" dimensionality = 2 a = lattice_constant + a_t = backend.convert_to_tensor(a) + zero = a_t * 0.0 # Define the primitive lattice vectors for a triangular lattice. - lattice_vectors = backend.convert_to_tensor( + lattice_vectors = backend.stack( [ - [a * 1.0, 0.0], - [a * 0.5, a * backend.sqrt(backend.convert_to_tensor(3.0)) / 2.0], + backend.stack([a_t * 1.0, zero]), + backend.stack( + [ + a_t * 0.5, + a_t * backend.sqrt(backend.convert_to_tensor(3.0)) / 2.0, + ] + ), ] ) # A triangular lattice is a Bravais lattice with a single-site basis. - basis_coords = backend.convert_to_tensor([[0.0, 0.0]]) + basis_coords = backend.stack([backend.stack([zero, zero])]) super().__init__( dimensionality=dimensionality, @@ -1036,9 +1051,11 @@ def __init__( ): dimensionality = 1 # The lattice vector is just the lattice constant along one dimension. - lattice_vectors = backend.convert_to_tensor([[lattice_constant]]) + lc = backend.convert_to_tensor(lattice_constant) + lattice_vectors = backend.stack([backend.stack([lc])]) # A simple chain is a Bravais lattice with a single-site basis. - basis_coords = backend.convert_to_tensor([[0.0]]) + zero = lc * 0.0 + basis_coords = backend.stack([backend.stack([zero])]) super().__init__( dimensionality=dimensionality, @@ -1077,9 +1094,11 @@ def __init__( ): dimensionality = 1 # The unit cell is twice the bond length, as it contains two sites. - lattice_vectors = backend.convert_to_tensor([[2 * lattice_constant]]) + lc = backend.convert_to_tensor(lattice_constant) + lattice_vectors = backend.stack([backend.stack([2 * lc])]) # Two basis sites (A and B) separated by the bond length. - basis_coords = backend.convert_to_tensor([[0.0], [lattice_constant]]) + zero = lc * 0.0 + basis_coords = backend.stack([backend.stack([zero]), backend.stack([lc])]) super().__init__( dimensionality=dimensionality, @@ -1118,10 +1137,16 @@ def __init__( ): dimensionality = 2 ax, ay = lattice_constants + ax_t = backend.convert_to_tensor(ax) + dt = backend.dtype(ax_t) + ay_t = backend.cast(backend.convert_to_tensor(ay), dt) + z = backend.cast(backend.convert_to_tensor(0.0), dt) # Orthogonal lattice vectors with potentially different lengths. - lattice_vectors = backend.convert_to_tensor([[ax, 0.0], [0.0, ay]]) + row1 = backend.stack([ax_t, z]) + row2 = backend.stack([z, ay_t]) + lattice_vectors = backend.stack([row1, row2]) # A rectangular lattice is a Bravais lattice with a single-site basis. - basis_coords = backend.convert_to_tensor([[0.0, 0.0]]) + basis_coords = backend.stack([backend.stack([z, z])]) super().__init__( dimensionality=dimensionality, @@ -1158,13 +1183,18 @@ def __init__( ): dimensionality = 2 a = lattice_constant + a_t = backend.convert_to_tensor(a) # The unit cell is a square rotated by 45 degrees. - lattice_vectors = backend.convert_to_tensor( - [[a * 1.0, a * 1.0], [a * 1.0, a * -1.0]] + lattice_vectors = backend.stack( + [ + backend.stack([a_t * 1.0, a_t * 1.0]), + backend.stack([a_t * 1.0, a_t * -1.0]), + ] ) # Two basis sites (A and B) within the unit cell. - basis_coords = backend.convert_to_tensor( - [[a * 0.0, a * 0.0], [a * 1.0, a * 0.0]] + zero = a_t * 0.0 + basis_coords = backend.stack( + [backend.stack([zero, zero]), backend.stack([a_t * 1.0, zero])] ) super().__init__( @@ -1202,19 +1232,21 @@ def __init__( ): dimensionality = 2 a = lattice_constant + a_t = backend.convert_to_tensor(a) # The Kagome lattice is based on a triangular Bravais lattice. - lattice_vectors = backend.convert_to_tensor( + lattice_vectors = backend.stack( [ - [a * 2.0, a * 0.0], - [a * 1.0, a * backend.sqrt(3.0)], + backend.stack([a_t * 2.0, a_t * 0.0]), + backend.stack([a_t * 1.0, a_t * backend.sqrt(3.0)]), ] ) # It has a three-site basis, forming the corners of the triangles. - basis_coords = backend.convert_to_tensor( + zero = a_t * 0.0 + basis_coords = backend.stack( [ - [a * 0.0, a * 0.0], - [a * 1.0, a * 0.0], - [a * 0.5, a * backend.sqrt(3.0) / 2.0], + backend.stack([zero, zero]), + backend.stack([a_t * 1.0, zero]), + backend.stack([a_t * 0.5, a_t * backend.sqrt(3.0) / 2.0]), ] ) @@ -1255,18 +1287,19 @@ def __init__( """Initializes the LiebLattice.""" dimensionality = 2 bond_length = lattice_constant - - unit_cell_side = 2 * bond_length + bl_t = backend.convert_to_tensor(bond_length) + unit_cell_side_t = 2 * bl_t # The Lieb lattice is based on a square Bravais lattice. - lattice_vectors = backend.convert_to_tensor( - [[unit_cell_side, 0.0], [0.0, unit_cell_side]] + z = bl_t * 0.0 + lattice_vectors = backend.stack( + [backend.stack([unit_cell_side_t, z]), backend.stack([z, unit_cell_side_t])] ) # It has a three-site basis: one corner and two edge-centers. - basis_coords = backend.convert_to_tensor( + basis_coords = backend.stack( [ - [0.0, 0.0], # Corner site - [bond_length, 0.0], # x-edge center - [0.0, bond_length], # y-edge center + backend.stack([z, z]), # Corner site + backend.stack([bl_t, z]), # x-edge center + backend.stack([z, bl_t]), # y-edge center ] ) @@ -1305,10 +1338,18 @@ def __init__( ): dimensionality = 3 a = lattice_constant + a_t = backend.convert_to_tensor(a) # Orthogonal lattice vectors of equal length in 3D. - lattice_vectors = backend.convert_to_tensor([[a, 0, 0], [0, a, 0], [0, 0, a]]) + z = a_t * 0.0 + lattice_vectors = backend.stack( + [ + backend.stack([a_t, z, z]), + backend.stack([z, a_t, z]), + backend.stack([z, z, a_t]), + ] + ) # A simple cubic lattice is a Bravais lattice with a single-site basis. - basis_coords = backend.convert_to_tensor([[0.0, 0.0, 0.0]]) + basis_coords = backend.stack([backend.stack([z, z, z])]) super().__init__( dimensionality=dimensionality, lattice_vectors=lattice_vectors, @@ -1402,11 +1443,12 @@ def _build_neighbors(self, max_k: int = 1, **kwargs: Any) -> None: :param max_k: Maximum number of neighbor shells to compute :type max_k: int :param kwargs: Additional arguments including: - - use_kdtree (bool): Whether to use KDTree optimization. Defaults to True. + - use_kdtree (bool): Whether to use KDTree optimization. Defaults to False. - tol (float): Distance tolerance for neighbor identification. Defaults to 1e-6. """ tol = kwargs.get("tol", 1e-6) - use_kdtree = kwargs.get("use_kdtree", True) + # Reviewer suggestion: prefer differentiable method by default + use_kdtree = kwargs.get("use_kdtree", False) if self.num_sites < 2: return @@ -1425,8 +1467,6 @@ def _build_neighbors(self, max_k: int = 1, **kwargs: Any) -> None: # Use the existing distance matrix method self._build_neighbors_by_distance_matrix(max_k, tol) - logger.info(f"Neighbor building complete for CustomizeLattice up to k={max_k}.") - def _build_neighbors_kdtree(self, max_k: int, tol: float) -> None: """ Build neighbors using KDTree for optimal performance. @@ -1445,7 +1485,7 @@ def _build_neighbors_kdtree(self, max_k: int, tol: float) -> None: # Build KDTree logger.info("Building KDTree...") - tree = cKDTree(coords_np) + tree = KDTree(coords_np) # For small lattices or cases with potential duplicate coordinates, # fall back to distance matrix method for robustness @@ -1558,24 +1598,6 @@ def _build_neighbors_kdtree(self, max_k: int, tol: float) -> None: self._distance_matrix = None logger.info("KDTree neighbor building completed") - def _compute_distance_matrix(self) -> Coordinates: - """ - Computes the full N x N distance matrix by delegating to the inherited method. - This avoids code duplication with the base class implementation. - """ - if self.num_sites == 0: - return backend.zeros((0, 0)) - if self.num_sites < 2: - if self._coordinates is None: - raise ValueError("Lattice coordinates have not been initialized.") - return backend.zeros( - (self.num_sites, self.num_sites), dtype=self._coordinates.dtype - ) - - # Use the inherited method from AbstractLattice which computes and caches the distance matrix - self._build_neighbors_by_distance_matrix(max_k=1, tol=1e-6) - return self._distance_matrix - def _reset_computations(self) -> None: """Resets all cached data that depends on the lattice structure.""" self._neighbor_maps = {} @@ -1602,32 +1624,23 @@ def from_lattice(cls, lattice: "AbstractLattice") -> "CustomizeLattice": ) # Unzip the list of tuples into separate lists of identifiers and coordinates - _, identifiers, coordinates = zip(*all_sites_info) - - # Normalize coordinates to plain nested Python float lists to avoid - # backend-specific tensor list issues (e.g., torch.tensor(list_of_tensors) ValueError). - # This ensures the resulting CustomizeLattice works consistently across all backends - # by converting any backend tensors to backend-agnostic Python lists. - normalized_coords = [] - for c in coordinates: - try: - # If already a backend tensor, convert to numpy then to list - if hasattr(backend, "is_tensor") and backend.is_tensor(c): # type: ignore - normalized_coords.append(backend.numpy(c).tolist()) # type: ignore - else: - # c may be a numpy array or list-like - if hasattr(c, "tolist"): - normalized_coords.append(c.tolist()) # type: ignore - else: - normalized_coords.append(list(c)) # fallback - except Exception: # pragma: no cover - defensive - # Last resort: wrap scalar(s) - normalized_coords.append([float(x) for x in c]) # type: ignore + _, identifiers, _ = zip(*all_sites_info) + + # Detach-and-copy coordinates in a backend-agnostic way. + # Rationale (answering reviewer question "why not keep backend-dependent form?"): + # - Passing a tuple/list of backend tensors (e.g., per-row slices) into + # convert_to_tensor can fail on some backends (torch.tensor(list_of_tensors) ValueError), + # whereas a plain nested Python list is accepted everywhere. + # - We want CustomizeLattice to be decoupled from the original lattice's computation + # graph and device state (CPU/GPU), so we materialize numeric values here. + # - This is a one-shot conversion of the full coordinate array, simpler and faster + # than iterating per row, while preserving the same numeric content. + coords_py = backend.numpy(lattice._coordinates).tolist() return cls( dimensionality=lattice.dimensionality, identifiers=list(identifiers), - coordinates=normalized_coords, + coordinates=coords_py, ) def add_sites( diff --git a/tests/test_lattice.py b/tests/test_lattice.py index 69923071..481fc867 100644 --- a/tests/test_lattice.py +++ b/tests/test_lattice.py @@ -528,7 +528,10 @@ def test_customizelattice_max_k_precomputation_and_ondemand(self): f"but found {computed_shells_after}." ) - def test_precompute_neighbors_on_init_custom(self): + @pytest.mark.parametrize( + "backend", [lf("npb"), lf("tfb"), lf("jaxb"), lf("torchb")] + ) + def test_precompute_neighbors_on_init_custom(self, backend): """ Tests that the `precompute_neighbors` argument correctly populates the neighbor map upon initialization for CustomizeLattice. @@ -1404,7 +1407,10 @@ def test_mixed_boundary_conditions(self): edge_neighbors == expected_edge_indices ), "Failed for edge site with mixed BC." - def test_mixed_boundary_conditions_on_honeycomb(self): + @pytest.mark.parametrize( + "backend", [lf("npb"), lf("tfb"), lf("jaxb"), lf("torchb")] + ) + def test_mixed_boundary_conditions_on_honeycomb(self, backend): """ Tests neighbor finding on a HoneycombLattice with mixed PBC. This ensures the logic correctly handles composite lattices with @@ -1531,7 +1537,6 @@ def initial_lattice(self): ) def test_from_lattice_conversion(self, backend): """Tests creating a CustomizeLattice from a TILattice.""" - tc.set_backend(backend) # Arrange sq_lattice = SquareLattice(size=(2, 2), pbc=False) @@ -1556,7 +1561,6 @@ def test_from_lattice_conversion(self, backend): ) def test_add_sites_successfully(self, backend, initial_lattice): """Tests adding new, valid sites to the lattice.""" - tc.set_backend(backend) # Arrange lat = initial_lattice @@ -1578,7 +1582,6 @@ def test_add_sites_successfully(self, backend, initial_lattice): ) def test_remove_sites_successfully(self, backend, initial_lattice): """Tests removing existing sites from the lattice.""" - tc.set_backend(backend) # Arrange lat = initial_lattice @@ -1600,7 +1603,6 @@ def test_remove_sites_successfully(self, backend, initial_lattice): ) def test_add_duplicate_identifier_raises_error(self, backend, initial_lattice): """Tests that adding a site with an existing identifier fails.""" - tc.set_backend(backend) with pytest.raises(ValueError, match="Duplicate identifiers found"): initial_lattice.add_sites(identifiers=["A"], coordinates=[[9.0, 9.0]]) @@ -1610,7 +1612,6 @@ def test_add_duplicate_identifier_raises_error(self, backend, initial_lattice): ) def test_remove_nonexistent_identifier_raises_error(self, backend, initial_lattice): """Tests that removing a non-existent site fails.""" - tc.set_backend(backend) with pytest.raises(ValueError, match="Non-existent identifiers provided"): initial_lattice.remove_sites(identifiers=["Z"]) @@ -1623,7 +1624,6 @@ def test_modification_clears_neighbor_cache(self, backend, initial_lattice): Tests that add_sites and remove_sites correctly invalidate the pre-computed neighbor map. """ - tc.set_backend(backend) # Arrange: Pre-compute neighbors on the initial lattice initial_lattice._build_neighbors(max_k=1) @@ -1653,7 +1653,6 @@ def test_modification_clears_distance_matrix_cache(self, backend, initial_lattic Tests that add_sites and remove_sites correctly invalidate the cached distance matrix and that the recomputed matrix is correct. """ - tc.set_backend(backend) # Arrange 1: Compute, cache, and perform a meaningful check on the original matrix. lat = initial_lattice @@ -1696,7 +1695,6 @@ def test_neighbor_finding_returns_sorted_list(self, backend, simple_square_latti Ensures that the list of neighbors returned by get_neighbors is always sorted. This provides a stricter check than set-based comparisons. """ - tc.set_backend(backend) # Arrange lattice = simple_square_lattice @@ -1719,7 +1717,6 @@ def test_neighbor_finding_returns_sorted_list(self, backend, simple_square_latti ) def test_from_lattice_from_empty_lattice(self, backend): """Tests creating a CustomizeLattice from an empty TILattice.""" - tc.set_backend(backend) # Arrange: Create an empty TILattice instance. empty_sq = SquareLattice(size=(0, 0)) @@ -1737,7 +1734,6 @@ def test_from_lattice_from_empty_lattice(self, backend): ) def test_add_sites_to_empty_lattice(self, backend): """Tests adding sites to a previously empty CustomizeLattice.""" - tc.set_backend(backend) # Arrange: Create an empty CustomizeLattice. empty_lat = CustomizeLattice(dimensionality=2, identifiers=[], coordinates=[]) @@ -1763,7 +1759,6 @@ def test_add_and_remove_empty_list_of_sites(self, backend, initial_lattice): Tests that calling add_sites and remove_sites with empty lists is a no-op and doesn't change the lattice state. """ - tc.set_backend(backend) # Arrange lat = initial_lattice @@ -1790,7 +1785,6 @@ def test_add_and_remove_empty_list_of_sites(self, backend, initial_lattice): ) def test_remove_all_sites(self, backend, initial_lattice): """Tests removing all sites from a lattice, resulting in an empty lattice.""" - tc.set_backend(backend) # Arrange lat = initial_lattice @@ -1879,7 +1873,6 @@ def test_tilattice_mixed_bc_distance_matrix_is_correct(self, backend): Tests that the distance matrix is correctly calculated for a TILattice with mixed boundary conditions (e.g., periodic in x, open in y). """ - tc.set_backend(backend) # Arrange # pbc=(True, False) means periodic along x-axis, open along y-axis. @@ -1911,7 +1904,8 @@ def test_tilattice_mixed_bc_distance_matrix_is_correct(self, backend): ) # --- This list and the following test are now at the correct indentation level --- - # 使用工厂函数而不是预先实例化对象,避免跨 backend 复用已缓存 _distance_matrix(numpy 数组) + # Use factory functions instead of pre-instantiated objects + # to avoid sharing a cached _distance_matrix (NumPy array) across backends lattice_factories_for_invariant_test = [ pytest.param(lambda: SquareLattice(size=(4, 4), pbc=True), id="Square_4x4_pbc"), pytest.param( @@ -1945,8 +1939,7 @@ def test_distance_matrix_invariants_for_all_lattice_types( fundamental mathematical properties (invariants): symmetry, zero diagonal, and positive off-diagonal elements. """ - tc.set_backend(backend) - # 重新实例化 lattice,确保没有跨 backend 的缓存副作用 + # Re-instantiate the lattice to ensure no cache side effects across backends lattice = lattice_factory() # Arrange @@ -1989,7 +1982,6 @@ def test_distance_matrix_caching_is_effective(self, backend): """ Tests that the distance_matrix property is cached after the first access. """ - tc.set_backend(backend) # Arrange: Create a lattice instance. lattice = CustomizeLattice( @@ -2098,6 +2090,54 @@ def test_layering_on_various_lattices(lattice_instance): _validate_layers(bonds, layers) +# --- Regression tests for backend-scalar lattice constants (PR fix) --- +@pytest.mark.parametrize("backend", [lf("npb"), lf("tfb"), lf("jaxb"), lf("torchb")]) +def test_square_lattice_accepts_backend_scalar_lattice_constant(backend): + """ + Ensure SquareLattice can be constructed when lattice_constant is a backend scalar tensor + (e.g., tf.constant, jnp.array, torch.tensor), without mixed-type errors. + """ + tc.set_backend(backend) + + lc = tc.backend.convert_to_tensor(0.5) + lat = SquareLattice(size=(2, 2), lattice_constant=lc, pbc=False) + + # basic shape sanity + assert lat.num_sites == 4 + assert tc.backend.shape_tuple(lat._coordinates)[1] == 2 # type: ignore[attr-defined] + + # distance check along x and y + dm = lat.distance_matrix + o = lat.get_index((0, 0, 0)) + x1 = lat.get_index((1, 0, 0)) + y1 = lat.get_index((0, 1, 0)) + np.testing.assert_allclose(tc.backend.numpy(dm[o, x1]), 0.5) + np.testing.assert_allclose(tc.backend.numpy(dm[o, y1]), 0.5) + + +@pytest.mark.parametrize("backend", [lf("npb"), lf("tfb"), lf("jaxb"), lf("torchb")]) +def test_rectangular_lattice_mixed_type_constants(backend): + """ + RectangularLattice should accept a tuple where one constant is a backend scalar tensor + and the other is a Python float. + """ + tc.set_backend(backend) + + ax = tc.backend.convert_to_tensor(0.5) # tensor scalar + ay = 2.0 # python float + lat = RectangularLattice(size=(2, 2), lattice_constants=(ax, ay), pbc=False) + + assert lat.num_sites == 4 + assert tc.backend.shape_tuple(lat._coordinates)[1] == 2 # type: ignore[attr-defined] + + dm = lat.distance_matrix + o = lat.get_index((0, 0, 0)) + x1 = lat.get_index((1, 0, 0)) + y1 = lat.get_index((0, 1, 0)) + np.testing.assert_allclose(tc.backend.numpy(dm[o, x1]), 0.5) + np.testing.assert_allclose(tc.backend.numpy(dm[o, y1]), 2.0) + + def test_layering_on_1d_chain_pbc(): """Test layering on a 1D chain with periodic boundaries (a cycle graph).""" lattice_even = ChainLattice(size=(6,), pbc=True) @@ -2469,7 +2509,6 @@ def test_identify_distance_shells_basic(self, backend, simple_lattice_for_helper Tests the basic functionality of _identify_distance_shells with a clear separation between distance shells. """ - tc.set_backend(backend) # Arrange lattice = simple_lattice_for_helpers @@ -2494,7 +2533,6 @@ def test_identify_distance_shells_with_max_k_limit( Tests that _identify_distance_shells respects the max_k parameter, limiting the number of returned shells. """ - tc.set_backend(backend) # Arrange lattice = simple_lattice_for_helpers @@ -2520,7 +2558,6 @@ def test_identify_distance_shells_with_tolerance_merging( Tests that distances that are very close together are merged into a single shell when the tolerance `tol` is large enough. """ - tc.set_backend(backend) # Arrange lattice = simple_lattice_for_helpers From d372f7264dcf9b2f8cbba9d9f16fffdd93729ea0 Mon Sep 17 00:00:00 2001 From: Stellogic Date: Thu, 14 Aug 2025 15:58:50 +0800 Subject: [PATCH 12/16] update lattice_neighbor_time_compare.py to enhance the accuracy --- examples/lattice_neighbor_time_compare.py | 49 ++++++++++------------- 1 file changed, 21 insertions(+), 28 deletions(-) diff --git a/examples/lattice_neighbor_time_compare.py b/examples/lattice_neighbor_time_compare.py index ec4e5a0d..2101307c 100644 --- a/examples/lattice_neighbor_time_compare.py +++ b/examples/lattice_neighbor_time_compare.py @@ -18,40 +18,33 @@ from tensorcircuit.templates.lattice import CustomizeLattice -def _timeit(fn, repeats: int) -> float: - """Return average wall time (seconds) over repeats for calling fn().""" - times: List[float] = [] - for _ in range(repeats): - t0 = time.perf_counter() - fn() - times.append(time.perf_counter() - t0) - return float(np.mean(times)) - - -def _gen_coords(n: int, d: int, seed: int) -> np.ndarray: - rng = np.random.default_rng(seed) - return rng.random((n, d), dtype=float) - - def run_once( n: int, d: int, max_k: int, repeats: int, seed: int ) -> Tuple[float, float]: """Run one size point and return (time_kdtree, time_matrix).""" - coords = _gen_coords(n, d, seed) + rng = np.random.default_rng(seed) ids = list(range(n)) - lat = CustomizeLattice(dimensionality=d, identifiers=ids, coordinates=coords) - - # KDTree path - t_kdtree = _timeit( - lambda: lat._build_neighbors(max_k=max_k, use_kdtree=True), repeats - ) - - # Distance-matrix path (fully differentiable) - t_matrix = _timeit( - lambda: lat._build_neighbors(max_k=max_k, use_kdtree=False), repeats - ) + + # Collect times for each repeat with different random coordinates + kdtree_times: List[float] = [] + matrix_times: List[float] = [] + + for i in range(repeats): + # Generate different coordinates for each repeat + coords = rng.random((n, d), dtype=float) + lat = CustomizeLattice(dimensionality=d, identifiers=ids, coordinates=coords) + + # KDTree path - single measurement + t0 = time.perf_counter() + lat._build_neighbors(max_k=max_k, use_kdtree=True) + kdtree_times.append(time.perf_counter() - t0) + + # Distance-matrix path - single measurement + t0 = time.perf_counter() + lat._build_neighbors(max_k=max_k, use_kdtree=False) + matrix_times.append(time.perf_counter() - t0) - return t_kdtree, t_matrix + return float(np.mean(kdtree_times)), float(np.mean(matrix_times)) def parse_sizes(s: str) -> List[int]: From 0b38522ac17cbe53638f2afba9110ee695566d00 Mon Sep 17 00:00:00 2001 From: Stellogic Date: Thu, 14 Aug 2025 16:03:50 +0800 Subject: [PATCH 13/16] fix black --- examples/lattice_neighbor_time_compare.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/examples/lattice_neighbor_time_compare.py b/examples/lattice_neighbor_time_compare.py index 2101307c..dada10a0 100644 --- a/examples/lattice_neighbor_time_compare.py +++ b/examples/lattice_neighbor_time_compare.py @@ -24,22 +24,22 @@ def run_once( """Run one size point and return (time_kdtree, time_matrix).""" rng = np.random.default_rng(seed) ids = list(range(n)) - + # Collect times for each repeat with different random coordinates kdtree_times: List[float] = [] matrix_times: List[float] = [] - - for i in range(repeats): + + for _ in range(repeats): # Generate different coordinates for each repeat coords = rng.random((n, d), dtype=float) lat = CustomizeLattice(dimensionality=d, identifiers=ids, coordinates=coords) - + # KDTree path - single measurement t0 = time.perf_counter() lat._build_neighbors(max_k=max_k, use_kdtree=True) kdtree_times.append(time.perf_counter() - t0) - - # Distance-matrix path - single measurement + + # Distance-matrix path - single measurement t0 = time.perf_counter() lat._build_neighbors(max_k=max_k, use_kdtree=False) matrix_times.append(time.perf_counter() - t0) From 04aca93ff87598f37012f5167b3a4b906d96bb37 Mon Sep 17 00:00:00 2001 From: Stellogic Date: Fri, 15 Aug 2025 11:07:16 +0800 Subject: [PATCH 14/16] fix according to the review --- examples/lennard_jones_optimization.py | 77 +++++++++++++------------- tensorcircuit/templates/lattice.py | 34 +++++------- tests/test_lattice.py | 3 - 3 files changed, 53 insertions(+), 61 deletions(-) diff --git a/examples/lennard_jones_optimization.py b/examples/lennard_jones_optimization.py index 298fd599..f3c4c825 100644 --- a/examples/lennard_jones_optimization.py +++ b/examples/lennard_jones_optimization.py @@ -79,42 +79,41 @@ def calculate_potential(log_a, epsilon=0.5, sigma=1.0): final_a = K.exp(log_a) final_energy = calculate_potential(log_a) -if not np.isnan(K.numpy(final_energy)): - print("\nOptimization finished!") - print(f"Final optimized lattice constant: {final_a:.6f}") - print(f"Corresponding minimum total energy: {final_energy:.6f}") - - # Vectorized calculation for the potential curve - a_vals = np.linspace(0.8, 1.5, 200) - log_a_vals = K.log(K.convert_to_tensor(a_vals)) - - # Use vmap to create a vectorized version of the potential function - vmap_potential = K.vmap(lambda la: calculate_potential(la)) - potential_curve = vmap_potential(log_a_vals) - - plt.figure(figsize=(10, 6)) - plt.plot(a_vals, potential_curve, label="Lennard-Jones Potential", color="blue") - plt.scatter( - history["a"], - history["energy"], - color="red", - s=20, - zorder=5, - label="Optimization Steps", - ) - plt.scatter( - final_a, - final_energy, - color="green", - s=100, - zorder=6, - marker="*", - label="Final Optimized Point", - ) - - plt.title("Lennard-Jones Potential Optimization") - plt.xlabel("Lattice Constant (a)") - plt.ylabel("Total Potential Energy") - plt.legend() - plt.grid(True) - plt.show() +print("\nOptimization finished!") +print(f"Final optimized lattice constant: {final_a:.6f}") +print(f"Corresponding minimum total energy: {final_energy:.6f}") + +# Vectorized calculation for the potential curve +a_vals = np.linspace(0.8, 1.5, 200) +log_a_vals = K.log(K.convert_to_tensor(a_vals)) + +# Use vmap to create a vectorized version of the potential function +vmap_potential = K.vmap(lambda la: calculate_potential(la)) +potential_curve = vmap_potential(log_a_vals) + +plt.figure(figsize=(10, 6)) +plt.plot(a_vals, potential_curve, label="Lennard-Jones Potential", color="blue") +plt.scatter( + history["a"], + history["energy"], + color="red", + s=20, + zorder=5, + label="Optimization Steps", +) +plt.scatter( + final_a, + final_energy, + color="green", + s=100, + zorder=6, + marker="*", + label="Final Optimized Point", +) + +plt.title("Lennard-Jones Potential Optimization") +plt.xlabel("Lattice Constant (a)") +plt.ylabel("Total Potential Energy") +plt.legend() +plt.grid(True) +plt.show() diff --git a/tensorcircuit/templates/lattice.py b/tensorcircuit/templates/lattice.py index 7074be4d..49508487 100644 --- a/tensorcircuit/templates/lattice.py +++ b/tensorcircuit/templates/lattice.py @@ -767,14 +767,11 @@ def _get_distance_matrix_with_mic_vectorized(self) -> Coordinates: """ # Ensure dtype consistency across backends (especially torch) by explicitly # casting size and lattice_vectors to the same floating dtype used internally. - # Strategy: prefer existing lattice_vectors dtype, fallback to float32 for efficiency, - # then float64 for precision. This avoids dtype mismatches in vectorized ops. - target_dt = None - try: - # prefer existing lattice_vectors dtype if possible - target_dt = backend.dtype(self.lattice_vectors) # type: ignore - except (AttributeError, TypeError): # pragma: no cover - defensive - target_dt = "float32" + # Strategy: prefer existing lattice_vectors dtype; if it's an unusual dtype, + # fall back to float32 to avoid mixed-precision issues in vectorized ops. + # Note: `self.lattice_vectors` is always created via `backend.convert_to_tensor` + # in __init__, so `backend.dtype(...)` is reliable here and doesn't need try/except. + target_dt = str(backend.dtype(self.lattice_vectors)) # type: ignore if target_dt not in ("float32", "float64"): # fallback for unusual dtypes target_dt = "float32" @@ -1626,21 +1623,20 @@ def from_lattice(cls, lattice: "AbstractLattice") -> "CustomizeLattice": # Unzip the list of tuples into separate lists of identifiers and coordinates _, identifiers, _ = zip(*all_sites_info) - # Detach-and-copy coordinates in a backend-agnostic way. - # Rationale (answering reviewer question "why not keep backend-dependent form?"): - # - Passing a tuple/list of backend tensors (e.g., per-row slices) into - # convert_to_tensor can fail on some backends (torch.tensor(list_of_tensors) ValueError), - # whereas a plain nested Python list is accepted everywhere. - # - We want CustomizeLattice to be decoupled from the original lattice's computation - # graph and device state (CPU/GPU), so we materialize numeric values here. - # - This is a one-shot conversion of the full coordinate array, simpler and faster - # than iterating per row, while preserving the same numeric content. - coords_py = backend.numpy(lattice._coordinates).tolist() + # Detach-and-copy coordinates while remaining in tensor form to avoid + # host roundtrips and device/dtype changes; this keeps CustomizeLattice + # decoupled from the original graph but backend-friendly. + # Some backends (e.g., NumPy) don't implement stop_gradient; fall back. + try: + coords_detached = backend.stop_gradient(lattice._coordinates) + except NotImplementedError: + coords_detached = lattice._coordinates + coords_tensor = backend.copy(coords_detached) return cls( dimensionality=lattice.dimensionality, identifiers=list(identifiers), - coordinates=coords_py, + coordinates=coords_tensor, ) def add_sites( diff --git a/tests/test_lattice.py b/tests/test_lattice.py index 481fc867..fe730dd7 100644 --- a/tests/test_lattice.py +++ b/tests/test_lattice.py @@ -2090,14 +2090,12 @@ def test_layering_on_various_lattices(lattice_instance): _validate_layers(bonds, layers) -# --- Regression tests for backend-scalar lattice constants (PR fix) --- @pytest.mark.parametrize("backend", [lf("npb"), lf("tfb"), lf("jaxb"), lf("torchb")]) def test_square_lattice_accepts_backend_scalar_lattice_constant(backend): """ Ensure SquareLattice can be constructed when lattice_constant is a backend scalar tensor (e.g., tf.constant, jnp.array, torch.tensor), without mixed-type errors. """ - tc.set_backend(backend) lc = tc.backend.convert_to_tensor(0.5) lat = SquareLattice(size=(2, 2), lattice_constant=lc, pbc=False) @@ -2121,7 +2119,6 @@ def test_rectangular_lattice_mixed_type_constants(backend): RectangularLattice should accept a tuple where one constant is a backend scalar tensor and the other is a Python float. """ - tc.set_backend(backend) ax = tc.backend.convert_to_tensor(0.5) # tensor scalar ay = 2.0 # python float From 283e1fd1caead007d1adba145f5f71a306399692 Mon Sep 17 00:00:00 2001 From: Stellogic Date: Sat, 16 Aug 2025 10:18:34 +0800 Subject: [PATCH 15/16] fix according to the review --- examples/lattice_neighbor_benchmark.py | 174 +++++++++++---------- examples/lattice_neighbor_time_compare.py | 111 ------------- examples/lennard_jones_optimization.py | 8 +- tensorcircuit/backends/abstract_backend.py | 2 +- tensorcircuit/templates/lattice.py | 8 +- tests/test_lattice.py | 39 ----- 6 files changed, 106 insertions(+), 236 deletions(-) delete mode 100644 examples/lattice_neighbor_time_compare.py diff --git a/examples/lattice_neighbor_benchmark.py b/examples/lattice_neighbor_benchmark.py index 98cb7f89..dada10a0 100644 --- a/examples/lattice_neighbor_benchmark.py +++ b/examples/lattice_neighbor_benchmark.py @@ -1,95 +1,111 @@ """ -An example script to benchmark neighbor-finding algorithms in CustomizeLattice. - -This script demonstrates the performance difference between the KDTree-based -neighbor search and a baseline all-to-all distance matrix method. -As shown by the results, the KDTree approach offers a significant speedup, -especially when calculating for a large number of neighbor shells (large max_k). +Benchmark: Compare neighbor-building time between KDTree and distance-matrix +methods in CustomizeLattice for varying lattice sizes. """ -import timeit -from typing import Any, Dict, List +import argparse +import csv +import time +from typing import Iterable, List, Tuple, Optional +import logging + +import numpy as np + +# Silence verbose infos from the library during benchmarks + +logging.basicConfig(level=logging.WARNING) + +from tensorcircuit.templates.lattice import CustomizeLattice + + +def run_once( + n: int, d: int, max_k: int, repeats: int, seed: int +) -> Tuple[float, float]: + """Run one size point and return (time_kdtree, time_matrix).""" + rng = np.random.default_rng(seed) + ids = list(range(n)) + + # Collect times for each repeat with different random coordinates + kdtree_times: List[float] = [] + matrix_times: List[float] = [] + + for _ in range(repeats): + # Generate different coordinates for each repeat + coords = rng.random((n, d), dtype=float) + lat = CustomizeLattice(dimensionality=d, identifiers=ids, coordinates=coords) + + # KDTree path - single measurement + t0 = time.perf_counter() + lat._build_neighbors(max_k=max_k, use_kdtree=True) + kdtree_times.append(time.perf_counter() - t0) + # Distance-matrix path - single measurement + t0 = time.perf_counter() + lat._build_neighbors(max_k=max_k, use_kdtree=False) + matrix_times.append(time.perf_counter() - t0) -def run_benchmark() -> None: - """ - Executes the benchmark test and prints the results in a formatted table. - """ - # --- Benchmark Parameters --- - # A list of lattice sizes (N = number of sites) to test - site_counts: List[int] = [10, 50, 100, 200, 500, 1000, 1500, 2000] + return float(np.mean(kdtree_times)), float(np.mean(matrix_times)) - # Use a large k to better showcase the performance of KDTree in - # finding multiple neighbor shells, as suggested by the maintainer. - max_k: int = 2000 - # Reduce the number of runs to keep the total benchmark time reasonable, - # especially with a large max_k. - number_of_runs: int = 3 - # -------------------------- +def parse_sizes(s: str) -> List[int]: + return [int(x) for x in s.split(",") if x.strip()] - results: List[Dict[str, Any]] = [] - print("=" * 75) - print("Starting neighbor finding benchmark for CustomizeLattice...") - print(f"Parameters: max_k={max_k}, number_of_runs={number_of_runs}") - print("=" * 75) +def format_row(n: int, t_kdtree: float, t_matrix: float) -> str: + speedup = (t_matrix / t_kdtree) if t_kdtree > 0 else float("inf") + return f"{n:>8} | {t_kdtree:>12.6f} | {t_matrix:>14.6f} | {speedup:>7.2f}x" + + +def main(argv: Optional[Iterable[str]] = None) -> int: + p = argparse.ArgumentParser(description="Neighbor-building time comparison") + p.add_argument( + "--sizes", + type=parse_sizes, + default=[128, 256, 512, 1024, 2048], + help="Comma-separated site counts to benchmark (default: 128,256,512,1024,2048)", + ) + p.add_argument( + "--dims", type=int, default=2, help="Lattice dimensionality (default: 2)" + ) + p.add_argument( + "--max-k", type=int, default=6, help="Max neighbor shells k (default: 6)" + ) + p.add_argument( + "--repeats", type=int, default=5, help="Repeats per measurement (default: 5)" + ) + p.add_argument("--seed", type=int, default=42, help="PRNG seed (default: 42)") + p.add_argument("--csv", type=str, default="", help="Optional CSV output path") + args = p.parse_args(list(argv) if argv is not None else None) + + print("=" * 74) print( - f"{'Sites (N)':>10} | {'KDTree Time (s)':>18} | {'Baseline Time (s)':>20} | {'Speedup':>10}" + f"Benchmark CustomizeLattice neighbor-building | dims={args.dims} max_k={args.max_k} repeats={args.repeats}" ) - print("-" * 75) + print("=" * 74) + print(f"{'N':>8} | {'KDTree(s)':>12} | {'DistMatrix(s)':>14} | {'Speedup':>7}") + print("-" * 74) - for n_sites in site_counts: - # Prepare the setup code for timeit. - # This code generates a random lattice and is executed before timing begins. - # We use a fixed seed to ensure the coordinates are the same for both tests. - setup_code = f""" -import numpy as np -from tensorcircuit.templates.lattice import CustomizeLattice + rows: List[Tuple[int, float, float]] = [] + for n in args.sizes: + t_kdtree, t_matrix = run_once(n, args.dims, args.max_k, args.repeats, args.seed) + rows.append((n, t_kdtree, t_matrix)) + print(format_row(n, t_kdtree, t_matrix)) -np.random.seed(42) -coords = np.random.rand({n_sites}, 2) -ids = list(range({n_sites})) -lat = CustomizeLattice(dimensionality=2, identifiers=ids, coordinates=coords) -""" - # Define the Python statements to be timed. - stmt_kdtree = f"lat._build_neighbors(max_k={max_k})" - stmt_baseline = f"lat._build_neighbors_by_distance_matrix(max_k={max_k})" - - try: - # Execute the timing. timeit returns the total time for all runs. - time_kdtree = ( - timeit.timeit(stmt=stmt_kdtree, setup=setup_code, number=number_of_runs) - / number_of_runs - ) - time_baseline = ( - timeit.timeit( - stmt=stmt_baseline, setup=setup_code, number=number_of_runs - ) - / number_of_runs - ) - - # Calculate and store results, handling potential division by zero. - speedup = time_baseline / time_kdtree if time_kdtree > 0 else float("inf") - results.append( - { - "n_sites": n_sites, - "time_kdtree": time_kdtree, - "time_baseline": time_baseline, - "speedup": speedup, - } - ) - print( - f"{n_sites:>10} | {time_kdtree:>18.6f} | {time_baseline:>20.6f} | {speedup:>9.2f}x" - ) - - except Exception as e: - print(f"An error occurred at N={n_sites}: {e}") - break - - print("-" * 75) - print("Benchmark complete.") + if args.csv: + with open(args.csv, "w", newline="", encoding="utf-8") as f: + w = csv.writer(f) + w.writerow(["N", "time_kdtree_s", "time_distance_matrix_s", "speedup"]) + for n, t_kdtree, t_matrix in rows: + speedup = (t_matrix / t_kdtree) if t_kdtree > 0 else float("inf") + w.writerow([n, f"{t_kdtree:.6f}", f"{t_matrix:.6f}", f"{speedup:.2f}"]) + + print("-" * 74) + print(f"Saved CSV to: {args.csv}") + + print("-" * 74) + print("Done.") + return 0 if __name__ == "__main__": - run_benchmark() + raise SystemExit(main()) diff --git a/examples/lattice_neighbor_time_compare.py b/examples/lattice_neighbor_time_compare.py deleted file mode 100644 index dada10a0..00000000 --- a/examples/lattice_neighbor_time_compare.py +++ /dev/null @@ -1,111 +0,0 @@ -""" -Benchmark: Compare neighbor-building time between KDTree and distance-matrix -methods in CustomizeLattice for varying lattice sizes. -""" - -import argparse -import csv -import time -from typing import Iterable, List, Tuple, Optional -import logging - -import numpy as np - -# Silence verbose infos from the library during benchmarks - -logging.basicConfig(level=logging.WARNING) - -from tensorcircuit.templates.lattice import CustomizeLattice - - -def run_once( - n: int, d: int, max_k: int, repeats: int, seed: int -) -> Tuple[float, float]: - """Run one size point and return (time_kdtree, time_matrix).""" - rng = np.random.default_rng(seed) - ids = list(range(n)) - - # Collect times for each repeat with different random coordinates - kdtree_times: List[float] = [] - matrix_times: List[float] = [] - - for _ in range(repeats): - # Generate different coordinates for each repeat - coords = rng.random((n, d), dtype=float) - lat = CustomizeLattice(dimensionality=d, identifiers=ids, coordinates=coords) - - # KDTree path - single measurement - t0 = time.perf_counter() - lat._build_neighbors(max_k=max_k, use_kdtree=True) - kdtree_times.append(time.perf_counter() - t0) - - # Distance-matrix path - single measurement - t0 = time.perf_counter() - lat._build_neighbors(max_k=max_k, use_kdtree=False) - matrix_times.append(time.perf_counter() - t0) - - return float(np.mean(kdtree_times)), float(np.mean(matrix_times)) - - -def parse_sizes(s: str) -> List[int]: - return [int(x) for x in s.split(",") if x.strip()] - - -def format_row(n: int, t_kdtree: float, t_matrix: float) -> str: - speedup = (t_matrix / t_kdtree) if t_kdtree > 0 else float("inf") - return f"{n:>8} | {t_kdtree:>12.6f} | {t_matrix:>14.6f} | {speedup:>7.2f}x" - - -def main(argv: Optional[Iterable[str]] = None) -> int: - p = argparse.ArgumentParser(description="Neighbor-building time comparison") - p.add_argument( - "--sizes", - type=parse_sizes, - default=[128, 256, 512, 1024, 2048], - help="Comma-separated site counts to benchmark (default: 128,256,512,1024,2048)", - ) - p.add_argument( - "--dims", type=int, default=2, help="Lattice dimensionality (default: 2)" - ) - p.add_argument( - "--max-k", type=int, default=6, help="Max neighbor shells k (default: 6)" - ) - p.add_argument( - "--repeats", type=int, default=5, help="Repeats per measurement (default: 5)" - ) - p.add_argument("--seed", type=int, default=42, help="PRNG seed (default: 42)") - p.add_argument("--csv", type=str, default="", help="Optional CSV output path") - args = p.parse_args(list(argv) if argv is not None else None) - - print("=" * 74) - print( - f"Benchmark CustomizeLattice neighbor-building | dims={args.dims} max_k={args.max_k} repeats={args.repeats}" - ) - print("=" * 74) - print(f"{'N':>8} | {'KDTree(s)':>12} | {'DistMatrix(s)':>14} | {'Speedup':>7}") - print("-" * 74) - - rows: List[Tuple[int, float, float]] = [] - for n in args.sizes: - t_kdtree, t_matrix = run_once(n, args.dims, args.max_k, args.repeats, args.seed) - rows.append((n, t_kdtree, t_matrix)) - print(format_row(n, t_kdtree, t_matrix)) - - if args.csv: - with open(args.csv, "w", newline="", encoding="utf-8") as f: - w = csv.writer(f) - w.writerow(["N", "time_kdtree_s", "time_distance_matrix_s", "speedup"]) - for n, t_kdtree, t_matrix in rows: - speedup = (t_matrix / t_kdtree) if t_kdtree > 0 else float("inf") - w.writerow([n, f"{t_kdtree:.6f}", f"{t_matrix:.6f}", f"{speedup:.2f}"]) - - print("-" * 74) - print(f"Saved CSV to: {args.csv}") - - print("-" * 74) - print("Done.") - return 0 - - -if __name__ == "__main__": - raise SystemExit(main()) diff --git a/examples/lennard_jones_optimization.py b/examples/lennard_jones_optimization.py index f3c4c825..f80c5e2d 100644 --- a/examples/lennard_jones_optimization.py +++ b/examples/lennard_jones_optimization.py @@ -5,8 +5,10 @@ to optimize crystal structure. It finds the equilibrium lattice constant that minimizes the total Lennard-Jones potential energy of a 2D square lattice. -The optimization showcases the key Task 3 capability: making lattice parameters -differentiable for variational material design. +This example showcases a key capability of the differentiable lattice system: +making geometric parameters (like lattice constants) fully differentiable and +optimizable using automatic differentiation. This enables variational material design +where crystal structures can be optimized to minimize physical energy functions. """ import optax @@ -54,7 +56,7 @@ def calculate_potential(log_a, epsilon=0.5, sigma=1.0): optimizer = optax.adam(learning_rate=0.01) -log_a = K.convert_to_tensor(K.log(K.convert_to_tensor(1.1))) +log_a = K.convert_to_tensor(K.log(K.convert_to_tensor(2.0))) opt_state = optimizer.init(log_a) diff --git a/tensorcircuit/backends/abstract_backend.py b/tensorcircuit/backends/abstract_backend.py index 286a21f9..83720805 100644 --- a/tensorcircuit/backends/abstract_backend.py +++ b/tensorcircuit/backends/abstract_backend.py @@ -629,7 +629,7 @@ def all(self: Any, a: Tensor, axis: Optional[Sequence[int]] = None) -> Tensor: def meshgrid(self: Any, *args: Any, **kwargs: Any) -> Any: """ - Return coordinate matrices from coordinate vectors. + Return coordinate matrices from coordinate vectors. :param args: coordinate vectors :type args: Any diff --git a/tensorcircuit/templates/lattice.py b/tensorcircuit/templates/lattice.py index 49508487..1fa8dc72 100644 --- a/tensorcircuit/templates/lattice.py +++ b/tensorcircuit/templates/lattice.py @@ -18,11 +18,11 @@ Set, ) -logger = logging.getLogger(__name__) import itertools import math import numpy as np from scipy.spatial import KDTree + from .. import backend @@ -42,6 +42,8 @@ import matplotlib.axes from mpl_toolkits.mplot3d import Axes3D +logger = logging.getLogger(__name__) + Tensor = Any SiteIndex = int SiteIdentifier = Hashable @@ -835,7 +837,7 @@ def _build_neighbors(self, max_k: int = 2, **kwargs: Any) -> None: :param max_k: The maximum order of neighbors to compute (e.g., k=1 for nearest neighbors, k=2 for next-nearest, etc.). Defaults to 2. :type max_k: int, optional - :param \**kwargs: Additional keyword arguments. May include: + :param kwargs: Additional keyword arguments. May include: - ``tol`` (float): The numerical tolerance used to determine if two distances are equal when identifying shells. Defaults to 1e-6. """ @@ -1486,7 +1488,7 @@ def _build_neighbors_kdtree(self, max_k: int, tol: float) -> None: # For small lattices or cases with potential duplicate coordinates, # fall back to distance matrix method for robustness - if self.num_sites < 1000: + if self.num_sites < 200: logger.info( "Small lattice detected, falling back to distance matrix method for robustness" ) diff --git a/tests/test_lattice.py b/tests/test_lattice.py index fe730dd7..612c9f80 100644 --- a/tests/test_lattice.py +++ b/tests/test_lattice.py @@ -1,5 +1,4 @@ from unittest.mock import patch -import time import logging import matplotlib @@ -2003,44 +2002,6 @@ def test_distance_matrix_caching_is_effective(self, backend): spy_compute.assert_called_once() -@pytest.mark.slow -class TestPerformance: - def test_pbc_implementation_is_not_significantly_slower_than_obc(self): - """ - A performance regression test. - It ensures that the specialized implementation for fully periodic - lattices (pbc=True) is not substantially slower than the general - implementation used for open boundaries (pbc=False). - This test will FAIL with the current code, exposing the performance bug. - """ - # Arrange: Use a large-enough lattice to make performance differences apparent - size = (40, 40) - k = 1 - - # Act 1: Measure the execution time of the general (OBC) implementation - start_time_obc = time.time() - _ = SquareLattice(size=size, pbc=False, precompute_neighbors=k) - duration_obc = time.time() - start_time_obc - - # Act 2: Measure the execution time of the specialized (PBC) implementation - start_time_pbc = time.time() - _ = SquareLattice(size=size, pbc=True, precompute_neighbors=k) - duration_pbc = time.time() - start_time_pbc - - print( - f"\n[Performance] OBC ({size}): {duration_obc:.4f}s | PBC ({size}): {duration_pbc:.4f}s" - ) - - # Assert: The PBC implementation should not be drastically slower. - # We allow it to be up to 3 times slower to account for minor overheads, - # but this will catch the current 10x+ regression. - # THIS ASSERTION WILL FAIL with the current buggy code. - assert duration_pbc < duration_obc * 5, ( - "The specialized PBC implementation is significantly slower " - "than the general-purpose implementation." - ) - - def _validate_layers(bonds, layers): """ A helper function to scientifically validate the output of get_compatible_layers. From 494a99b6f39db3e4ff8427b561410c84b387bb39 Mon Sep 17 00:00:00 2001 From: Stellogic Date: Sat, 16 Aug 2025 10:23:13 +0800 Subject: [PATCH 16/16] fix black --- examples/lennard_jones_optimization.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/lennard_jones_optimization.py b/examples/lennard_jones_optimization.py index f80c5e2d..6d8349a1 100644 --- a/examples/lennard_jones_optimization.py +++ b/examples/lennard_jones_optimization.py @@ -5,8 +5,8 @@ to optimize crystal structure. It finds the equilibrium lattice constant that minimizes the total Lennard-Jones potential energy of a 2D square lattice. -This example showcases a key capability of the differentiable lattice system: -making geometric parameters (like lattice constants) fully differentiable and +This example showcases a key capability of the differentiable lattice system: +making geometric parameters (like lattice constants) fully differentiable and optimizable using automatic differentiation. This enables variational material design where crystal structures can be optimized to minimize physical energy functions. """