Skip to content

Commit 9600cef

Browse files
orionarcherjanosh
andauthored
Excise openff dependency from openmm testing (#993)
* Excise openff dependency from openmm testing * Remove commmented out code * Update src/atomate2/openmm/jobs/base.py Co-authored-by: Janosh Riebesell <janosh.riebesell@gmail.com> * Respond to Yanosh PR and fix type of OpenMM Flow * Fix typo, lint * Add dataclass tag where needed --------- Co-authored-by: Janosh Riebesell <janosh.riebesell@gmail.com>
1 parent 96b2b82 commit 9600cef

File tree

11 files changed

+195
-90
lines changed

11 files changed

+195
-90
lines changed

src/atomate2/openmm/flows/core.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from typing import TYPE_CHECKING
99

1010
from emmet.core.openmm import Calculation, OpenMMInterchange, OpenMMTaskDocument
11-
from jobflow import Flow, Job, Response
11+
from jobflow import Flow, Job, Maker, Response
1212
from monty.json import MontyDecoder, MontyEncoder
1313

1414
from atomate2.openmm.jobs.base import openmm_job
@@ -68,7 +68,7 @@ def collect_outputs(
6868

6969

7070
@dataclass
71-
class OpenMMFlowMaker:
71+
class OpenMMFlowMaker(Maker):
7272
"""Run a production simulation.
7373
7474
This flexible flow links together any flows of OpenMM jobs in

src/atomate2/openmm/jobs/base.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -524,6 +524,11 @@ def _update_interchange(
524524
interchange.box = state.getPeriodicBoxVectors(asNumpy=True)
525525
elif isinstance(interchange, OpenMMInterchange):
526526
interchange.state = XmlSerializer.serialize(state)
527+
else:
528+
raise TypeError(
529+
f"Interchange must be an Interchange or "
530+
f"OpenMMInterchange object, got {type(interchange).__name__}"
531+
)
527532

528533
def _create_structure(
529534
self, sim: Simulation, prev_task: OpenMMTaskDocument | None = None
@@ -607,8 +612,10 @@ def _create_task_doc(
607612

608613
prev_task = prev_task or OpenMMTaskDocument()
609614

610-
interchange_json = interchange.json()
611-
# interchange_bytes = interchange_json.encode("utf-8")
615+
if isinstance(interchange, Interchange):
616+
interchange_json = interchange.json()
617+
else:
618+
interchange_json = interchange.model_dump_json()
612619

613620
return OpenMMTaskDocument(
614621
tags=tags,

src/atomate2/openmm/utils.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,23 @@
22

33
from __future__ import annotations
44

5+
import io
56
import re
67
import tempfile
78
import time
89
import warnings
910
from pathlib import Path
1011
from typing import TYPE_CHECKING
1112

13+
import numpy as np
14+
import openmm.unit as omm_unit
15+
from emmet.core.openmm import OpenMMInterchange
16+
from openmm import LangevinMiddleIntegrator, XmlSerializer
17+
from openmm.app import PDBFile
18+
1219
if TYPE_CHECKING:
1320
from emmet.core.openmm import OpenMMTaskDocument
21+
from openff.interchange import Interchange
1422

1523

1624
def download_opls_xml(
@@ -132,3 +140,34 @@ def task_reports(task: OpenMMTaskDocument, traj_or_state: str = "traj") -> bool:
132140
else:
133141
raise ValueError("traj_or_state must be 'traj' or 'state'")
134142
return calc_input.n_steps >= report_freq
143+
144+
145+
def openff_to_openmm_interchange(
146+
openff_interchange: Interchange,
147+
) -> OpenMMInterchange:
148+
"""Convert an OpenFF Interchange object to an OpenMM Interchange object."""
149+
integrator = LangevinMiddleIntegrator(
150+
300 * omm_unit.kelvin,
151+
10.0 / omm_unit.picoseconds,
152+
1.0 * omm_unit.femtoseconds,
153+
)
154+
sim = openff_interchange.to_openmm_simulation(integrator)
155+
state = sim.context.getState(
156+
getPositions=True,
157+
getVelocities=True,
158+
enforcePeriodicBox=True,
159+
)
160+
with io.StringIO() as buffer:
161+
PDBFile.writeFile(
162+
sim.topology,
163+
np.zeros(shape=(sim.topology.getNumAtoms(), 3)),
164+
file=buffer,
165+
)
166+
buffer.seek(0)
167+
pdb = buffer.read()
168+
169+
return OpenMMInterchange(
170+
system=XmlSerializer.serialize(sim.system),
171+
state=XmlSerializer.serialize(state),
172+
topology=pdb,
173+
)

tests/openff_md/test_core.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
1+
import pytest
12
from emmet.core.openff import ClassicalMDTaskDocument, MoleculeSpec
2-
from openff.interchange import Interchange
33

44
from atomate2.openff.core import generate_interchange
55

6+
pytest.importorskip("openff.toolkit")
7+
from openff.interchange import Interchange # noqa: E402
8+
69

710
def test_generate_interchange(mol_specs_small, run_job):
811
mass_density = 1

tests/openff_md/test_utils.py

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,8 @@
11
import numpy as np
2-
import openff.toolkit as tk
32
import pymatgen
43
import pytest
54
from emmet.core.openff import MoleculeSpec
6-
from openff.interchange import Interchange
7-
from openff.toolkit.topology import Topology
8-
from openff.toolkit.topology.molecule import Molecule
9-
from openff.units import Quantity
105
from pymatgen.analysis.graphs import MoleculeGraph
11-
from pymatgen.io.openff import (
12-
add_conformer,
13-
assign_partial_charges,
14-
create_openff_mol,
15-
get_atom_map,
16-
infer_openff_mol,
17-
mol_graph_to_openff_mol,
18-
)
196

207
from atomate2.openff.utils import (
218
counts_from_box_size,
@@ -24,6 +11,21 @@
2411
merge_specs_by_name_and_smiles,
2512
)
2613

14+
pytest.importorskip("openff.toolkit")
15+
import openff.toolkit as tk # noqa: E402
16+
from openff.interchange import Interchange # noqa: E402
17+
from openff.toolkit.topology import Topology # noqa: E402
18+
from openff.toolkit.topology.molecule import Molecule # noqa: E402
19+
from openff.units import Quantity # noqa: E402
20+
from pymatgen.io.openff import ( # noqa: E402
21+
add_conformer,
22+
assign_partial_charges,
23+
create_openff_mol,
24+
get_atom_map,
25+
infer_openff_mol,
26+
mol_graph_to_openff_mol,
27+
)
28+
2729

2830
def test_molgraph_to_openff_pf6(mol_files):
2931
"""transform a water MoleculeGraph to a OpenFF water molecule"""

tests/openmm_md/conftest.py

Lines changed: 52 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,6 @@
1-
import openff.toolkit as tk
21
import pytest
2+
from emmet.core.openmm import OpenMMInterchange
33
from jobflow import run_locally
4-
from openff.interchange import Interchange
5-
from openff.interchange.components._packmol import pack_box
6-
from openff.toolkit import ForceField
7-
from openff.units import unit
8-
9-
from atomate2.openff.utils import create_mol_spec, merge_specs_by_name_and_smiles
104

115

126
@pytest.fixture
@@ -18,37 +12,62 @@ def run_job(job):
1812
return run_job
1913

2014

21-
@pytest.fixture
15+
@pytest.fixture(scope="package")
2216
def openmm_data(test_dir):
2317
return test_dir / "openmm"
2418

2519

2620
@pytest.fixture(scope="package")
27-
def interchange():
28-
o = create_mol_spec("O", 300, charge_method="mmff94")
29-
cco = create_mol_spec("CCO", 10, charge_method="mmff94")
30-
cco2 = create_mol_spec("CCO", 20, name="cco2", charge_method="mmff94")
31-
mol_specs = [o, cco, cco2]
32-
mol_specs.sort(
33-
key=lambda x: tk.Molecule.from_json(x.openff_mol).to_smiles() + x.name
34-
)
35-
36-
topology = pack_box(
37-
molecules=[tk.Molecule.from_json(spec.openff_mol) for spec in mol_specs],
38-
number_of_copies=[spec.count for spec in mol_specs],
39-
mass_density=0.8 * unit.grams / unit.milliliter,
40-
)
41-
42-
mol_specs = merge_specs_by_name_and_smiles(mol_specs)
43-
44-
return Interchange.from_smirnoff(
45-
force_field=ForceField("openff_unconstrained-2.1.1.offxml"),
46-
topology=topology,
47-
charge_from_molecules=[
48-
tk.Molecule.from_json(spec.openff_mol) for spec in mol_specs
49-
],
50-
allow_nonintegral_charges=True,
51-
)
21+
def interchange(openmm_data):
22+
# we use openff to generate the interchange object that we test on
23+
# but we don't want to create a logical dependency on openff, in
24+
# case the user has another way of generating the interchange object
25+
regenerate_test_data = False
26+
if regenerate_test_data:
27+
import openff.toolkit as tk
28+
from openff.interchange import Interchange
29+
from openff.interchange.components._packmol import pack_box
30+
from openff.toolkit import ForceField
31+
from openff.units import unit
32+
33+
from atomate2.openff.utils import (
34+
create_mol_spec,
35+
merge_specs_by_name_and_smiles,
36+
)
37+
from atomate2.openmm.utils import openff_to_openmm_interchange
38+
39+
o = create_mol_spec("O", 300, charge_method="mmff94")
40+
cco = create_mol_spec("CCO", 10, charge_method="mmff94")
41+
cco2 = create_mol_spec("CCO", 20, name="cco2", charge_method="mmff94")
42+
mol_specs = [o, cco, cco2]
43+
mol_specs.sort(
44+
key=lambda x: tk.Molecule.from_json(x.openff_mol).to_smiles() + x.name
45+
)
46+
47+
topology = pack_box(
48+
molecules=[tk.Molecule.from_json(spec.openff_mol) for spec in mol_specs],
49+
number_of_copies=[spec.count for spec in mol_specs],
50+
mass_density=0.8 * unit.grams / unit.milliliter,
51+
)
52+
53+
mol_specs = merge_specs_by_name_and_smiles(mol_specs)
54+
55+
openff_interchange = Interchange.from_smirnoff(
56+
force_field=ForceField("openff_unconstrained-2.1.1.offxml"),
57+
topology=topology,
58+
charge_from_molecules=[
59+
tk.Molecule.from_json(spec.openff_mol) for spec in mol_specs
60+
],
61+
allow_nonintegral_charges=True,
62+
)
63+
64+
openmm_interchange = openff_to_openmm_interchange(openff_interchange)
65+
66+
with open(openmm_data / "interchange.json", "w") as file:
67+
file.write(openmm_interchange.model_dump_json())
68+
69+
with open(openmm_data / "interchange.json") as file:
70+
return OpenMMInterchange.model_validate_json(file.read())
5271

5372

5473
@pytest.fixture

tests/openmm_md/flows/test_core.py

Lines changed: 9 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,16 @@
11
from __future__ import annotations
22

3+
import io
34
import json
45
from pathlib import Path
56

67
import numpy as np
78
import pytest
8-
from emmet.core.openmm import OpenMMTaskDocument
9+
from emmet.core.openmm import OpenMMInterchange, OpenMMTaskDocument
910
from jobflow import Flow
1011
from MDAnalysis import Universe
1112
from monty.json import MontyDecoder
12-
from openff.interchange import Interchange
13+
from openmm.app import PDBFile
1314

1415
from atomate2.openmm.flows.core import OpenMMFlowMaker
1516
from atomate2.openmm.jobs import EnergyMinimizationMaker, NPTMaker, NVTMaker
@@ -156,22 +157,13 @@ def test_flow_maker(interchange, run_job):
156157
calc_output = task_doc.calcs_reversed[0].output
157158
assert len(calc_output.steps_reported) == 5
158159

159-
all_steps = [calc.output.steps_reported for calc in task_doc.calcs_reversed]
160-
assert all_steps == [
161-
[1, 2, 3, 4, 5],
162-
[1],
163-
[1, 2],
164-
[1, 2],
165-
[1, 2, 3, 4, 5],
166-
None,
167-
]
168160
# Test that the state interval is respected
169-
assert calc_output.steps_reported == list(range(1, 6))
161+
assert calc_output.steps_reported == list(range(11, 16))
170162
assert calc_output.traj_file == "trajectory5.dcd"
171163
assert calc_output.state_file == "state5.csv"
172164

173-
interchange = Interchange.parse_raw(task_doc.interchange)
174-
topology = interchange.to_openmm_topology()
165+
interchange = OpenMMInterchange.model_validate_json(task_doc.interchange)
166+
topology = PDBFile(io.StringIO(interchange.topology)).getTopology()
175167
u = Universe(topology, str(Path(task_doc.dir_name) / "trajectory5.dcd"))
176168

177169
assert len(u.trajectory) == 5
@@ -184,8 +176,9 @@ def test_traj_blob_embed(interchange, run_job, tmp_path):
184176
nvt_job = nvt.make(interchange)
185177
task_doc = run_job(nvt_job)
186178

187-
interchange = Interchange.parse_raw(task_doc.interchange)
188-
topology = interchange.to_openmm_topology()
179+
interchange = OpenMMInterchange.model_validate_json(task_doc.interchange)
180+
topology = PDBFile(io.StringIO(interchange.topology)).getTopology()
181+
189182
u = Universe(topology, str(Path(task_doc.dir_name) / "trajectory.dcd"))
190183

191184
assert len(u.trajectory) == 2

tests/openmm_md/jobs/test_base.py

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from emmet.core.openmm import Calculation, CalculationInput, OpenMMTaskDocument
77
from jobflow import Flow, Job
88
from mdareporter import MDAReporter
9+
from openmm import XmlSerializer
910
from openmm.app import Simulation, StateDataReporter
1011
from openmm.openmm import LangevinMiddleIntegrator
1112
from openmm.unit import kelvin, picoseconds
@@ -70,22 +71,30 @@ def test_create_simulation(interchange):
7071
def test_update_interchange(interchange):
7172
interchange = copy.deepcopy(interchange)
7273
maker = BaseOpenMMMaker(wrap_traj=True)
74+
7375
sim = maker._create_simulation(interchange) # noqa: SLF001
74-
start_positions = interchange.positions
75-
start_velocities = interchange.velocities
76-
start_box = interchange.box
76+
77+
state = XmlSerializer.deserialize(interchange.state)
78+
start_positions = state.getPositions(asNumpy=True)
79+
start_velocities = state.getVelocities(asNumpy=True)
80+
start_box = state.getPeriodicBoxVectors()
7781

7882
# Run the simulation for one step
79-
sim.step(1)
83+
sim.step(2)
8084

8185
maker._update_interchange(interchange, sim, None) # noqa: SLF001
8286

83-
assert interchange.positions.shape == start_positions.shape
84-
assert interchange.velocities.shape == (1170, 3)
87+
new_state = XmlSerializer.deserialize(interchange.state)
88+
new_positions = new_state.getPositions(asNumpy=True)
89+
new_velocities = new_state.getVelocities(asNumpy=True)
90+
new_box = new_state.getPeriodicBoxVectors()
91+
92+
assert new_positions.shape == start_positions.shape
93+
assert new_velocities.shape == start_velocities.shape
8594

86-
assert np.any(interchange.positions != start_positions)
87-
assert np.any(interchange.velocities != start_velocities)
88-
assert np.all(interchange.box == start_box)
95+
assert not np.all(new_positions == start_positions)
96+
assert not np.all(new_velocities == start_velocities)
97+
assert np.all(new_box == start_box)
8998

9099

91100
def test_create_task_doc(interchange, tmp_path):

0 commit comments

Comments
 (0)