Skip to content

Commit 6496f65

Browse files
committed
test relax_calcs are ase calculator instances
upload smallest sevennet checkpoint i could find, fails to load for some reason
1 parent b88b34d commit 6496f65

File tree

2 files changed

+22
-1
lines changed

2 files changed

+22
-1
lines changed

tests/common/jobs/test_phonon.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
from pathlib import Path
2+
3+
from ase.calculators.calculator import Calculator
14
from jobflow import Flow, run_locally
25
from numpy.testing import assert_allclose
36
from pymatgen.core import Structure
@@ -17,18 +20,29 @@ def test_phonon_get_supercell_size(clean_dir, si_structure: Structure):
1720
assert_allclose(responses[job.uuid][1].output, [[6, -2, 0], [0, 6, 0], [-3, -2, 5]])
1821

1922

20-
def test_phonon_maker_initialization_with_all_mlff(si_structure):
23+
def test_phonon_maker_initialization_with_all_mlff(
24+
si_structure: Structure, test_dir: Path
25+
):
2126
"""Test PhononMaker can be initialized with all MLFF static and relax makers."""
2227

28+
chk_pt_dir = test_dir / "forcefields"
2329
for mlff in MLFF:
30+
calc_kwargs = {
31+
MLFF.Nequip: {"model_path": f"{chk_pt_dir}/nequip/nequip_ff_sr_ti_o3.pth"},
32+
MLFF.SevenNet: { # TODO this currently raises NotImplementedError
33+
"model": f"{chk_pt_dir}/sevennet/2024-07-11-SevenNet-0-serial.pt"
34+
},
35+
}.get(mlff, {})
2436
static_maker = ForceFieldStaticMaker(
2537
name=f"{mlff} static",
2638
force_field_name=str(mlff),
39+
calculator_kwargs=calc_kwargs,
2740
)
2841
relax_maker = ForceFieldRelaxMaker(
2942
name=f"{mlff} relax",
3043
force_field_name=str(mlff),
3144
relax_kwargs={"fmax": 0.00001},
45+
calculator_kwargs=calc_kwargs,
3246
)
3347

3448
try:
@@ -49,6 +63,13 @@ def test_phonon_maker_initialization_with_all_mlff(si_structure):
4963
assert flow[4].name == "generate_phonon_displacements", f"{flow[4].name=}"
5064
assert flow[5].name == "run_phonon_displacements", f"{flow[5].name=}"
5165

66+
# expected_calc = ase_calculator(mlff)
67+
relax_calc = phonon_maker.bulk_relax_maker.calculator
68+
if mlff == MLFF.Forcefield:
69+
assert relax_calc is None, f"{relax_calc=}"
70+
else:
71+
assert isinstance(relax_calc, Calculator), f"{type(relax_calc)=}"
5272
except Exception as exc:
73+
# TODO this requires py3.11
5374
exc.add_note(f"Failed to initialize PhononMaker with {mlff=} makers")
5475
raise
Binary file not shown.

0 commit comments

Comments
 (0)