Skip to content

feat(lattice): Make lattice geometries differentiable and backend-agn… #30

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 17 commits into from
Aug 16, 2025

Conversation

Stellogic
Copy link
Contributor

This commit introduces a major refactoring of the Lattice module to support automatic differentiation and backend-agnostic operations. The core motivation is to enable the optimization of lattice geometric parameters, such as the lattice constant, by minimizing physical observables like energy.

This is achieved by removing the direct dependency on NumPy for all geometric calculations and replacing them with operations from the abstract tc.backend.

Key Features & Changes:

  • Differentiable Geometry: All internal calculations in the Lattice classes, including distance matrices and neighbor finding, are now based on differentiable backend tensor operations. This allows for gradient-based optimization of lattice parameters.

  • Backend Abstraction: The Lattice module is now decoupled from any specific backend (JAX, TensorFlow, PyTorch), ensuring consistent behavior across all supported frameworks.

  • All-to-All Interactions: Added get_all_pairs to Lattice and an interaction_scope parameter to heisenberg_hamiltonian to support models requiring interactions beyond nearest-neighbors, such as Rydberg or Lennard-Jones potentials.

  • New Optimization Example: A new example, lennard_jones_optimization.py, demonstrates how to find the equilibrium lattice constant of a crystal by minimizing the Lennard-Jones potential.

Backend API Enhancements:

  • The ExtendedBackend interface has been expanded with new abstract methods (sort, meshgrid, where, equal, etc.) to support the required geometric computations.
  • All backends (jax, numpy, pytorch, tensorflow) have been updated to implement these new methods.

Testing:

  • Added a new TestDifferentiability suite to test_lattice.py to verify gradient calculations.
  • Enhanced tests for CustomizeLattice and data type consistency across backends.
  • Added a new test for all-to-all interactions in the Heisenberg model.

There are many debug log for debugging. When the code is ready , we can delete them if necessary.
I thought this pr may finish my project

Copy link
Member

@refraction-ray refraction-ray left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

whenever possible, please keep the comment, name convention etc. exactly the same as the original one to minimize the diff, currently there are two many meaning less changes for debug log, variable name change etc. Keep the change as minimal as possible, dont let ai to rewrite the good parts with meaningless changes (like removing/adding comments, debug info, etc).

@@ -107,7 +141,7 @@ def test_input_validation_mismatched_lengths(self):
# the specified exception is raised within the 'with' block.
with pytest.raises(
ValueError,
match="Identifiers and coordinates lists must have the same length.",
match="The number of identifiers must match the number of coordinates.",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is the point to change the text, which can only increase the diff. please keep the diff as minimal as possible

@Stellogic
Copy link
Contributor Author

Stellogic commented Aug 12, 2025

@refraction-ray
Thanks for your review!
I have modified the code based on the comments.

Removed debug logs: There were many test failures after refactoring the code, and I was unsure of the cause, so I added a lot of debug logs. Sorry for the problems. I will remember to remove them before creating a pull request next time.

Deleted unnecessary methods.

Modified the build_neighbors method in CustomizeLattice: When differentiation is needed, kdtree is not used (as I believe kdtree does not support differentiation). When differentiation is not required, kdtree is used to improve efficiency.

Added tests for new backend-independent methods.

Removed the type hints from the test code. I forgot to do it when I was refactoring.

@refraction-ray
Copy link
Member

overall very nice work! I have left some more comments. and more improtantly, the test_lattice.py need some rewriting to use the same backend test fixture as how the test is defined in test_backends.py. and all tests in test_lattice.py should check with different backends, unless specific tests not valid in all backends

@Stellogic
Copy link
Contributor Author

Stellogic commented Aug 13, 2025

Thanks agian for your patient guidance!

I have revised the content of test_lattice based on the review comments and resolved the new fails that appeared after the changes to lattice.py.

Some assert statements have been modified and replaced with ValueError.

The tensorflow_backend issue has been fixed.

The checks for NaN in the example files have been reduced.

Unnecessary or duplicate comments have been removed.


def meshgrid(self: Any, *args: Any, **kwargs: Any) -> Any:
"""
Return coordinate matrices from coordinate vectors.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

one more space for the docstring?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

still one more space on the above line

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

still one more space on the above line!

@Stellogic
Copy link
Contributor Author

Thanks for your review,
I have fixed the code according to the review.

Summary of Changes:

Added lattice_neighbor_time_compare to compare the efficiency of the k-d tree and matrix methods. The k-d tree begins to outperform the matrix method at approximately N=512, and its efficiency advantage increases with N. At N=2048, the k-d tree is nearly four times more efficient than the matrix method.

Reduced unnecessary checks.

Instantiated the _compute_distance_matrix method from the base class and removed unnecessary overrides.

Improved comments and code for better readability.

Ensured all numerical values share the same data type by consistently using backend tensor operations, which is expected to prevent errors caused by mixing Python types with backend tensors.

@Stellogic
Copy link
Contributor Author

Thanks for your review!
Here are the changes:
Replaced numpy with a use of the stop_gradient method, falling back to the original tensor and making a copy if the method is not applicable.

Confirmed that convert_to_tensor always has a dtype under normal conditions, so there is no need to throw an exception; float32 or 64 is now selected directly.

Removed unnecessary if-statements.

Copy link
Member

@refraction-ray refraction-ray left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The PR is very close to be merged. After addressing the following minor points, I think it can be merged.


optimizer = optax.adam(learning_rate=0.01)

log_a = K.convert_to_tensor(K.log(K.convert_to_tensor(1.1)))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

staring from 2.0

@@ -0,0 +1,111 @@
"""
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please delete this file, and ensure the previous lattice neighbor example lattice_neighbor_benchmark.py is doing correct, i.e. compare kdtree and the baseline

:param tol: The numerical tolerance for distance
comparisons. Defaults to 1e-6.
:type tol: float, optional
:param \**kwargs: Additional keyword arguments. May include:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just kwargs, no \ **

prev_found = found_indices[i]
# For small lattices or cases with potential duplicate coordinates,
# fall back to distance matrix method for robustness
if self.num_sites < 1000:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

200 instead of 1000

from scipy.spatial import KDTree
from scipy.spatial.distance import pdist, squareform
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

an empty line between external import and internal import

@@ -19,10 +19,11 @@
)

logger = logging.getLogger(__name__)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

move after the imports

spy_compute.assert_called_once()


@pytest.mark.slow
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

gives warning for pytest

PytestUnknownMarkWarning: Unknown pytest.mark.slow - is this a typo?  You can register custom marks to avoid this warning - for details, see https://docs.pytest.org/en/stable/how-to/mark.html
    @pytest.mark.slow

I think you can just delete the test performance function

@refraction-ray
Copy link
Member

still a space mismatch in meshgrid method in abstractbackend.py

to optimize crystal structure. It finds the equilibrium lattice constant that minimizes
the total Lennard-Jones potential energy of a 2D square lattice.

The optimization showcases the key Task 3 capability: making lattice parameters
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

users have no idea what taks 3 is, please rephrase

@Stellogic
Copy link
Contributor Author

Thanks for your review
I have changed the code according to your review.

I have deleted lattice_neighbor_time_compair and copied its content to lattice_neighbor_benchmark.py.

I have deleted test performance function.

And other detail changes under your guidance

@refraction-ray
Copy link
Member

LGTM, thanks for the nice contribution!

@refraction-ray refraction-ray merged commit 494a99b into tensorcircuit:master Aug 16, 2025
1 check passed
@refraction-ray
Copy link
Member

Now, I believe the only remaining thing is to add a jupyter tutorial here: https://github.com/tensorcircuit/tensorcircuit-ng/tree/master/docs/source/tutorials, to better and systematically illustrate the basic usage and functionalities of the newly added lattice module, you can include ingredients of previous tests and examples. The final output of the ipynb will be like https://tensorcircuit-ng.readthedocs.io/en/latest/tutorials/qubo_problem.html, which is easier for the users to learn about the new useful module.

@Stellogic
Copy link
Contributor Author

Thanks !
I will learn and then try to finish the jupyter tutorial.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants