Skip to content

Commit 66cec7a

Browse files
QuantumChemisthrushikesh-s
authored andcommitted
Add possibility to use your own M3GNet potential (materialsproject#911)
* allow the possibility to use your own M3GNet potential allow the possibility to use your own M3GNet potential, instead of the pretrained model only. * added a unit test * added a unit test * test_dir not needed * change kwargs passing
1 parent 13bdff1 commit 66cec7a

File tree

2 files changed

+24
-1
lines changed

2 files changed

+24
-1
lines changed

src/atomate2/forcefields/utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -421,7 +421,8 @@ def ase_calculator(calculator_meta: str | dict, **kwargs: Any) -> Calculator | N
421421
import matgl
422422
from matgl.ext.ase import PESCalculator
423423

424-
potential = matgl.load_model("M3GNet-MP-2021.2.8-PES")
424+
path = kwargs.get("path", "M3GNet-MP-2021.2.8-PES")
425+
potential = matgl.load_model(path)
425426
calculator = PESCalculator(potential, **kwargs)
426427

427428
elif calculator_name == MLFF.MACE:

tests/forcefields/test_utils.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -163,3 +163,25 @@ def test_fix_symmetry(fix_symmetry):
163163
assert symmetry_init["number"] == symmetry_final["number"] == 229
164164
else:
165165
assert symmetry_init["number"] != symmetry_final["number"] == 99
166+
167+
168+
def test_m3gnet_pot():
169+
import matgl
170+
from matgl.ext.ase import PESCalculator
171+
172+
kwargs_calc = {"path": "M3GNet-MP-2021.2.8-DIRECT-PES", "stress_weight": 2.0}
173+
kwargs_default = {"stress_weight": 2.0}
174+
175+
m3gnet_calculator = ase_calculator(calculator_meta="MLFF.M3GNet", **kwargs_calc)
176+
177+
# uses "M3GNet-MP-2021.2.8-PES" per default
178+
m3gnet_default = ase_calculator(calculator_meta="MLFF.M3GNet", **kwargs_default)
179+
180+
potential = matgl.load_model("M3GNet-MP-2021.2.8-DIRECT-PES")
181+
m3gnet_pes_calc = PESCalculator(potential=potential, stress_weight=2.0)
182+
183+
assert str(m3gnet_pes_calc.potential) == str(m3gnet_calculator.potential)
184+
# casting necessary because <class 'matgl.apps.pes.Potential'> can't be compared
185+
assert str(m3gnet_pes_calc.potential) != str(m3gnet_default.potential)
186+
assert m3gnet_pes_calc.stress_weight == m3gnet_calculator.stress_weight
187+
assert m3gnet_pes_calc.stress_weight == m3gnet_default.stress_weight

0 commit comments

Comments
 (0)