Skip to content

Commit 5f5ed84

Browse files
Add UMA (#574)
* Add UMA dependency * Add UMA calculator * Add tests for UMA * Add UMA tests to workflows * Fix CI installations * Update CI install and delete cached MLIPs * Change default UMA model * Swap tested UMA models * Reduce memory used by UMA tests * Add invalid path UMA test and reorder * Skip tests if unauthorised for model * ADD UMA to README list * Update docs for UMA/fairchem --------- Co-authored-by: Jacob Wilkins <46597752+oerc0122@users.noreply.github.com>
1 parent b04de1d commit 5f5ed84

File tree

11 files changed

+191
-59
lines changed

11 files changed

+191
-59
lines changed

.github/workflows/ci.yml

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ jobs:
3838
env:
3939
# show timings of tests
4040
PYTEST_ADDOPTS: "--durations=0"
41-
run: uv run pytest --cov janus_core --cov-append .
41+
run: uv run --no-sync pytest --cov janus_core --cov-append .
4242

4343
- name: Install updated e3nn dependencies
4444
run: |
@@ -51,7 +51,7 @@ jobs:
5151
# show timings of tests
5252
PYTEST_ADDOPTS: "--durations=0"
5353
HF_TOKEN: ${{ secrets.HF_TOKEN }}
54-
run: uv run pytest tests/test_{mlip_calculators,single_point}.py
54+
run: uv run --no-sync pytest tests/test_{mlip_calculators,single_point}.py
5555

5656
- name: Install dgl dependencies
5757
run: |
@@ -62,7 +62,25 @@ jobs:
6262
env:
6363
# show timings of tests
6464
PYTEST_ADDOPTS: "--durations=0"
65-
run: uv run pytest tests/test_{mlip_calculators,single_point,eos}.py
65+
run: uv run --no-sync pytest tests/test_{mlip_calculators,single_point,eos}.py
66+
67+
- name: Create space in cache
68+
run: |
69+
rm -rf ~/.cache/*
70+
uv cache clean
71+
72+
- name: Install UMA
73+
run: |
74+
uv sync --extra uma
75+
uv pip install --reinstall pynvml
76+
uv pip install fairchem-core[torch-extras] --no-build-isolation
77+
78+
- name: Run test suite for UMA
79+
env:
80+
# show timings of tests
81+
PYTEST_ADDOPTS: "--durations=0"
82+
HF_TOKEN: ${{ secrets.HF_TOKEN }}
83+
run: uv run --no-sync pytest tests/test_{mlip_calculators,single_point}.py
6684

6785
- name: Report coverage to Coveralls
6886
uses: coverallsapp/github-action@v2

.github/workflows/mac.yml

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ jobs:
3838
env:
3939
# show timings of tests
4040
PYTEST_ADDOPTS: "--durations=0"
41-
run: uv run pytest
41+
run: uv run --no-sync pytest
4242

4343
- name: Install updated e3nn dependencies
4444
run: |
@@ -51,7 +51,7 @@ jobs:
5151
# show timings of tests
5252
PYTEST_ADDOPTS: "--durations=0"
5353
HF_TOKEN: ${{ secrets.HF_TOKEN }}
54-
run: uv run pytest tests/test_{mlip_calculators,single_point}.py
54+
run: uv run --no-sync pytest tests/test_{mlip_calculators,single_point}.py
5555

5656
- name: Install dgl dependencies
5757
run: |
@@ -62,4 +62,17 @@ jobs:
6262
env:
6363
# show timings of tests
6464
PYTEST_ADDOPTS: "--durations=0"
65-
run: uv run pytest tests/test_{mlip_calculators,single_point,eos}.py
65+
run: uv run --no-sync pytest tests/test_{mlip_calculators,single_point,eos}.py
66+
67+
- name: Install uma
68+
run: |
69+
uv sync --extra uma
70+
uv pip install --reinstall pynvml
71+
uv pip install fairchem-core[torch-extras] --no-build-isolation
72+
73+
- name: Run test suite for UMA
74+
env:
75+
# show timings of tests
76+
PYTEST_ADDOPTS: "--durations=0"
77+
HF_TOKEN: ${{ secrets.HF_TOKEN }}
78+
run: uv run --no-sync pytest tests/test_{mlip_calculators,single_point}.py

.github/workflows/windows.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ jobs:
2727
env:
2828
# show timings of tests
2929
PYTEST_ADDOPTS: "--durations=0"
30-
run: uv run pytest
30+
run: uv run --no-sync pytest
3131

3232
- name: Install updated e3nn dependencies
3333
run: |
@@ -40,4 +40,4 @@ jobs:
4040
# show timings of tests
4141
PYTEST_ADDOPTS: "--durations=0"
4242
HF_TOKEN: ${{ secrets.HF_TOKEN }}
43-
run: uv run pytest tests/test_mlip_calculators.py tests/test_single_point.py
43+
run: uv run --no-sync pytest tests/test_mlip_calculators.py tests/test_single_point.py

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,7 @@ Current and planned features include:
9292
- GRACE
9393
- EquiformerV2
9494
- eSEN
95+
- UMA
9596
- PET-MAD
9697
- [x] Single point calculations
9798
- [x] Geometry optimisation

docs/source/user_guide/get_started.rst

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,8 @@ Currently supported MLIP ``extras`` are:
5959
- ``orb``: `Orb <https://github.com/orbital-materials/orb-models>`_
6060
- ``mattersim``: `MatterSim <https://github.com/microsoft/mattersim>`_
6161
- ``grace``: `GRACE <https://github.com/ICAMS/grace-tensorpotential>`_
62-
- ``fairchem``: `eqV2 DeNS/eSEN <https://github.com/FAIR-Chem/fairchem/tree/main/src/fairchem/core>`_
62+
- ``fairchem``: `eqV2 DeNS/eSEN <https://github.com/facebookresearch/fairchem/tree/fairchem_core-1.10.0/src/fairchem/core>`_
63+
- ``uma``: `UMA <https://github.com/FAIR-Chem/fairchem/tree/main/src/fairchem/core>`_
6364
- ``pet-mad``: `PET-MAD <https://github.com/lab-cosmo/pet-mad>`_
6465

6566
.. note::

janus_core/helpers/janus_types.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,7 @@ class Correlation(TypedDict, total=True):
132132
"esen",
133133
"equiformer",
134134
"pet_mad",
135+
"uma",
135136
]
136137
Devices = Literal["cpu", "cuda", "mps", "xpu"]
137138
Ensembles = Literal["nph", "npt", "nve", "nvt", "nvt-nh", "nvt-csvr", "npt-mtk"]

janus_core/helpers/mlip_calculators.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ def _set_model(
5353
"path",
5454
"model_name",
5555
"checkpoint_path",
56+
"predict_unit",
5657
}
5758
present = kwargs.keys() & model_kwargs
5859

@@ -382,6 +383,31 @@ def choose_calculator(
382383
if model is None:
383384
model = LATEST_VERSION
384385

386+
case "uma":
387+
from fairchem.core import FAIRChemCalculator, __version__, pretrained_mlip
388+
from fairchem.core.units.mlip_unit import MLIPPredictUnit
389+
390+
match model:
391+
case MLIPPredictUnit():
392+
predict_unit = model
393+
model = "loaded_Module"
394+
case Path() | str():
395+
predict_unit = pretrained_mlip.get_predict_unit(
396+
model_name=model, device=device
397+
)
398+
case None:
399+
model = "uma-m-1p1"
400+
predict_unit = pretrained_mlip.get_predict_unit(
401+
model_name=model, device=device
402+
)
403+
404+
kwargs.setdefault("task_name", "omat")
405+
406+
calculator = FAIRChemCalculator(
407+
predict_unit=predict_unit,
408+
**kwargs,
409+
)
410+
385411
case _:
386412
raise ValueError(
387413
f"Unrecognized {arch=}. Suported architectures "

pyproject.toml

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ readme = "README.md"
2828
dependencies = [
2929
"ase<4.0,>=3.25",
3030
"codecarbon<3.0.0,>=2.8.4",
31-
"numpy<2.0.0,>=1.26.4",
31+
"numpy<3.0.0,>=1.26.4",
3232
"phonopy<3.0.0,>=2.23.1",
3333
"pymatgen>=2025.1.24",
3434
"pyyaml<7.0.0,>=6.0.1",
@@ -93,6 +93,10 @@ mattersim = [
9393
"mattersim == 1.1.2; sys_platform != 'win32'",
9494
]
9595

96+
uma = [
97+
"fairchem-core == 2.3.0",
98+
]
99+
96100
# MLIPs with dgl dependency
97101
alignn = [
98102
"alignn == 2024.5.27; sys_platform != 'win32'",
@@ -281,4 +285,36 @@ conflicts = [
281285
{ extra = "fairchem" },
282286
{ extra = "all" },
283287
],
288+
[
289+
{ extra = "uma" },
290+
{ extra = "alignn" },
291+
],
292+
[
293+
{ extra = "uma" },
294+
{ extra = "fairchem" },
295+
],
296+
[
297+
{ extra = "uma" },
298+
{ extra = "grace" },
299+
],
300+
[
301+
{ extra = "uma" },
302+
{ extra = "mace" },
303+
],
304+
[
305+
{ extra = "uma" },
306+
{ extra = "mattersim" },
307+
],
308+
[
309+
{ extra = "uma" },
310+
{ extra = "m3gnet" },
311+
],
312+
[
313+
{ extra = "uma" },
314+
{ extra = "sevennet" },
315+
],
316+
[
317+
{ extra = "uma" },
318+
{ extra = "all" },
319+
],
284320
]

tests/test_mlip_calculators.py

Lines changed: 60 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
EQUIFORMER_LABEL = "EquiformerV2-83M-S2EF-OC20-2M"
3232
ESEN_LABEL = "eSEN-30M-MP"
3333

34+
3435
try:
3536
from fairchem.core.models.model_registry import model_name_to_local_file
3637

@@ -59,6 +60,23 @@
5960

6061
SEVENNET_PATH = MODEL_PATH / "sevennet_0.pth"
6162

63+
UMA_LABEL = "uma-s-1"
64+
65+
try:
66+
from fairchem.core import pretrained_mlip
67+
from huggingface_hub.errors import GatedRepoError
68+
69+
try:
70+
UMA_PREDICT_UNIT = pretrained_mlip.get_predict_unit(
71+
model_name=UMA_LABEL, device="cpu"
72+
)
73+
except GatedRepoError:
74+
UMA_PREDICT_UNIT = None
75+
76+
except ImportError:
77+
UMA_PREDICT_UNIT = None
78+
79+
6280
ALIGNN_PATH = MODEL_PATH / "v5.27.2024"
6381
M3GNET_DIR_PATH = MODEL_PATH / "M3GNet-MP-2021.2.8-DIRECT-PES"
6482
M3GNET_MODEL_PATH = M3GNET_DIR_PATH / "model.pt"
@@ -78,17 +96,11 @@
7896
@pytest.mark.parametrize(
7997
"arch, device, kwargs",
8098
[
81-
("mace", "cpu", {"model": MACE_MP_PATH}),
82-
("mace", "cpu", {"model_paths": MACE_MP_PATH}),
83-
("mace", "cpu", {"model_path": MACE_MP_PATH}),
84-
("mace_off", "cpu", {}),
85-
("mace_off", "cpu", {"model": "small"}),
86-
("mace_off", "cpu", {"model_path": MACE_OFF_PATH}),
87-
("mace_off", "cpu", {"model": MACE_OFF_PATH}),
88-
("mace_mp", "cpu", {}),
89-
("mace_mp", "cpu", {"model": "small"}),
90-
("mace_mp", "cpu", {"model_path": MACE_MP_PATH}),
91-
("mace_mp", "cpu", {"model": MACE_MP_PATH}),
99+
("alignn", "cpu", {}),
100+
("alignn", "cpu", {"model_path": ALIGNN_PATH}),
101+
("alignn", "cpu", {"model_path": ALIGNN_PATH / "best_model.pt"}),
102+
("alignn", "cpu", {"model": "alignnff_wt10"}),
103+
("alignn", "cpu", {"path": ALIGNN_PATH}),
92104
("chgnet", "cpu", {}),
93105
("chgnet", "cpu", {"model": "0.2.0"}),
94106
("chgnet", "cpu", {"model_path": CHGNET_PATH}),
@@ -107,8 +119,26 @@
107119
("esen", "cpu", {"model_name": ESEN_LABEL}),
108120
("esen", "cpu", {"model_name": ESEN_PATH}),
109121
("esen", "cpu", {"checkpoint_path": ESEN_PATH}),
122+
("grace", "cpu", {}),
123+
("grace", "cpu", {"model_path": "GRACE-1L-MP-r6"}),
124+
("mace", "cpu", {"model": MACE_MP_PATH}),
125+
("mace", "cpu", {"model_paths": MACE_MP_PATH}),
126+
("mace", "cpu", {"model_path": MACE_MP_PATH}),
127+
("mace_mp", "cpu", {}),
128+
("mace_mp", "cpu", {"model": "small"}),
129+
("mace_mp", "cpu", {"model_path": MACE_MP_PATH}),
130+
("mace_mp", "cpu", {"model": MACE_MP_PATH}),
131+
("mace_off", "cpu", {}),
132+
("mace_off", "cpu", {"model": "small"}),
133+
("mace_off", "cpu", {"model_path": MACE_OFF_PATH}),
134+
("mace_off", "cpu", {"model": MACE_OFF_PATH}),
110135
("mattersim", "cpu", {}),
111136
("mattersim", "cpu", {"model_path": "mattersim-v1.0.0-1m"}),
137+
("m3gnet", "cpu", {}),
138+
("m3gnet", "cpu", {"model_path": M3GNET_DIR_PATH}),
139+
("m3gnet", "cpu", {"model_path": M3GNET_MODEL_PATH}),
140+
("m3gnet", "cpu", {"potential": M3GNET_DIR_PATH}),
141+
("m3gnet", "cpu", {"potential": M3GNET_POTENTIAL}),
112142
("nequip", "cpu", {"model_path": NEQUIP_PATH}),
113143
("nequip", "cpu", {"model": NEQUIP_PATH}),
114144
("orb", "cpu", {}),
@@ -121,18 +151,11 @@
121151
("sevennet", "cpu", {"model_path": SEVENNET_PATH}),
122152
("sevennet", "cpu", {}),
123153
("sevennet", "cpu", {"model": "sevennet-0"}),
124-
("alignn", "cpu", {}),
125-
("alignn", "cpu", {"model_path": ALIGNN_PATH}),
126-
("alignn", "cpu", {"model_path": ALIGNN_PATH / "best_model.pt"}),
127-
("alignn", "cpu", {"model": "alignnff_wt10"}),
128-
("alignn", "cpu", {"path": ALIGNN_PATH}),
129-
("m3gnet", "cpu", {}),
130-
("m3gnet", "cpu", {"model_path": M3GNET_DIR_PATH}),
131-
("m3gnet", "cpu", {"model_path": M3GNET_MODEL_PATH}),
132-
("m3gnet", "cpu", {"potential": M3GNET_DIR_PATH}),
133-
("m3gnet", "cpu", {"potential": M3GNET_POTENTIAL}),
134-
("grace", "cpu", {}),
135-
("grace", "cpu", {"model_path": "GRACE-1L-MP-r6"}),
154+
("uma", "cpu", {"model": UMA_LABEL}),
155+
("uma", "cpu", {"model_path": UMA_LABEL}),
156+
("uma", "cpu", {"model_name": UMA_LABEL}),
157+
("uma", "cpu", {"model": UMA_PREDICT_UNIT}),
158+
("uma", "cpu", {"predict_unit": UMA_PREDICT_UNIT}),
136159
],
137160
)
138161
def test_mlips(arch, device, kwargs):
@@ -164,20 +187,22 @@ def test_invalid_arch():
164187
@pytest.mark.parametrize(
165188
"arch, model",
166189
[
167-
("mace", "/invalid/path"),
168-
("mace_off", "/invalid/path"),
169-
("mace_mp", "/invalid/path"),
190+
("alignn", "invalid/path"),
170191
("chgnet", "/invalid/path"),
171192
("dpa3", "/invalid/path"),
172-
("fairchem", "/invalid/path"),
193+
("equiformer", "/invalid/path"),
194+
("esen", "/invalid/path"),
173195
("grace", "/invalid/path"),
196+
("mace", "/invalid/path"),
197+
("mace_mp", "/invalid/path"),
198+
("mace_off", "/invalid/path"),
174199
("mattersim", "/invalid/path"),
200+
("m3gnet", "/invalid/path"),
175201
("nequip", "/invalid/path"),
176202
("orb", "/invalid/path"),
177203
("pet_mad", "/invalid/path"),
178204
("sevenn", "/invalid/path"),
179-
("alignn", "invalid/path"),
180-
("m3gnet", "/invalid/path"),
205+
("uma", "/invalid/path"),
181206
],
182207
)
183208
def test_invalid_model(arch, model):
@@ -190,10 +215,6 @@ def test_invalid_model(arch, model):
190215
@pytest.mark.parametrize(
191216
"kwargs",
192217
[
193-
{"arch": "mace", "model": MACE_MP_PATH, "model_paths": MACE_MP_PATH},
194-
{"arch": "mace", "model": MACE_MP_PATH, "model_paths": MACE_MP_PATH},
195-
{"arch": "mace", "model_path": MACE_MP_PATH, "model": MACE_MP_PATH},
196-
{"arch": "mace", "model": MACE_MP_PATH, "potential": MACE_MP_PATH},
197218
{
198219
"arch": "alignn",
199220
"model_path": ALIGNN_PATH / "best_model.pt",
@@ -207,11 +228,6 @@ def test_invalid_model(arch, model):
207228
{"arch": "chgnet", "model": CHGNET_PATH, "path": CHGNET_PATH},
208229
{"arch": "dpa3", "model_path": DPA3_PATH, "model": DPA3_PATH},
209230
{"arch": "dpa3", "model": DPA3_PATH, "path": DPA3_PATH},
210-
{
211-
"arch": "esen",
212-
"model_path": ESEN_LABEL,
213-
"model": ESEN_LABEL,
214-
},
215231
{
216232
"arch": "equiformer",
217233
"model_path": EQUIFORMER_LABEL,
@@ -222,6 +238,11 @@ def test_invalid_model(arch, model):
222238
"model_path": "GRACE-1L-MP-r6",
223239
"model": "GRACE-1L-MP-r6",
224240
},
241+
{"arch": "esen", "model_path": ESEN_LABEL, "model": ESEN_LABEL},
242+
{"arch": "mace", "model": MACE_MP_PATH, "model_paths": MACE_MP_PATH},
243+
{"arch": "mace", "model": MACE_MP_PATH, "model_paths": MACE_MP_PATH},
244+
{"arch": "mace", "model_path": MACE_MP_PATH, "model": MACE_MP_PATH},
245+
{"arch": "mace", "model": MACE_MP_PATH, "potential": MACE_MP_PATH},
225246
{
226247
"arch": "mattersim",
227248
"model": "mattersim-v1.0.0-1m",
@@ -243,6 +264,7 @@ def test_invalid_model(arch, model):
243264
},
244265
{"arch": "sevennet", "model_path": SEVENNET_PATH, "model": SEVENNET_PATH},
245266
{"arch": "sevennet", "model": SEVENNET_PATH, "path": SEVENNET_PATH},
267+
{"arch": "uma", "model_path": UMA_LABEL, "model": UMA_LABEL},
246268
],
247269
)
248270
def test_model_model_paths(kwargs):

0 commit comments

Comments
 (0)