Skip to content

Commit 9363e4b

Browse files
Add calculator option for dispersion correction (#573)
* Add option to add dispersion to all calculators * Add d3 optional extra * Test D3 dispersion * Add D3 to all extra * Add m3gnet D3 test * Add D3 tests to CI * Add test skip for missing torch_dftd * Add d3 to extras list * Add MLIP parameters to sum calculator * Move D3 import into function * Test mace_mp dispersion is consistent * Improve dispersion defaults * Modify arch saved when adding D3 * Test adding D3 calculator * Test mace_mp with dispersion kwargs * Add docs for dispersion * Change dispersion label * Update arch for dispersion in tests * Remove unnecessary dispersion check * Update torch-dftd dependency
1 parent eef755a commit 9363e4b

File tree

11 files changed

+240
-12
lines changed

11 files changed

+240
-12
lines changed

.github/workflows/ci.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ jobs:
4242

4343
- name: Install updated e3nn dependencies
4444
run: |
45-
uv sync --extra mattersim --extra fairchem
45+
uv sync --extra mattersim --extra fairchem --extra d3
4646
uv pip install --reinstall pynvml
4747
uv pip install fairchem-core[torch-extras] --no-build-isolation
4848
@@ -55,7 +55,7 @@ jobs:
5555

5656
- name: Install dgl dependencies
5757
run: |
58-
uv sync --extra mace --extra m3gnet --extra alignn
58+
uv sync --extra mace --extra m3gnet --extra alignn --extra d3
5959
uv pip install --reinstall pynvml
6060
6161
- name: Run test suite for dgl dependencies

.github/workflows/mac.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ jobs:
4242

4343
- name: Install updated e3nn dependencies
4444
run: |
45-
uv sync --extra mattersim --extra fairchem
45+
uv sync --extra mattersim --extra fairchem --extra d3
4646
uv pip install --reinstall pynvml
4747
uv pip install "fairchem-core[torch-extras]" --no-build-isolation
4848
@@ -55,7 +55,7 @@ jobs:
5555

5656
- name: Install dgl dependencies
5757
run: |
58-
uv sync --extra mace --extra m3gnet --extra alignn
58+
uv sync --extra mace --extra m3gnet --extra alignn --extra d3
5959
uv pip install --reinstall pynvml
6060
6161
- name: Run test suite for dgl dependencies

docs/source/conf.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -243,4 +243,5 @@
243243
("py:class", "Progress"),
244244
("py:class", "ProgressBar"),
245245
("py:class", "typer.models.CallbackParam"),
246+
("py:class", "SumCalculator"),
246247
]

docs/source/user_guide/get_started.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@ Currently supported MLIP ``extras`` are:
7171

7272
Additional features can also be enabled as ``extras``:
7373

74+
- ``d3``: `DFTD3 <https://github.com/pfnet-research/torch-dftd>`_
7475
- ``visualise``: `WEAS Widget <https://github.com/superstar54/weas-widget>`_
7576
- ``plumed``: `PLUMED <https://www.plumed.org>`_
7677

docs/source/user_guide/python.rst

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,60 @@ will return
9898
unless the same ``arch`` is chosen, in which case these values will also be overwritten.
9999

100100

101+
D3 Dispersion
102+
=============
103+
104+
A PyTorch implementation of DFTD2 and DFTD3, using the `TorchDFTD3Calculator <https://github.com/pfnet-research/torch-dftd>`_,
105+
can be used to add dispersion corrections to MLIP predictions.
106+
107+
The required Python pacakge is included with ``mace_mp``, but can also be installed as its own extra:
108+
109+
.. code-block:: bash
110+
111+
pip install janus-core[d3]
112+
113+
114+
Once installed, dispersion can be added through ``calc_kwargs`` through the ``dispersion`` keyword,
115+
with ``dispersion_kwargs`` used to pass any further keywords to the ``TorchDFTD3Calculator``:
116+
117+
.. code-block:: python
118+
119+
from ase import units
120+
121+
from janus_core.calculations.single_point import SinglePoint
122+
123+
single_point = SinglePoint(
124+
struct="tests/data/NaCl.cif",
125+
arch="mace_mp",
126+
model="tests/models/mace_mp_small.model",
127+
calc_kwargs={"dispersion": True, "dispersion_kwargs": {"cutoff": 95.0 * units.Bohr}}
128+
)
129+
130+
.. note::
131+
In most cases, defaults for ``dispersion_kwargs`` are those set within ``TorchDFTD3Calculator``,
132+
but in the case of ``mace_mp``, we mirror the corresponding defaults from the
133+
``mace.calculators.mace_mp`` function.
134+
135+
136+
The ``TorchDFTD3Calculator`` can also be added to any existing calculator if required:
137+
138+
.. note::
139+
Keyword arguments for ``TorchDFTD3Calculator`` should be passed directly here,
140+
as shown with ``cutoff``. This will not have access to ``mace_mp`` default values,
141+
so will always use defaults from ``TorchDFTD3Calculator``.
142+
143+
144+
.. code-block:: python
145+
146+
from ase import units
147+
148+
from janus_core.helpers.mlip_calculators import add_dispersion, choose_calculator
149+
150+
mace_calc = choose_calculator("mace_mp")
151+
calc = add_dispersion(mace_calc, device="cpu", cutoff=95 * units.Bohr)
152+
153+
154+
101155
Additional Calculators
102156
======================
103157

janus_core/helpers/mlip_calculators.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,9 @@
1212
from pathlib import Path
1313
from typing import TYPE_CHECKING, Any, get_args
1414

15+
from ase import units
1516
from ase.calculators.mixing import SumCalculator
17+
from torch import get_default_dtype
1618

1719
from janus_core.helpers.janus_types import Architectures, Devices, PathLike
1820
from janus_core.helpers.utils import none_to_dict
@@ -84,10 +86,59 @@ def _set_no_weights_only_load():
8486
environ.setdefault("TORCH_FORCE_NO_WEIGHTS_ONLY_LOAD", "1")
8587

8688

89+
def add_dispersion(
90+
calc: Calculator,
91+
device: Devices = "cpu",
92+
dtype: torch.dtype | None = None,
93+
**kwargs,
94+
) -> SumCalculator:
95+
"""
96+
Add D3 dispersion calculator to existing calculator.
97+
98+
Parameters
99+
----------
100+
calc
101+
Calculator to add D3 correction to.
102+
device
103+
Device to run calculator on. Default is "cpu".
104+
dtype
105+
Calculation precision. Default is current torch dtype.
106+
**kwargs
107+
Additional keyword arguments passed to `TorchDFTD3Calculator`.
108+
109+
Returns
110+
-------
111+
SumCalculator
112+
Configured calculator with D3 dispersion correction added.
113+
"""
114+
try:
115+
from torch_dftd.torch_dftd3_calculator import TorchDFTD3Calculator
116+
except ImportError as err:
117+
raise ImportError("Please install the d3 extra.") from err
118+
119+
dtype = dtype if dtype else get_default_dtype()
120+
121+
d3_calc = TorchDFTD3Calculator(
122+
device=device,
123+
dtype=dtype,
124+
**kwargs,
125+
)
126+
sum_calc = SumCalculator([calc, d3_calc])
127+
128+
# Copy calculator parameters to make more accessible
129+
sum_calc.parameters = calc.parameters
130+
if "arch" in sum_calc.parameters:
131+
sum_calc.parameters["arch"] = sum_calc.parameters["arch"] + "_d3"
132+
133+
return sum_calc
134+
135+
87136
def choose_calculator(
88137
arch: Architectures,
89138
device: Devices = "cpu",
90139
model: PathLike | None = None,
140+
dispersion: bool = False,
141+
dispersion_kwargs: dict[str, Any] | None = None,
91142
**kwargs,
92143
) -> Calculator:
93144
"""
@@ -101,6 +152,11 @@ def choose_calculator(
101152
Device to run calculator on. Default is "cpu".
102153
model
103154
MLIP model label, path to model, or loaded model. Default is `None`.
155+
dispersion
156+
Whether to add D3 dispersion.
157+
dispersion_kwargs
158+
Additional keyword arguments for `TorchDFTD3Calculator`. Defaults for mace_mp
159+
are taken from mace_mp's defaults.
104160
**kwargs
105161
Additional keyword arguments passed to the selected calculator.
106162
@@ -116,6 +172,8 @@ def choose_calculator(
116172
ValueError
117173
Invalid architecture specified.
118174
"""
175+
dispersion_kwargs = dispersion_kwargs if dispersion_kwargs else {}
176+
119177
model = _set_model(model, kwargs)
120178

121179
if device not in get_args(Devices):
@@ -147,6 +205,13 @@ def choose_calculator(
147205
model = model if model else "small"
148206
kwargs.setdefault("default_dtype", "float64")
149207

208+
# Set mace_mp dispersion defaults
209+
dispersion_kwargs.setdefault("damping", kwargs.pop("damping", "bj"))
210+
dispersion_kwargs.setdefault("xc", kwargs.pop("dispersion_xc", "pbe"))
211+
dispersion_kwargs.setdefault(
212+
"cutoff", kwargs.pop("dispersion_cutoff", 40.0 * units.Bohr)
213+
)
214+
150215
calculator = mace_mp(model=model, device=device, **kwargs)
151216

152217
case "mace_off":
@@ -420,6 +485,9 @@ def choose_calculator(
420485
calculator.parameters["arch"] = arch
421486
calculator.parameters["model"] = str(model)
422487

488+
if dispersion:
489+
return add_dispersion(calc=calculator, device=device, **dispersion_kwargs)
490+
423491
return calculator
424492

425493

pyproject.toml

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,12 +46,15 @@ chgnet = [
4646
dpa3 = [
4747
"deepmd-kit == 3.1.0",
4848
]
49+
d3 = [
50+
"torch-dftd==0.5.1",
51+
]
4952
grace = [
5053
"tensorpotential == 0.5.1",
5154
]
5255
mace = [
5356
"mace-torch==0.3.13",
54-
"torch-dftd==0.4.0",
57+
"janus-core[d3]",
5558
]
5659
nequip = [
5760
"nequip == 0.6.1",
@@ -73,8 +76,8 @@ visualise = [
7376
]
7477
all = [
7578
"janus-core[chgnet]",
76-
"janus-core[dpa3]",
7779
"janus-core[grace]",
80+
"janus-core[d3]",
7881
"janus-core[mace]",
7982
"janus-core[nequip]",
8083
"janus-core[orb]",

tests/test_descriptors.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ def test_dispersion():
119119
descriptors_disp.run()
120120

121121
assert (
122-
descriptors_disp.struct.info["mace_mp_descriptor"]
122+
descriptors_disp.struct.info["mace_mp_d3_descriptor"]
123123
== descriptors.struct.info["mace_mp_descriptor"]
124124
)
125125

tests/test_mlip_calculators.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
import pytest
1010

11-
from janus_core.helpers.mlip_calculators import choose_calculator
11+
from janus_core.helpers.mlip_calculators import add_dispersion, choose_calculator
1212
from tests.utils import skip_extras
1313

1414
MODEL_PATH = Path(__file__).parent / "models"
@@ -279,3 +279,24 @@ def test_invalid_device(arch):
279279
"""Test error raised if invalid device is specified."""
280280
with pytest.raises(ValueError):
281281
choose_calculator(arch=arch, device="invalid")
282+
283+
284+
def test_d3():
285+
"""Test adding D3 dispersion calculator automatically."""
286+
skip_extras("mace_mp")
287+
288+
calculator = choose_calculator(arch="mace_mp", dispersion=True)
289+
assert calculator.parameters["version"] is not None
290+
assert calculator.parameters["model"] is not None
291+
assert calculator.parameters["arch"] == "mace_mp_d3"
292+
293+
294+
def test_d3_manual():
295+
"""Test adding D3 dispersion calculator manually."""
296+
skip_extras("mace_mp")
297+
298+
calculator = choose_calculator(arch="mace_mp")
299+
calculator = add_dispersion(calculator)
300+
assert calculator.parameters["version"] is not None
301+
assert calculator.parameters["model"] is not None
302+
assert calculator.parameters["arch"] == "mace_mp_d3"

tests/test_single_point.py

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
import shutil
77
from urllib.error import HTTPError, URLError
88

9+
from ase import units
10+
from ase.calculators.mixing import SumCalculator
911
from ase.calculators.singlepoint import SinglePointCalculator
1012
from ase.io import read
1113
from numpy import isfinite
@@ -549,3 +551,81 @@ def test_missing_arch(struct):
549551

550552
with pytest.raises(ValueError, match="A calculator must be attached"):
551553
SinglePoint(struct=struct)
554+
555+
556+
@pytest.mark.parametrize(
557+
"arch, kwargs, pred",
558+
[
559+
("m3gnet", {}, -0.08281749),
560+
(
561+
"mace_mp",
562+
{"damping": "zero", "xc": "pbe", "cutoff": 95 * units.Bohr},
563+
-0.08281749,
564+
),
565+
("mace_off", {}, -0.08281747),
566+
("mattersim", {}, -0.08281749),
567+
("sevennet", {}, -0.08281749),
568+
],
569+
)
570+
def test_dispersion(arch, kwargs, pred):
571+
"""Test dispersion correction."""
572+
skip_extras(arch)
573+
pytest.importorskip("torch_dftd")
574+
575+
data_path = DATA_PATH / "benzene.xyz"
576+
sp_no_d3 = SinglePoint(
577+
struct=data_path,
578+
arch=arch,
579+
properties="energy",
580+
calc_kwargs={"dispersion": False},
581+
)
582+
assert not isinstance(sp_no_d3.struct.calc, SumCalculator)
583+
no_d3_results = sp_no_d3.run()
584+
585+
sp_d3 = SinglePoint(
586+
struct=data_path,
587+
arch=arch,
588+
properties="energy",
589+
calc_kwargs={"dispersion": True, "dispersion_kwargs": {**kwargs}},
590+
)
591+
assert isinstance(sp_d3.struct.calc, SumCalculator)
592+
d3_results = sp_d3.run()
593+
594+
assert (d3_results["energy"] - no_d3_results["energy"]) == pytest.approx(pred)
595+
596+
597+
def test_mace_mp_dispersion():
598+
"""Test mace_mp dispersion correction matches default."""
599+
skip_extras("mace_mp")
600+
pytest.importorskip("torch_dftd")
601+
602+
from mace.calculators import mace_mp
603+
604+
data_path = DATA_PATH / "benzene.xyz"
605+
606+
no_d3_energy = SinglePoint(
607+
struct=data_path,
608+
arch="mace_mp",
609+
properties="energy",
610+
calc_kwargs={"dispersion": False},
611+
).run()["energy"]
612+
613+
d3_energy = SinglePoint(
614+
struct=data_path,
615+
arch="mace_mp",
616+
properties="energy",
617+
calc_kwargs={"dispersion": True},
618+
).run()["energy"]
619+
620+
struct = read(data_path)
621+
struct.calc = mace_mp(model="small", dispersion=True)
622+
623+
mace_d3_energy = SinglePoint(
624+
struct=struct,
625+
properties="energy",
626+
calc_kwargs={"dispersion": False},
627+
).run()["energy"]
628+
629+
# Different default to other architectures
630+
assert d3_energy - no_d3_energy == pytest.approx(-0.29815768)
631+
assert d3_energy == pytest.approx(mace_d3_energy)

0 commit comments

Comments
 (0)