Skip to content

Commit 97e513c

Browse files
More robust tests (#12)
* More robust tests for integrators, lj for argon without ASE / Pymatgen * Add invariant check for nose-hoover * NPT test
1 parent 244eb80 commit 97e513c

File tree

4 files changed

+210
-92
lines changed

4 files changed

+210
-92
lines changed

tests/conftest.py

Lines changed: 51 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -56,29 +56,70 @@ def fe_fcc_state(device: torch.device) -> Any:
5656
return atoms_to_state(fe_atoms, device, torch.float64)
5757

5858

59-
@pytest.fixture
60-
def ar_fcc_state(device: torch.device) -> Any:
61-
ar_atoms = bulk("Ar", "fcc", a=5.26, cubic=True).repeat([4, 4, 4])
62-
return atoms_to_state(ar_atoms, device, torch.float64)
63-
64-
6559
@pytest.fixture
6660
def si_double_base_state(si_atoms: Atoms, device: torch.device) -> Any:
6761
"""Create a basic state from si_structure."""
6862
return atoms_to_state([si_atoms, si_atoms], device, torch.float64)
6963

7064

65+
@pytest.fixture
66+
def ar_fcc_base_state(device: torch.device) -> BaseState:
67+
"""Create a face-centered cubic (FCC) Argon structure."""
68+
# 5.26 Å is a typical lattice constant for Ar
69+
a = 5.26 # Lattice constant
70+
N = 4 # Supercell size
71+
n_atoms = 4 * N * N * N # Total number of atoms (4 atoms per unit cell)
72+
dtype = torch.float64
73+
74+
# Create positions tensor directly
75+
positions = torch.zeros((n_atoms, 3), device=device, dtype=dtype)
76+
idx = 0
77+
for i in range(N):
78+
for j in range(N):
79+
for k in range(N):
80+
# Add base FCC positions with offset
81+
positions[idx] = torch.tensor([i, j, k], device=device, dtype=dtype) * a
82+
positions[idx + 1] = (
83+
torch.tensor([i, j + 0.5, k + 0.5], device=device, dtype=dtype) * a
84+
)
85+
positions[idx + 2] = (
86+
torch.tensor([i + 0.5, j, k + 0.5], device=device, dtype=dtype) * a
87+
)
88+
positions[idx + 3] = (
89+
torch.tensor([i + 0.5, j + 0.5, k], device=device, dtype=dtype) * a
90+
)
91+
idx += 4
92+
93+
# Create cell tensor with shape (1, 3, 3) to match atoms_to_state format
94+
cell = torch.eye(3, device=device, dtype=dtype).unsqueeze(0) * (N * a)
95+
96+
# Create batch indices
97+
batch = torch.zeros(n_atoms, device=device, dtype=torch.long)
98+
99+
return BaseState(
100+
positions=positions,
101+
masses=torch.full((n_atoms,), 39.95, device=device, dtype=dtype), # Ar mass
102+
cell=cell, # Cubic cell
103+
pbc=True,
104+
atomic_numbers=torch.full(
105+
(n_atoms,), 18, device=device, dtype=torch.long
106+
), # Ar atomic number
107+
batch=batch,
108+
)
109+
110+
71111
@pytest.fixture
72112
def unbatched_lj_calculator(device: torch.device) -> UnbatchedLennardJonesModel:
73-
"""Create a Lennard-Jones calculator with reasonable parameters for Si."""
113+
"""Create a Lennard-Jones calculator with reasonable parameters for Ar."""
74114
return UnbatchedLennardJonesModel(
75-
sigma=2.0, # Approximate for Si-Si interaction
76-
epsilon=0.1, # Small epsilon for stability during testing
115+
use_neighbor_list=True,
116+
sigma=3.405, # Approximate for Ar-Ar interaction
117+
epsilon=0.0104, # Small epsilon for stability during testing
77118
device=device,
78119
dtype=torch.float64,
79120
compute_force=True,
80121
compute_stress=True,
81-
cutoff=5.0,
122+
cutoff=2.5 * 3.405,
82123
)
83124

84125

tests/models/test_lennard_jones.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
from torchsim.models.interface import validate_model_outputs
77
from torchsim.models.lennard_jones import (
8-
LennardJonesModel,
8+
UnbatchedLennardJonesModel,
99
lennard_jones_pair,
1010
lennard_jones_pair_force,
1111
)
@@ -132,7 +132,7 @@ def test_lennard_jones_force_energy_consistency() -> None:
132132

133133
@pytest.fixture
134134
def calculators(
135-
ar_fcc_state: BaseState,
135+
ar_fcc_base_state: BaseState,
136136
) -> tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]]:
137137
"""Create both neighbor list and direct calculators with Argon parameters."""
138138
calc_params = {
@@ -145,10 +145,14 @@ def calculators(
145145
}
146146

147147
cutoff = 2.5 * 3.405 # Standard LJ cutoff * sigma
148-
calc_nl = LennardJonesModel(use_neighbor_list=True, cutoff=cutoff, **calc_params)
149-
calc_direct = LennardJonesModel(use_neighbor_list=False, cutoff=cutoff, **calc_params)
148+
calc_nl = UnbatchedLennardJonesModel(
149+
use_neighbor_list=True, cutoff=cutoff, **calc_params
150+
)
151+
calc_direct = UnbatchedLennardJonesModel(
152+
use_neighbor_list=False, cutoff=cutoff, **calc_params
153+
)
150154

151-
positions, cell = ar_fcc_state.positions, ar_fcc_state.cell
155+
positions, cell = ar_fcc_base_state.positions, ar_fcc_base_state.cell.squeeze(0)
152156
return calc_nl(positions, cell), calc_direct(positions, cell)
153157

154158

@@ -197,7 +201,7 @@ def test_stress_tensor_symmetry(
197201

198202

199203
def test_validate_model_outputs(
200-
lj_calculator: LennardJonesModel,
204+
lj_calculator: UnbatchedLennardJonesModel,
201205
device: torch.device,
202206
) -> None:
203207
"""Test that the model outputs are valid."""

0 commit comments

Comments
 (0)