Skip to content

Commit a0a1283

Browse files
DefaultRadii
1 parent e0b2986 commit a0a1283

File tree

3 files changed

+23
-3
lines changed

3 files changed

+23
-3
lines changed

theforce/calculator/active.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from theforce.regression.gppotential import PosteriorPotential, PosteriorPotentialFromFolder
33
from theforce.descriptor.atoms import TorchAtoms, AtomsData, LocalsData
44
from theforce.similarity.sesoap import SeSoapKernel
5-
from theforce.math.sesoap import SpecialRadii
5+
from theforce.math.sesoap import DefaultRadii
66
from theforce.util.tensors import padded
77
from theforce.util.util import date, timestamp
88
from theforce.io.sgprio import SgprIO
@@ -18,7 +18,7 @@
1818

1919

2020
def default_kernel(cutoff=6.):
21-
return SeSoapKernel(3, 3, 4, cutoff, radii=SpecialRadii({1: 0.5}))
21+
return SeSoapKernel(3, 3, 4, cutoff, radii=DefaultRadii())
2222

2323

2424
class FilterDeltas(Filter):
@@ -215,6 +215,7 @@ def __init__(self, covariance=None, calculator=None, process_group=None, meta=No
215215
self.stdout = True
216216
self.step = 0
217217
self.log('active calculator says Hello!', mode='w')
218+
self.log(f'kernel: {self.model.descriptors}')
218219
self.log_settings()
219220
self.log('model size: {} {}'.format(*self.size))
220221
self.pckl = pckl

theforce/cl/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from theforce.util.util import get_default_args
33
from theforce.util.parallel import mpi_init
44
from theforce.calculator.socketcalc import SocketCalculator
5-
from theforce.calculator.active import ActiveCalculator, kcal_mol, inf
5+
from theforce.calculator.active import ActiveCalculator, kcal_mol, inf, SeSoapKernel, DefaultRadii
66
import os
77

88

theforce/math/sesoap.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,25 @@ def state_args(self):
8383
return f'{dct}, {self.others}'
8484

8585

86+
class DefaultRadii(Radii):
87+
88+
def __init__(self, default=1., special={1: 0.5}):
89+
self.default = default
90+
self.special = special
91+
92+
def get(self, number):
93+
try:
94+
return self.special[number]
95+
except KeyError:
96+
return self.default
97+
98+
@property
99+
def state_args(self):
100+
default = float(self.default)
101+
special = {z: float(r) for z, r in self.special.items()}
102+
return f'{default}, {special}'
103+
104+
86105
class SeSoap(Module):
87106

88107
def __init__(self, lmax, nmax, radial, radii=1., flatten=True, normalize=True):

0 commit comments

Comments
 (0)