Skip to content

feat(lattice): Make lattice geometries differentiable and backend-agn… #30

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 17 commits into from
Aug 16, 2025
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
116 changes: 116 additions & 0 deletions examples/lennard_jones_optimization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
import optax
import numpy as np
import jax.numpy as jnp
import matplotlib.pyplot as plt
import jax
import tensorcircuit as tc


jax.config.update("jax_enable_x64", True)
K = tc.set_backend("jax")


def calculate_potential(log_a, base_distance_matrix, epsilon=0.5, sigma=1.0):
"""
Calculate the total Lennard-Jones potential energy for a given logarithm of the lattice constant (log_a).
"""
lattice_constant = jnp.exp(log_a)
d = base_distance_matrix * lattice_constant
d_safe = jnp.where(d > 1e-9, d, 1e-9)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use K throughout, instead of mix usage of jax and K?


term12 = (sigma / d_safe) ** 12
term6 = (sigma / d_safe) ** 6
potential_matrix = 4 * epsilon * (term12 - term6)

num_sites = d.shape[0]
potential_matrix = potential_matrix * (
1 - K.eye(num_sites, dtype=potential_matrix.dtype)
)

potential_energy = K.sum(potential_matrix) / 2.0

return potential_energy


# Pre-calculate the base distance matrix (for lattice_constant=1.0)
size = (10, 10)
lat_base = tc.templates.lattice.SquareLattice(size, lattice_constant=1.0, pbc=True)
base_distance_matrix = lat_base.distance_matrix

# Create a lambda function to pass the base distance matrix to the potential function
potential_fun_for_grad = lambda log_a: calculate_potential(log_a, base_distance_matrix)
value_and_grad_fun = K.jit(K.value_and_grad(potential_fun_for_grad))

optimizer = optax.adam(learning_rate=0.01)

log_a = K.convert_to_tensor(jnp.log(1.1))

opt_state = optimizer.init(log_a)

history = {"a": [], "energy": []}

print("Starting optimization of lattice constant...")
for i in range(200):
energy, grad = value_and_grad_fun(log_a)

history["a"].append(jnp.exp(log_a))
history["energy"].append(energy)

if jnp.isnan(grad):
print(f"Gradient became NaN at iteration {i+1}. Stopping optimization.")
print(f"Current energy: {energy}, Current log_a: {log_a}")
break

updates, opt_state = optimizer.update(grad, opt_state)
log_a = optax.apply_updates(log_a, updates)

if (i + 1) % 20 == 0:
current_a = jnp.exp(log_a)
print(
f"Iteration {i+1}/200: Total Energy = {energy:.4f}, Lattice Constant = {current_a:.4f}"
)

final_a = jnp.exp(log_a)
final_energy = calculate_potential(K.convert_to_tensor(log_a), base_distance_matrix)

if not jnp.isnan(final_energy):
print("\nOptimization finished!")
print(f"Final optimized lattice constant: {final_a:.6f}")
print(f"Corresponding minimum total energy: {final_energy:.6f}")

# Vectorized calculation for the potential curve
a_vals = np.linspace(0.8, 1.5, 200)
log_a_vals = np.log(a_vals)

# Use vmap to create a vectorized version of the potential function
vmap_potential = jax.vmap(lambda la: calculate_potential(la, base_distance_matrix))
potential_curve = vmap_potential(K.convert_to_tensor(log_a_vals))

plt.figure(figsize=(10, 6))
plt.plot(a_vals, potential_curve, label="Lennard-Jones Potential", color="blue")
plt.scatter(
history["a"],
history["energy"],
color="red",
s=20,
zorder=5,
label="Optimization Steps",
)
plt.scatter(
final_a,
final_energy,
color="green",
s=100,
zorder=6,
marker="*",
label="Final Optimized Point",
)

plt.title("Lennard-Jones Potential Optimization")
plt.xlabel("Lattice Constant (a)")
plt.ylabel("Total Potential Energy")
plt.legend()
plt.grid(True)
plt.show()
else:
print("\nOptimization failed. Final energy is NaN.")
115 changes: 115 additions & 0 deletions tensorcircuit/backends/abstract_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -596,6 +596,84 @@ def argsort(self: Any, a: Tensor, axis: int = -1) -> Tensor:
"Backend '{}' has not implemented `argsort`.".format(self.name)
)

def sort(self: Any, a: Tensor, axis: int = -1) -> Tensor:
"""
Sort a tensor along the given axis.

:param a: [description]
:type a: Tensor
:param axis: [description], defaults to -1
:type axis: int, optional
:return: [description]
:rtype: Tensor
"""
raise NotImplementedError(
"Backend '{}' has not implemented `sort`.".format(self.name)
)

def all(self: Any, a: Tensor, axis: Optional[Sequence[int]] = None) -> Tensor:
"""
Test whether all array elements along a given axis evaluate to True.

:param a: Input tensor
:type a: Tensor
:param axis: Axis or axes along which a logical AND reduction is performed,
defaults to None
:type axis: Optional[Sequence[int]], optional
:return: A new boolean or tensor resulting from the AND reduction
:rtype: Tensor
"""
raise NotImplementedError(
"Backend '{}' has not implemented `all`.".format(self.name)
)

def meshgrid(self: Any, *args: Any, **kwargs: Any) -> Any:
"""
Return coordinate matrices from coordinate vectors.

:param args: coordinate vectors
:type args: Any
:param kwargs: keyword arguments for meshgrid
:type kwargs: Any
:return: list of coordinate matrices
:rtype: Any
"""
raise NotImplementedError(
"Backend '{}' has not implemented `meshgrid`.".format(self.name)
)

def expand_dims(self: Any, a: Tensor, axis: int) -> Tensor:
"""
Expand the shape of a tensor.
Insert a new axis that will appear at the `axis` position in the expanded
tensor shape.

:param a: Input tensor
:type a: Tensor
:param axis: Position in the expanded axes where the new axis is placed
:type axis: int
:return: Output tensor with the number of dimensions increased by one.
:rtype: Tensor
"""
raise NotImplementedError(
"Backend '{}' has not implemented `expand_dims`.".format(self.name)
)

def power(self: Any, a: Tensor, b: Union[Tensor, float]) -> Tensor:
"""
First array elements raised to powers from second array, element-wise.

:param a: The bases
:type a: Tensor
:param b: The exponents
:type b: Union[Tensor, float]
:return: The bases in `a` raised to the powers in `b`.
:rtype: Tensor
"""
raise NotImplementedError(
"Backend '{}' has not implemented `power`.".format(self.name)
)

def unique_with_counts(self: Any, a: Tensor, **kws: Any) -> Tuple[Tensor, Tensor]:
"""
Find the unique elements and their corresponding counts of the given tensor ``a``.
Expand Down Expand Up @@ -1404,6 +1482,43 @@ def cond(
"Backend '{}' has not implemented `cond`.".format(self.name)
)

def where(
self: Any,
condition: Tensor,
x: Optional[Tensor] = None,
y: Optional[Tensor] = None,
) -> Tensor:
"""
Return a tensor of elements selected from either x or y, depending on condition.

:param condition: Where True, yield x, otherwise yield y.
:type condition: Tensor (bool)
:param x: Values from which to choose when condition is True.
:type x: Tensor
:param y: Values from which to choose when condition is False.
:type y: Tensor
:return: A tensor with elements from x where condition is True, and y otherwise.
:rtype: Tensor
"""
raise NotImplementedError(
"Backend '{}' has not implemented `where`.".format(self.name)
)

def equal(self: Any, x1: Tensor, x2: Tensor) -> Tensor:
"""
Return the truth value of (x1 == x2) element-wise.

:param x1: Input tensor.
:type x1: Tensor
:param x2: Input tensor.
:type x2: Tensor
:return: Output tensor, element-wise truth value of (x1 == x2).
:rtype: Tensor (bool)
"""
raise NotImplementedError(
"Backend '{}' has not implemented `equal`.".format(self.name)
)

def switch(
self: Any, index: Tensor, branches: Sequence[Callable[[], Tensor]]
) -> Tensor:
Expand Down
35 changes: 34 additions & 1 deletion tensorcircuit/backends/jax_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,8 +243,10 @@ def zeros(self, shape: Tuple[int, ...], dtype: Optional[str] = None) -> Tensor:
def copy(self, tensor: Tensor) -> Tensor:
return jnp.array(tensor, copy=True)

def convert_to_tensor(self, tensor: Tensor) -> Tensor:
def convert_to_tensor(self, tensor: Tensor, **kwargs: Any) -> Tensor:
result = jnp.asarray(tensor)
if "dtype" in kwargs and kwargs["dtype"] is not None:
result = self.cast(result, kwargs["dtype"])
return result

def abs(self, a: Tensor) -> Tensor:
Expand Down Expand Up @@ -353,6 +355,9 @@ def expm(self, a: Tensor) -> Tensor:
# currently expm in jax doesn't support AD, it will raise an AssertError,
# see https://github.com/google/jax/issues/2645

def power(self, a: Tensor, b: Union[Tensor, float]) -> Tensor:
return jnp.power(a, b)

def stack(self, a: Sequence[Tensor], axis: int = 0) -> Tensor:
return jnp.stack(a, axis=axis)

Expand Down Expand Up @@ -390,6 +395,9 @@ def argmin(self, a: Tensor, axis: int = 0) -> Tensor:
def argsort(self, a: Tensor, axis: int = -1) -> Tensor:
return jnp.argsort(a, axis=axis)

def sort(self, a: Tensor, axis: int = -1) -> Tensor:
return jnp.sort(a, axis=axis)

def unique_with_counts( # type: ignore
self, a: Tensor, *, size: Optional[int] = None, fill_value: Optional[int] = None
) -> Tuple[Tensor, Tensor]:
Expand All @@ -410,6 +418,12 @@ def onehot(self, a: Tensor, num: int) -> Tensor:
def cumsum(self, a: Tensor, axis: Optional[int] = None) -> Tensor:
return jnp.cumsum(a, axis)

def all(self, a: Tensor, axis: Optional[Sequence[int]] = None) -> Tensor:
return jnp.all(a, axis=axis)

def equal(self, x1: Tensor, x2: Tensor) -> Tensor:
return jnp.equal(x1, x2)

def is_tensor(self, a: Any) -> bool:
if not isinstance(a, jnp.ndarray):
return False
Expand Down Expand Up @@ -812,4 +826,23 @@ def wrapper(

vvag = vectorized_value_and_grad

def meshgrid(self, *args: Any, **kwargs: Any) -> Any:
"""
Backend-agnostic meshgrid function.
"""
return jnp.meshgrid(*args, **kwargs)

optimizer = optax_optimizer

def expand_dims(self, a: Tensor, axis: int) -> Tensor:
return jnp.expand_dims(a, axis)

def where(
self,
condition: Tensor,
x: Optional[Tensor] = None,
y: Optional[Tensor] = None,
) -> Tensor:
if x is None and y is None:
return jnp.where(condition)
return jnp.where(condition, x, y)
34 changes: 33 additions & 1 deletion tensorcircuit/backends/numpy_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def _sum_numpy(
# see https://github.com/google/TensorNetwork/issues/952


def _convert_to_tensor_numpy(self: Any, a: Tensor) -> Tensor:
def _convert_to_tensor_numpy(self: Any, a: Tensor, **kwargs: Any) -> Tensor:
if not isinstance(a, np.ndarray) and not np.isscalar(a):
a = np.array(a)
a = np.asarray(a)
Expand Down Expand Up @@ -80,6 +80,9 @@ def copy(self, a: Tensor) -> Tensor:
def expm(self, a: Tensor) -> Tensor:
return expm(a)

def power(self, a: Tensor, b: Union[Tensor, float]) -> Tensor:
return np.power(a, b)

def abs(self, a: Tensor) -> Tensor:
return np.abs(a)

Expand Down Expand Up @@ -132,6 +135,12 @@ def eigvalsh(self, a: Tensor) -> Tensor:
def kron(self, a: Tensor, b: Tensor) -> Tensor:
return np.kron(a, b)

def meshgrid(self, *args: Any, **kwargs: Any) -> Any:
"""
Backend-agnostic meshgrid function.
"""
return np.meshgrid(*args, **kwargs)

def dtype(self, a: Tensor) -> str:
return a.dtype.__str__() # type: ignore

Expand All @@ -151,6 +160,9 @@ def i(self, dtype: Any = None) -> Tensor:
dtype = getattr(np, dtype)
return np.array(1j, dtype=dtype)

def expand_dims(self, a: Tensor, axis: int) -> Tensor:
return np.expand_dims(a, axis)

def stack(self, a: Sequence[Tensor], axis: int = 0) -> Tensor:
return np.stack(a, axis=axis)

Expand All @@ -173,6 +185,9 @@ def std(
) -> Tensor:
return np.std(a, axis=axis, keepdims=keepdims)

def all(self, a: Tensor, axis: Optional[Sequence[int]] = None) -> Tensor:
return np.all(a, axis=axis)

def unique_with_counts(self, a: Tensor, **kws: Any) -> Tuple[Tensor, Tensor]:
return np.unique(a, return_counts=True) # type: ignore

Expand All @@ -188,6 +203,9 @@ def argmax(self, a: Tensor, axis: int = 0) -> Tensor:
def argmin(self, a: Tensor, axis: int = 0) -> Tensor:
return np.argmin(a, axis=axis)

def sort(self, a: Tensor, axis: int = -1) -> Tensor:
return np.sort(a, axis=axis)

def sigmoid(self, a: Tensor) -> Tensor:
return expit(a)

Expand Down Expand Up @@ -345,6 +363,20 @@ def to_dense(self, sp_a: Tensor) -> Tensor:
def is_sparse(self, a: Tensor) -> bool:
return issparse(a) # type: ignore

def where(
self,
condition: Tensor,
x: Optional[Tensor] = None,
y: Optional[Tensor] = None,
) -> Tensor:
if x is None and y is None:
return np.where(condition)
assert x is not None and y is not None
return np.where(condition, x, y)

def equal(self, x: Tensor, y: Tensor) -> Tensor:
return np.equal(x, y)

def cond(
self,
pred: bool,
Expand Down
Loading