1
+ from pathlib import Path
2
+
3
+ from ase .calculators .calculator import Calculator
1
4
from jobflow import Flow , run_locally
2
5
from numpy .testing import assert_allclose
3
6
from pymatgen .core import Structure
@@ -17,18 +20,29 @@ def test_phonon_get_supercell_size(clean_dir, si_structure: Structure):
17
20
assert_allclose (responses [job .uuid ][1 ].output , [[6 , - 2 , 0 ], [0 , 6 , 0 ], [- 3 , - 2 , 5 ]])
18
21
19
22
20
- def test_phonon_maker_initialization_with_all_mlff (si_structure ):
23
+ def test_phonon_maker_initialization_with_all_mlff (
24
+ si_structure : Structure , test_dir : Path
25
+ ):
21
26
"""Test PhononMaker can be initialized with all MLFF static and relax makers."""
22
27
28
+ chk_pt_dir = test_dir / "forcefields"
23
29
for mlff in MLFF :
30
+ calc_kwargs = {
31
+ MLFF .Nequip : {"model_path" : f"{ chk_pt_dir } /nequip/nequip_ff_sr_ti_o3.pth" },
32
+ MLFF .SevenNet : { # TODO this currently raises NotImplementedError
33
+ "model" : f"{ chk_pt_dir } /sevennet/2024-07-11-SevenNet-0-serial.pt"
34
+ },
35
+ }.get (mlff , {})
24
36
static_maker = ForceFieldStaticMaker (
25
37
name = f"{ mlff } static" ,
26
38
force_field_name = str (mlff ),
39
+ calculator_kwargs = calc_kwargs ,
27
40
)
28
41
relax_maker = ForceFieldRelaxMaker (
29
42
name = f"{ mlff } relax" ,
30
43
force_field_name = str (mlff ),
31
44
relax_kwargs = {"fmax" : 0.00001 },
45
+ calculator_kwargs = calc_kwargs ,
32
46
)
33
47
34
48
try :
@@ -49,6 +63,13 @@ def test_phonon_maker_initialization_with_all_mlff(si_structure):
49
63
assert flow [4 ].name == "generate_phonon_displacements" , f"{ flow [4 ].name = } "
50
64
assert flow [5 ].name == "run_phonon_displacements" , f"{ flow [5 ].name = } "
51
65
66
+ # expected_calc = ase_calculator(mlff)
67
+ relax_calc = phonon_maker .bulk_relax_maker .calculator
68
+ if mlff == MLFF .Forcefield :
69
+ assert relax_calc is None , f"{ relax_calc = } "
70
+ else :
71
+ assert isinstance (relax_calc , Calculator ), f"{ type (relax_calc )= } "
52
72
except Exception as exc :
73
+ # TODO this requires py3.11
53
74
exc .add_note (f"Failed to initialize PhononMaker with { mlff = } makers" )
54
75
raise
0 commit comments