Skip to content

Commit 30a4350

Browse files
committed
fixes to tests
1 parent 2279d6c commit 30a4350

File tree

12 files changed

+90
-63
lines changed

12 files changed

+90
-63
lines changed

examples/scripts/3_Dynamics/3.11_Lennard_Jones_NPT_Langevin.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -148,11 +148,11 @@
148148

149149

150150
stress = model(state)["stress"]
151-
calc_kinetic_energy = calc_kinetic_energy(
151+
kinetic_energy = calc_kinetic_energy(
152152
masses=state.masses, momenta=state.momenta, system_idx=state.system_idx
153153
)
154154
volume = torch.linalg.det(state.cell)
155-
pressure = get_pressure(stress, calc_kinetic_energy, volume)
155+
pressure = get_pressure(stress, kinetic_energy, volume)
156156
pressure = pressure.item() / Units.pressure
157157
print(f"Final {pressure=:.4f}")
158158
print(stress * UnitConversion.eV_per_Ang3_to_GPa)

examples/scripts/4_High_level_api/4.1_high_level_api.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,9 @@
5454
prop_calculators = {
5555
10: {"potential_energy": lambda state: state.energy},
5656
20: {
57-
"kinetic_energy": lambda state: calc_kinetic_energy(state.momenta, state.masses)
57+
"kinetic_energy": lambda state: calc_kinetic_energy(
58+
momenta=state.momenta, masses=state.masses
59+
)
5860
},
5961
}
6062

examples/scripts/hi.py

Lines changed: 0 additions & 15 deletions
This file was deleted.

examples/tutorials/high_level_tutorial.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,7 @@
132132
10: {"potential_energy": lambda state: state.energy},
133133
20: {
134134
"kinetic_energy": lambda state: ts.calc_kinetic_energy(
135-
state.momenta, state.masses
135+
momenta=state.momenta, masses=state.masses
136136
)
137137
},
138138
}

tests/test_integrators.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,9 @@ def test_npt_langevin(ar_double_sim_state: ts.SimState, lj_model: LennardJonesMo
109109
state = update_fn(state=state)
110110

111111
# Calculate instantaneous temperature from kinetic energy
112-
temp = calc_kT(state.momenta, state.masses, system_idx=state.system_idx)
112+
temp = calc_kT(
113+
masses=state.masses, momenta=state.momenta, system_idx=state.system_idx
114+
)
113115
energies.append(state.energy)
114116
temperatures.append(temp / MetalUnits.temperature)
115117

@@ -172,7 +174,9 @@ def test_npt_langevin_multi_kt(
172174
state = update_fn(state=state)
173175

174176
# Calculate instantaneous temperature from kinetic energy
175-
temp = calc_kT(state.momenta, state.masses, system_idx=state.system_idx)
177+
temp = calc_kT(
178+
masses=state.masses, momenta=state.momenta, system_idx=state.system_idx
179+
)
176180
energies.append(state.energy)
177181
temperatures.append(temp / MetalUnits.temperature)
178182

@@ -213,7 +217,9 @@ def test_nvt_langevin(ar_double_sim_state: ts.SimState, lj_model: LennardJonesMo
213217
state = update_fn(state=state)
214218

215219
# Calculate instantaneous temperature from kinetic energy
216-
temp = calc_kT(state.momenta, state.masses, system_idx=state.system_idx)
220+
temp = calc_kT(
221+
masses=state.masses, momenta=state.momenta, system_idx=state.system_idx
222+
)
217223
energies.append(state.energy)
218224
temperatures.append(temp / MetalUnits.temperature)
219225

@@ -273,7 +279,9 @@ def test_nvt_langevin_multi_kt(
273279
state = update_fn(state=state)
274280

275281
# Calculate instantaneous temperature from kinetic energy
276-
temp = calc_kT(state.momenta, state.masses, system_idx=state.system_idx)
282+
temp = calc_kT(
283+
masses=state.masses, momenta=state.momenta, system_idx=state.system_idx
284+
)
277285
energies.append(state.energy)
278286
temperatures.append(temp / MetalUnits.temperature)
279287

tests/test_runners.py

Lines changed: 25 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,11 @@ def test_integrate_nve(
2323
filenames=traj_file,
2424
state_frequency=1,
2525
prop_calculators={
26-
1: {"ke": lambda state: calc_kinetic_energy(state.momenta, state.masses)}
26+
1: {
27+
"ke": lambda state: calc_kinetic_energy(
28+
momenta=state.momenta, masses=state.masses
29+
)
30+
}
2731
},
2832
)
2933

@@ -56,7 +60,11 @@ def test_integrate_single_nvt(
5660
filenames=traj_file,
5761
state_frequency=1,
5862
prop_calculators={
59-
1: {"ke": lambda state: calc_kinetic_energy(state.momenta, state.masses)}
63+
1: {
64+
"ke": lambda state: calc_kinetic_energy(
65+
momenta=state.momenta, masses=state.masses
66+
)
67+
}
6068
},
6169
)
6270

@@ -108,7 +116,11 @@ def test_integrate_double_nvt_with_reporter(
108116
filenames=trajectory_files,
109117
state_frequency=1,
110118
prop_calculators={
111-
1: {"ke": lambda state: calc_kinetic_energy(state.momenta, state.masses)}
119+
1: {
120+
"ke": lambda state: calc_kinetic_energy(
121+
momenta=state.momenta, masses=state.masses
122+
)
123+
}
112124
},
113125
)
114126

@@ -155,7 +167,11 @@ def test_integrate_many_nvt(
155167
filenames=trajectory_files,
156168
state_frequency=1,
157169
prop_calculators={
158-
1: {"ke": lambda state: calc_kinetic_energy(state.momenta, state.masses)}
170+
1: {
171+
"ke": lambda state: calc_kinetic_energy(
172+
momenta=state.momenta, masses=state.masses
173+
)
174+
}
159175
},
160176
)
161177

@@ -346,7 +362,11 @@ def test_batched_optimize_fire(
346362
filenames=trajectory_files,
347363
state_frequency=1,
348364
prop_calculators={
349-
1: {"ke": lambda state: calc_kinetic_energy(state.momenta, state.masses)}
365+
1: {
366+
"ke": lambda state: calc_kinetic_energy(
367+
velocities=state.velocities, masses=state.masses
368+
)
369+
}
350370
},
351371
)
352372

tests/test_state.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -530,14 +530,6 @@ def deform_grad_state(device: torch.device) -> DeformState:
530530
)
531531

532532

533-
def test_deform_grad_momenta(deform_grad_state: DeformState) -> None:
534-
"""Test momenta calculation in DeformGradMixin."""
535-
expected_momenta = deform_grad_state.velocities * deform_grad_state.masses.unsqueeze(
536-
-1
537-
)
538-
assert torch.allclose(deform_grad_state.momenta, expected_momenta)
539-
540-
541533
def test_deform_grad_reference_cell(deform_grad_state: DeformState) -> None:
542534
"""Test reference cell getter/setter in DeformGradMixin."""
543535
original_ref_cell = deform_grad_state.reference_cell.clone()

torch_sim/integrators/npt.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1222,7 +1222,9 @@ def compute_cell_force(
12221222
if system_mask.any():
12231223
system_momenta = momenta[system_mask]
12241224
system_masses = masses[system_mask]
1225-
KE_per_system[b] = calc_kinetic_energy(system_momenta, system_masses)
1225+
KE_per_system[b] = calc_kinetic_energy(
1226+
masses=system_masses, momenta=system_momenta
1227+
)
12261228

12271229
# Get stress tensor and compute trace per system
12281230
# Handle stress tensor with batch dimension
@@ -1431,7 +1433,7 @@ def npt_nose_hoover_init(
14311433
cell_mass = cell_mass.to(device=device, dtype=dtype)
14321434

14331435
# Calculate cell kinetic energy (using first system for initialization)
1434-
KE_cell = calc_kinetic_energy(cell_momentum[:1], cell_mass[:1])
1436+
KE_cell = calc_kinetic_energy(masses=cell_mass[:1], momenta=cell_momentum[:1])
14351437

14361438
# Ensure reference_cell has proper system dimensions
14371439
if state.cell.ndim == 2:
@@ -1486,7 +1488,9 @@ def npt_nose_hoover_init(
14861488
# Initialize thermostat
14871489
npt_state.momenta = momenta
14881490
KE = calc_kinetic_energy(
1489-
npt_state.momenta, npt_state.masses, system_idx=npt_state.system_idx
1491+
momenta=npt_state.momenta,
1492+
masses=npt_state.masses,
1493+
system_idx=npt_state.system_idx,
14901494
)
14911495
npt_state.thermostat = thermostat_fns.initialize(
14921496
npt_state.positions.numel(), KE, kT
@@ -1543,10 +1547,12 @@ def npt_nose_hoover_update(
15431547
)
15441548

15451549
# Update kinetic energies for thermostats
1546-
KE = calc_kinetic_energy(state.momenta, state.masses, system_idx=state.system_idx)
1550+
KE = calc_kinetic_energy(
1551+
masses=state.masses, momenta=state.momenta, system_idx=state.system_idx
1552+
)
15471553
state.thermostat.kinetic_energy = KE
15481554

1549-
KE_cell = calc_kinetic_energy(state.cell_momentum, state.cell_mass)
1555+
KE_cell = calc_kinetic_energy(masses=state.cell_mass, momenta=state.cell_momentum)
15501556
state.barostat.kinetic_energy = KE_cell
15511557

15521558
# Second half step of thermostat chains
@@ -1598,7 +1604,7 @@ def npt_nose_hoover_invariant(
15981604

15991605
# Calculate kinetic energy of particles per system
16001606
e_kin_per_system = calc_kinetic_energy(
1601-
state.momenta, state.masses, system_idx=state.system_idx
1607+
masses=state.masses, momenta=state.momenta, system_idx=state.system_idx
16021608
)
16031609

16041610
# Calculate degrees of freedom per system

torch_sim/integrators/nvt.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -368,7 +368,9 @@ def nvt_nose_hoover_init(
368368
)
369369

370370
# Calculate initial kinetic energy per system
371-
KE = calc_kinetic_energy(momenta, state.masses, system_idx=state.system_idx)
371+
KE = calc_kinetic_energy(
372+
masses=state.masses, momenta=momenta, system_idx=state.system_idx
373+
)
372374

373375
# Calculate degrees of freedom per system
374376
n_atoms_per_system = torch.bincount(state.system_idx)
@@ -434,7 +436,9 @@ def nvt_nose_hoover_update(
434436
state = velocity_verlet(state=state, dt=dt, model=model)
435437

436438
# Update chain kinetic energy per system
437-
KE = calc_kinetic_energy(state.momenta, state.masses, system_idx=state.system_idx)
439+
KE = calc_kinetic_energy(
440+
masses=state.masses, momenta=state.momenta, system_idx=state.system_idx
441+
)
438442
chain.kinetic_energy = KE
439443

440444
# Second half-step of chain evolution
@@ -478,7 +482,9 @@ def nvt_nose_hoover_invariant(
478482
"""
479483
# Calculate system energy terms per system
480484
e_pot = state.energy
481-
e_kin = calc_kinetic_energy(state.momenta, state.masses, system_idx=state.system_idx)
485+
e_kin = calc_kinetic_energy(
486+
masses=state.masses, momenta=state.momenta, system_idx=state.system_idx
487+
)
482488

483489
# Get system degrees of freedom per system
484490
n_atoms_per_system = torch.bincount(state.system_idx)

torch_sim/models/soft_sphere.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -594,11 +594,11 @@ def __init__(
594594
with type 0).
595595
"""
596596
super().__init__()
597-
self.device = device or torch.device("cpu")
598-
self.dtype = dtype
597+
self._device = device or torch.device("cpu")
598+
self._dtype = dtype
599599
self.pbc = pbc
600-
self.compute_forces = compute_forces
601-
self.compute_stress = compute_stress
600+
self._compute_forces = compute_forces
601+
self._compute_stress = compute_stress
602602
self.per_atom_energies = per_atom_energies
603603
self.per_atom_stresses = per_atom_stresses
604604
self.use_neighbor_list = use_neighbor_list

0 commit comments

Comments
 (0)