Skip to content

Commit 917f27e

Browse files
authored
Add pass-through arguments and allow setting parameters in PySCF wrapper (#44)
1 parent cf9d7a8 commit 917f27e

File tree

2 files changed

+125
-20
lines changed

2 files changed

+125
-20
lines changed

python/dftd3/pyscf.py

Lines changed: 79 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -21,12 +21,13 @@
2121
"""
2222

2323
try:
24-
from pyscf import lib, gto
24+
from pyscf import gto, lib, mcscf, scf
25+
from pyscf.grad import rhf as rhf_grad
2526
except ModuleNotFoundError:
2627
raise ModuleNotFoundError("This submodule requires pyscf installed")
2728

2829
import numpy as np
29-
from typing import Tuple
30+
from typing import Dict, Optional, Tuple
3031

3132
from .interface import (
3233
DispersionModel,
@@ -66,6 +67,52 @@ class DFTD3Dispersion(lib.StreamObject):
6667
``"d3op"``
6768
Optimized power damping function
6869
70+
Custom parameters can be provided with the `param` dictionary.
71+
The `param` dict contains the damping parameters, at least s8, a1 and a2
72+
must be provided for rational damping, while s8 and rs6 are required in case
73+
of zero damping.
74+
75+
Parameters for (modified) rational damping are:
76+
77+
======================== =========== ============================================
78+
Tweakable parameter Default Description
79+
======================== =========== ============================================
80+
s6 1.0 Scaling of the dipole-dipole dispersion
81+
s8 None Scaling of the dipole-quadrupole dispersion
82+
s9 1.0 Scaling of the three-body dispersion energy
83+
a1 None Scaling of the critical radii
84+
a2 None Offset of the critical radii
85+
alp 14.0 Exponent of the zero damping (ATM only)
86+
======================== =========== ============================================
87+
88+
Parameters for (modified) zero damping are:
89+
90+
======================== =========== ===================================================
91+
Tweakable parameter Default Description
92+
======================== =========== ===================================================
93+
s6 1.0 Scaling of the dipole-dipole dispersion
94+
s8 None Scaling of the dipole-quadrupole dispersion
95+
s9 1.0 Scaling of the three-body dispersion energy
96+
rs6 None Scaling of the dipole-dipole damping
97+
rs8 1.0 Scaling of the dipole-quadrupole damping
98+
alp 14.0 Exponent of the zero damping
99+
bet None Offset for damping radius (modified zero damping)
100+
======================== =========== ===================================================
101+
102+
Parameters for optimized power damping are:
103+
104+
======================== =========== ============================================
105+
Tweakable parameter Default Description
106+
======================== =========== ============================================
107+
s6 1.0 Scaling of the dipole-dipole dispersion
108+
s8 None Scaling of the dipole-quadrupole dispersion
109+
s9 1.0 Scaling of the three-body dispersion energy
110+
a1 None Scaling of the critical radii
111+
a2 None Offset of the critical radii
112+
alp 14.0 Exponent of the zero damping (ATM only)
113+
bet None Power for the zero-damping component
114+
======================== =========== ============================================
115+
69116
The version of the damping can be changed after constructing the dispersion correction.
70117
With the `atm` boolean the three-body dispersion energy can be enabled, which is
71118
generally recommended.
@@ -107,14 +154,22 @@ class DFTD3Dispersion(lib.StreamObject):
107154
array(-0.00574289)
108155
"""
109156

110-
def __init__(self, mol, xc="hf", version="d3bj", atm=False):
157+
def __init__(
158+
self,
159+
mol: gto.Mole,
160+
xc: str = "hf",
161+
version: str = "d3bj",
162+
atm: bool = False,
163+
param: Optional[Dict[str, float]] = None,
164+
):
111165
self.mol = mol
112166
self.verbose = mol.verbose
113167
self.xc = xc
168+
self.param = param
114169
self.atm = atm
115170
self.version = version
116171

117-
def dump_flags(self, verbose=None):
172+
def dump_flags(self, verbose: Optional[bool] = None):
118173
"""
119174
Show options used for the DFT-D3 dispersion correction.
120175
"""
@@ -168,16 +223,19 @@ def kernel(self) -> Tuple[float, np.ndarray]:
168223
mol.atom_coords(),
169224
)
170225

171-
param = _damping_param[self.version](
172-
method=self.xc,
173-
atm=self.atm,
174-
)
226+
if self.param is not None:
227+
param = _damping_param[self.version](**self.param)
228+
else:
229+
param = _damping_param[self.version](
230+
method=self.xc,
231+
atm=self.atm,
232+
)
175233

176234
res = disp.get_dispersion(param=param, grad=True)
177235

178236
return res.get("energy"), res.get("gradient")
179237

180-
def reset(self, mol):
238+
def reset(self, mol: gto.Mole):
181239
"""Reset mol and clean up relevant attributes for scanner mode"""
182240
self.mol = mol
183241
return self
@@ -199,7 +257,7 @@ class _DFTD3Grad:
199257
pass
200258

201259

202-
def energy(mf):
260+
def energy(mf: scf.hf.SCF, **kwargs) -> scf.hf.SCF:
203261
"""
204262
Apply DFT-D3 corrections to SCF or MCSCF methods by returning an
205263
instance of a new class built from the original instances class.
@@ -208,8 +266,10 @@ def energy(mf):
208266
209267
Parameters
210268
----------
211-
mf
269+
mf: scf.hf.SCF
212270
The method to which DFT-D3 corrections will be applied.
271+
**kwargs
272+
Keyword arguments passed to the `DFTD3Dispersion` class.
213273
214274
Returns
215275
-------
@@ -237,17 +297,15 @@ def energy(mf):
237297
-110.93260361702605
238298
"""
239299

240-
from pyscf.scf import hf
241-
from pyscf.mcscf import casci
242-
243-
if not isinstance(mf, (hf.SCF, casci.CASCI)):
300+
if not isinstance(mf, (scf.hf.SCF, mcscf.casci.CASCI)):
244301
raise TypeError("mf must be an instance of SCF or CASCI")
245302

246303
with_dftd3 = DFTD3Dispersion(
247304
mf.mol,
248305
xc="hf"
249-
if isinstance(mf, casci.CASCI)
306+
if isinstance(mf, mcscf.casci.CASCI)
250307
else getattr(mf, "xc", "HF").upper().replace(" ", ""),
308+
**kwargs,
251309
)
252310

253311
if isinstance(mf, _DFTD3):
@@ -287,7 +345,7 @@ def nuc_grad_method(self):
287345
return DFTD3(mf, with_dftd3)
288346

289347

290-
def grad(scf_grad):
348+
def grad(scf_grad: rhf_grad.Gradients, **kwargs):
291349
"""
292350
Apply DFT-D3 corrections to SCF or MCSCF nuclear gradients methods
293351
by returning an instance of a new class built from the original class.
@@ -296,8 +354,10 @@ def grad(scf_grad):
296354
297355
Parameters
298356
----------
299-
mfgrad
357+
scf_grad: rhf_grad.Gradients
300358
The method to which DFT-D3 corrections will be applied.
359+
**kwargs
360+
Keyword arguments passed to the `DFTD3Dispersion` class.
301361
302362
Returns
303363
-------
@@ -330,14 +390,13 @@ def grad(scf_grad):
330390
5 H -0.0154527822 0.0229409425 -0.0215141991
331391
----------------------------------------------
332392
"""
333-
from pyscf.grad import rhf as rhf_grad
334393

335394
if not isinstance(scf_grad, rhf_grad.Gradients):
336395
raise TypeError("scf_grad must be an instance of Gradients")
337396

338397
# Ensure that the zeroth order results include DFTD3 corrections
339398
if not getattr(scf_grad.base, "with_dftd3", None):
340-
scf_grad.base = dftd3(scf_grad.base)
399+
scf_grad.base = energy(scf_grad.base, **kwargs)
341400

342401
class DFTD3Grad(_DFTD3Grad, scf_grad.__class__):
343402
def grad_nuc(self, mol=None, atmlst=None):

python/dftd3/test_pyscf.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,52 @@ def test_energy_r2scan_d3():
5656
assert d3.kernel()[0] == approx(-0.00578401192369041, abs=1.0e-7)
5757

5858

59+
@pytest.mark.skipif(pyscf is None, reason="requires pyscf")
60+
@pytest.mark.parametrize("atm", [True, False])
61+
def test_energy_bp_d3zero(atm):
62+
thr = 1e-9
63+
64+
mol = gto.M(
65+
atom=[
66+
("C", [-1.42754169820131, -1.50508961850828, -1.93430551124333]),
67+
("C", [+1.19860572924150, -1.66299114873979, -2.03189643761298]),
68+
("C", [+2.65876001301880, +0.37736955363609, -1.23426391650599]),
69+
("C", [+1.50963368042358, +2.57230374419743, -0.34128058818180]),
70+
("C", [-1.12092277855371, +2.71045691257517, -0.25246348639234]),
71+
("C", [-2.60071517756218, +0.67879949508239, -1.04550707592673]),
72+
("I", [-2.86169588073340, +5.99660765711210, +1.08394899986031]),
73+
("H", [+2.09930989272956, -3.36144811062374, -2.72237695164263]),
74+
("H", [+2.64405246349916, +4.15317840474646, +0.27856972788526]),
75+
("H", [+4.69864865613751, +0.26922271535391, -1.30274048619151]),
76+
("H", [-4.63786461351839, +0.79856258572808, -0.96906659938432]),
77+
("H", [-2.57447518692275, -3.08132039046931, -2.54875517521577]),
78+
("S", [-5.88211879210329, 11.88491819358157, +2.31866455902233]),
79+
("H", [-8.18022701418703, 10.95619984550779, +1.83940856333092]),
80+
("C", [-5.08172874482867, 12.66714386256482, -0.92419491629867]),
81+
("H", [-3.18311711399702, 13.44626574330220, -0.86977613647871]),
82+
("H", [-5.07177399637298, 10.99164969235585, -2.10739192258756]),
83+
("H", [-6.35955320518616, 14.08073002965080, -1.68204314084441]),
84+
],
85+
unit="bohr",
86+
)
87+
88+
d3 = disp.DFTD3Dispersion(
89+
mol,
90+
param={
91+
"s6": 1.0,
92+
"s8": 1.683,
93+
"rs6": 1.139,
94+
"rs8": 1.0,
95+
"alp": 14.0,
96+
"s9": 1.0 if atm else 0.0,
97+
},
98+
version="d3zero",
99+
)
100+
ref = -0.01410721853585842 if atm else -0.014100267345314462
101+
102+
assert approx(d3.kernel()[0], abs=thr) == ref
103+
104+
59105
@pytest.mark.skipif(pyscf is None, reason="requires pyscf")
60106
@pytest.mark.parametrize("xc", ["b3lyp", "b3lypg", "b3lyp5", "b3lyp3"])
61107
def test_energy_b3lyp_d3(xc: str):

0 commit comments

Comments
 (0)