Skip to content

Commit d0ddbdf

Browse files
Use ase bulk for ar base state
1 parent 27f5c23 commit d0ddbdf

File tree

3 files changed

+19
-61
lines changed

3 files changed

+19
-61
lines changed

tests/conftest.py

Lines changed: 3 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -65,47 +65,9 @@ def si_double_base_state(si_atoms: Atoms, device: torch.device) -> Any:
6565
@pytest.fixture
6666
def ar_base_state(device: torch.device) -> BaseState:
6767
"""Create a face-centered cubic (FCC) Argon structure."""
68-
# 5.26 Å is a typical lattice constant for Ar
69-
a = 5.26 # Lattice constant
70-
N = 4 # Supercell size
71-
n_atoms = 4 * N * N * N # Total number of atoms (4 atoms per unit cell)
72-
dtype = torch.float64
73-
74-
# Create positions tensor directly
75-
positions = torch.zeros((n_atoms, 3), device=device, dtype=dtype)
76-
idx = 0
77-
for i in range(N):
78-
for j in range(N):
79-
for k in range(N):
80-
# Add base FCC positions with offset
81-
positions[idx] = torch.tensor([i, j, k], device=device, dtype=dtype) * a
82-
positions[idx + 1] = (
83-
torch.tensor([i, j + 0.5, k + 0.5], device=device, dtype=dtype) * a
84-
)
85-
positions[idx + 2] = (
86-
torch.tensor([i + 0.5, j, k + 0.5], device=device, dtype=dtype) * a
87-
)
88-
positions[idx + 3] = (
89-
torch.tensor([i + 0.5, j + 0.5, k], device=device, dtype=dtype) * a
90-
)
91-
idx += 4
92-
93-
# Create cell tensor with shape (1, 3, 3) to match atoms_to_state format
94-
cell = torch.eye(3, device=device, dtype=dtype).unsqueeze(0) * (N * a)
95-
96-
# Create batch indices
97-
batch = torch.zeros(n_atoms, device=device, dtype=torch.long)
98-
99-
return BaseState(
100-
positions=positions,
101-
masses=torch.full((n_atoms,), 39.95, device=device, dtype=dtype), # Ar mass
102-
cell=cell, # Cubic cell
103-
pbc=True,
104-
atomic_numbers=torch.full(
105-
(n_atoms,), 18, device=device, dtype=torch.long
106-
), # Ar atomic number
107-
batch=batch,
108-
)
68+
# Create FCC Ar using ASE, with 4x4x4 supercell
69+
ar_atoms = bulk("Ar", "fcc", a=5.26, cubic=True).repeat([2, 2, 2])
70+
return atoms_to_state(ar_atoms, device, torch.float64)
10971

11072

11173
@pytest.fixture

tests/models/test_lennard_jones.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,15 @@
22

33
import pytest
44
import torch
5+
from ase.build import bulk
56

67
from torchsim.models.interface import validate_model_outputs
78
from torchsim.models.lennard_jones import (
89
UnbatchedLennardJonesModel,
910
lennard_jones_pair,
1011
lennard_jones_pair_force,
1112
)
13+
from torchsim.runners import atoms_to_state
1214
from torchsim.state import BaseState
1315

1416

@@ -130,9 +132,21 @@ def test_lennard_jones_force_energy_consistency() -> None:
130132
assert torch.allclose(force_direct, force_from_grad, rtol=1e-4, atol=1e-4)
131133

132134

135+
# NOTE: This is a large system to test the neighbor list and direct calculation
136+
# are consistent. Direct calculation uses minimal image convention, which
137+
# is not used in the neighbor list calculation. So to get correct results,
138+
# we need a system that is large enough (2*cutoff).
139+
@pytest.fixture
140+
def ar_base_state_large(device: torch.device) -> BaseState:
141+
"""Create a face-centered cubic (FCC) Argon structure."""
142+
# Create FCC Ar using ASE, with 4x4x4 supercell
143+
ar_atoms = bulk("Ar", "fcc", a=5.26, cubic=True).repeat([4, 4, 4])
144+
return atoms_to_state(ar_atoms, device, torch.float64)
145+
146+
133147
@pytest.fixture
134148
def calculators(
135-
ar_base_state: BaseState,
149+
ar_base_state_large: BaseState,
136150
) -> tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]]:
137151
"""Create both neighbor list and direct calculators with Argon parameters."""
138152
calc_params = {
@@ -152,7 +166,7 @@ def calculators(
152166
use_neighbor_list=False, cutoff=cutoff, **calc_params
153167
)
154168

155-
positions, cell = ar_base_state.positions, ar_base_state.cell.squeeze(0)
169+
positions, cell = ar_base_state_large.positions, ar_base_state_large.cell.squeeze(0)
156170
return calc_nl(positions, cell), calc_direct(positions, cell)
157171

158172

tests/test_integrators.py

Lines changed: 0 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
from typing import Any
22

3-
import pytest
43
import torch
54

65
from torchsim.integrators import batched_initialize_momenta, nve, nvt_langevin
@@ -11,23 +10,6 @@
1110
from torchsim.units import MetalUnits
1211

1312

14-
@pytest.fixture
15-
def ar_double_base_state(ar_base_state: BaseState) -> BaseState:
16-
"""Create a basic state from ar_fcc_base_state."""
17-
batch = torch.repeat_interleave(torch.arange(2), ar_base_state.positions.shape[0])
18-
19-
return BaseState(
20-
positions=torch.cat([ar_base_state.positions, ar_base_state.positions], dim=0),
21-
cell=torch.cat([ar_base_state.cell, ar_base_state.cell], dim=0),
22-
masses=torch.cat([ar_base_state.masses, ar_base_state.masses], dim=0),
23-
atomic_numbers=torch.cat(
24-
[ar_base_state.atomic_numbers, ar_base_state.atomic_numbers], dim=0
25-
),
26-
batch=batch,
27-
pbc=ar_base_state.pbc,
28-
)
29-
30-
3113
def batched_initialize_momenta_loop(
3214
positions: torch.Tensor, # shape: (n_batches, n_atoms_per_batch, 3)
3315
masses: torch.Tensor, # shape: (n_batches, n_atoms_per_batch)

0 commit comments

Comments
 (0)