Skip to content

Commit b04de1d

Browse files
Add PET-MAD (#575)
* Add PET-MAD dependency * Add PET-MAD architecture * Test PET-MAD architecture * Add PET-MAD to docs * Save PET-MAD model version * Remove Windows support for PET-MAD * Update docs for PET-MAD
1 parent 435229d commit b04de1d

File tree

10 files changed

+47
-3
lines changed

10 files changed

+47
-3
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,7 @@ Current and planned features include:
9292
- GRACE
9393
- EquiformerV2
9494
- eSEN
95+
- PET-MAD
9596
- [x] Single point calculations
9697
- [x] Geometry optimisation
9798
- [x] Molecular Dynamics

docs/source/user_guide/get_started.rst

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,10 +60,11 @@ Currently supported MLIP ``extras`` are:
6060
- ``mattersim``: `MatterSim <https://github.com/microsoft/mattersim>`_
6161
- ``grace``: `GRACE <https://github.com/ICAMS/grace-tensorpotential>`_
6262
- ``fairchem``: `eqV2 DeNS/eSEN <https://github.com/FAIR-Chem/fairchem/tree/main/src/fairchem/core>`_
63+
- ``pet-mad``: `PET-MAD <https://github.com/lab-cosmo/pet-mad>`_
6364

6465
.. note::
6566

66-
``orb`` and ``mattersim`` are not currently compatible with Windows natively,
67+
``orb``, ``mattersim``, and ``pet-mad`` are not currently compatible with Windows natively,
6768
but can be installed and run via Windows Subsystem for Linux.
6869

6970

docs/source/user_guide/installation.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -158,5 +158,5 @@ MLIPs with limited OS support
158158
-----------------------------
159159

160160
Several MLIP packages have limited support on Windows. We are currently unable to
161-
support ``orb``, ``mattersim``, ``alignn`` or ``matgl`` as ``extras`` on Windows, so they
161+
support ``orb``, ``mattersim``, ``pet-mad``, ``alignn`` or ``matgl`` as ``extras`` on Windows, so they
162162
must be installed manually.

janus_core/helpers/janus_types.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,7 @@ class Correlation(TypedDict, total=True):
131131
"grace",
132132
"esen",
133133
"equiformer",
134+
"pet_mad",
134135
]
135136
Devices = Literal["cpu", "cuda", "mps", "xpu"]
136137
Ensembles = Literal["nph", "npt", "nve", "nvt", "nvt-nh", "nvt-csvr", "npt-mtk"]

janus_core/helpers/mlip_calculators.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -370,6 +370,18 @@ def choose_calculator(
370370
**kwargs,
371371
)
372372

373+
case "pet_mad":
374+
from pet_mad import __version__
375+
from pet_mad._version import LATEST_VERSION
376+
from pet_mad.calculator import PETMADCalculator
377+
378+
calculator = PETMADCalculator(
379+
checkpoint_path=model, device=device, **kwargs
380+
)
381+
382+
if model is None:
383+
model = LATEST_VERSION
384+
373385
case _:
374386
raise ValueError(
375387
f"Unrecognized {arch=}. Suported architectures "

pyproject.toml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,9 @@ nequip = [
5959
orb = [
6060
"orb-models == 0.5.4; sys_platform != 'win32'",
6161
]
62+
pet-mad = [
63+
"pet-mad == 1.3.1; sys_platform != 'win32'"
64+
]
6265
plumed = [
6366
"plumed<3.0.0,>=2.9.0; sys_platform != 'win32'",
6467
]
@@ -75,6 +78,7 @@ all = [
7578
"janus-core[mace]",
7679
"janus-core[nequip]",
7780
"janus-core[orb]",
81+
"janus-core[pet-mad]",
7882
"janus-core[plumed]",
7983
"janus-core[sevennet]",
8084
"janus-core[visualise]",

tests/models/pet-mad-1.1.0.pt

12.6 MB
Binary file not shown.

tests/test_mlip_calculators.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,10 @@
7070
except ImportError:
7171
M3GNET_POTENTIAL = None
7272

73+
PET_MAD_CHECKPOINT = (
74+
"https://huggingface.co/lab-cosmo/pet-mad/resolve/v1.1.0/models/pet-mad-v1.1.0.ckpt"
75+
)
76+
7377

7478
@pytest.mark.parametrize(
7579
"arch, device, kwargs",
@@ -109,6 +113,9 @@
109113
("nequip", "cpu", {"model": NEQUIP_PATH}),
110114
("orb", "cpu", {}),
111115
("orb", "cpu", {"model": ORB_MODEL}),
116+
("pet_mad", "cpu", {}),
117+
("pet_mad", "cpu", {"model": PET_MAD_CHECKPOINT}),
118+
("pet_mad", "cpu", {"checkpoint_path": PET_MAD_CHECKPOINT}),
112119
("sevennet", "cpu", {"model": SEVENNET_PATH}),
113120
("sevennet", "cpu", {"path": SEVENNET_PATH}),
114121
("sevennet", "cpu", {"model_path": SEVENNET_PATH}),
@@ -167,6 +174,7 @@ def test_invalid_arch():
167174
("mattersim", "/invalid/path"),
168175
("nequip", "/invalid/path"),
169176
("orb", "/invalid/path"),
177+
("pet_mad", "/invalid/path"),
170178
("sevenn", "/invalid/path"),
171179
("alignn", "invalid/path"),
172180
("m3gnet", "/invalid/path"),
@@ -228,6 +236,11 @@ def test_invalid_model(arch, model):
228236
{"arch": "nequip", "model": NEQUIP_PATH, "path": NEQUIP_PATH},
229237
{"arch": "orb", "model_path": ORB_MODEL, "model": ORB_MODEL},
230238
{"arch": "orb", "model": ORB_MODEL, "path": ORB_MODEL},
239+
{
240+
"arch": "pet_mad",
241+
"model": PET_MAD_CHECKPOINT,
242+
"checkpoint_path": PET_MAD_CHECKPOINT,
243+
},
231244
{"arch": "sevennet", "model_path": SEVENNET_PATH, "model": SEVENNET_PATH},
232245
{"arch": "sevennet", "model": SEVENNET_PATH, "path": SEVENNET_PATH},
233246
],

tests/test_single_point.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,11 @@
2323
ESEN_LABEL = "eSEN-30M-MP"
2424
MACE_PATH = MODEL_PATH / "mace_mp_small.model"
2525
NEQUIP_PATH = MODEL_PATH / "toluene.pth"
26+
PET_MAD_CHECKPOINT = (
27+
"https://huggingface.co/lab-cosmo/pet-mad/resolve/v1.1.0/models/pet-mad-v1.1.0.ckpt"
28+
)
2629
SEVENNET_PATH = MODEL_PATH / "sevennet_0.pth"
2730

28-
2931
test_data = [
3032
("benzene.xyz", -76.0605725422795, "energy", "energy", {}, None),
3133
(
@@ -131,6 +133,14 @@ def test_potential_energy(struct, expected, properties, prop_key, calc_kwargs, i
131133
),
132134
("orb", "cpu", -27.08186149597168, "NaCl.cif", {}),
133135
("orb", "cpu", -27.089094161987305, "NaCl.cif", {"model": "orb-v2"}),
136+
("pet_mad", "cpu", -27.47624969482422, "NaCl.cif", {}),
137+
(
138+
"pet_mad",
139+
"cpu",
140+
-27.47624969482422,
141+
"NaCl.cif",
142+
{"model": PET_MAD_CHECKPOINT},
143+
),
134144
(
135145
"sevennet",
136146
"cpu",

tests/utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,8 @@ def skip_extras(arch: str):
149149
pytest.importorskip("nequip")
150150
case "orb":
151151
pytest.importorskip("orb_models")
152+
case "pet_mad":
153+
pytest.importorskip("pet_mad")
152154
case "sevennet":
153155
pytest.importorskip("sevenn")
154156
case "alignn":

0 commit comments

Comments
 (0)