Skip to content

Commit 3ac8c4c

Browse files
authored
Merge devel into master (#2402)
2 parents 6cdc5bf + afa27d7 commit 3ac8c4c

File tree

20 files changed

+351
-33
lines changed

20 files changed

+351
-33
lines changed

.github/workflows/build_wheel.yml

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ jobs:
3333
platform_id: manylinux_aarch64
3434
dp_variant: cpu
3535
steps:
36-
- uses: actions/checkout@v2
36+
- uses: actions/checkout@v3
3737
with:
3838
submodules: true
3939
# https://github.com/pypa/setuptools_scm/issues/480
@@ -42,13 +42,13 @@ jobs:
4242
name: Setup QEMU
4343
if: matrix.platform_id == 'manylinux_aarch64'
4444
- name: Build wheels
45-
uses: pypa/cibuildwheel@v2.11.3
45+
uses: pypa/cibuildwheel@v2.12.1
4646
env:
4747
CIBW_BUILD_VERBOSITY: 1
4848
CIBW_ARCHS: all
4949
CIBW_BUILD: cp${{ matrix.python }}-${{ matrix.platform_id }}
5050
DP_VARIANT: ${{ matrix.dp_variant }}
51-
- uses: actions/upload-artifact@v2
51+
- uses: actions/upload-artifact@v3
5252
with:
5353
path: ./wheelhouse/*.whl
5454
build_sdist:
@@ -79,7 +79,7 @@ jobs:
7979
with:
8080
name: artifact
8181
path: dist
82-
- uses: pypa/gh-action-pypi-publish@v4
82+
- uses: pypa/gh-action-pypi-publish@release/v1
8383
with:
8484
user: __token__
8585
password: ${{ secrets.pypi_password }}
@@ -103,12 +103,12 @@ jobs:
103103

104104
- name: Extract metadata (tags, labels) for Docker
105105
id: meta
106-
uses: docker/metadata-action@98669ae865ea3cffbcbaa878cf57c20bbf1c6c38
106+
uses: docker/metadata-action@507c2f2dc502c992ad446e3d7a5dfbe311567a96
107107
with:
108108
images: ghcr.io/deepmodeling/deepmd-kit
109109

110110
- name: Build and push Docker image
111-
uses: docker/build-push-action@ad44023a93711e3deb337508980b4b5e9bcdc5dc
111+
uses: docker/build-push-action@3b5e8027fcad23fda98b2e3ac259d8d67585f671
112112
with:
113113
context: source/install/docker
114114
push: ${{ github.repository_owner == 'deepmodeling' && github.event_name == 'push' }}

.github/workflows/package_c.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,14 +9,14 @@ jobs:
99
name: Build C library
1010
runs-on: ubuntu-22.04
1111
steps:
12-
- uses: actions/checkout@v2
12+
- uses: actions/checkout@v3
1313
- name: Package C library
1414
run: ./source/install/docker_package_c.sh
1515
- name: Test C library
1616
run: ./source/install/docker_test_package_c.sh
1717
# for download and debug
1818
- name: Upload artifact
19-
uses: actions/upload-artifact@v2
19+
uses: actions/upload-artifact@v3
2020
with:
2121
path: ./libdeepmd_c.tar.gz
2222
- name: Release

deepmd/descriptor/se_a.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -96,8 +96,8 @@ class DescrptSeA(DescrptSe):
9696
.. math::
9797
(\mathcal{G}^i)_j = \mathcal{N}(s(r_{ji}))
9898
99-
:math:`\mathcal{G}^i_< \in \mathbb{R}^{N \times M_2}` takes first :math:`M_2`$` columns of
100-
:math:`\mathcal{G}^i`$`. The equation of embedding network :math:`\mathcal{N}` can be found at
99+
:math:`\mathcal{G}^i_< \in \mathbb{R}^{N \times M_2}` takes first :math:`M_2` columns of
100+
:math:`\mathcal{G}^i`. The equation of embedding network :math:`\mathcal{N}` can be found at
101101
:meth:`deepmd.utils.network.embedding_net`.
102102
103103
Parameters

deepmd/descriptor/se_a_mask.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -70,8 +70,8 @@ class DescrptSeAMask(DescrptSeA):
7070
.. math::
7171
(\mathcal{G}^i)_j = \mathcal{N}(s(r_{ji}))
7272
73-
:math:`\mathcal{G}^i_< \in \mathbb{R}^{N \times M_2}` takes first :math:`M_2`$` columns of
74-
:math:`\mathcal{G}^i`$`. The equation of embedding network :math:`\mathcal{N}` can be found at
73+
:math:`\mathcal{G}^i_< \in \mathbb{R}^{N \times M_2}` takes first :math:`M_2` columns of
74+
:math:`\mathcal{G}^i`. The equation of embedding network :math:`\mathcal{N}` can be found at
7575
:meth:`deepmd.utils.network.embedding_net`.
7676
Specially for descriptor se_a_mask is a concise implementation of se_a.
7777
The difference is that se_a_mask only considered a non-pbc system.

deepmd/descriptor/se_atten.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import warnings
12
from typing import (
23
List,
34
Optional,
@@ -67,6 +68,8 @@ class DescrptSeAtten(DescrptSeA):
6768
exclude_types : List[List[int]]
6869
The excluded pairs of types which have no interaction with each other.
6970
For example, `[[0, 1]]` means no interaction between type 0 and type 1.
71+
set_davg_zero
72+
Set the shift of embedding net input to zero.
7073
activation_function
7174
The activation function in the embedding net. Supported options are |ACTIVATION_FN|
7275
precision
@@ -97,6 +100,7 @@ def __init__(
97100
trainable: bool = True,
98101
seed: Optional[int] = None,
99102
type_one_side: bool = True,
103+
set_davg_zero: bool = True,
100104
exclude_types: List[List[int]] = [],
101105
activation_function: str = "tanh",
102106
precision: str = "default",
@@ -107,6 +111,11 @@ def __init__(
107111
attn_mask: bool = False,
108112
multi_task: bool = False,
109113
) -> None:
114+
if not set_davg_zero:
115+
warnings.warn(
116+
"Set 'set_davg_zero' False in descriptor 'se_atten' "
117+
"may cause unexpected incontinuity during model inference!"
118+
)
110119
DescrptSeA.__init__(
111120
self,
112121
rcut,
@@ -119,7 +128,7 @@ def __init__(
119128
seed=seed,
120129
type_one_side=type_one_side,
121130
exclude_types=exclude_types,
122-
set_davg_zero=True,
131+
set_davg_zero=set_davg_zero,
123132
activation_function=activation_function,
124133
precision=precision,
125134
uniform_seed=uniform_seed,

deepmd/fit/ener.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -59,9 +59,9 @@ class EnerFitting(Fitting):
5959
\mathbf{y}=\mathcal{L}(\mathbf{x};\mathbf{w},\mathbf{b})=
6060
\boldsymbol{\phi}(\mathbf{x}^T\mathbf{w}+\mathbf{b})
6161
62-
where :math:`\mathbf{x} \in \mathbb{R}^{N_1}`$` is the input vector and :math:`\mathbf{y} \in \mathbb{R}^{N_2}`
62+
where :math:`\mathbf{x} \in \mathbb{R}^{N_1}` is the input vector and :math:`\mathbf{y} \in \mathbb{R}^{N_2}`
6363
is the output vector. :math:`\mathbf{w} \in \mathbb{R}^{N_1 \times N_2}` and
64-
:math:`\mathbf{b} \in \mathbb{R}^{N_2}`$` are weights and biases, respectively,
64+
:math:`\mathbf{b} \in \mathbb{R}^{N_2}` are weights and biases, respectively,
6565
both of which are trainable if `trainable[i]` is `True`. :math:`\boldsymbol{\phi}`
6666
is the activation function.
6767
@@ -71,9 +71,9 @@ class EnerFitting(Fitting):
7171
\mathbf{y}=\mathcal{L}^{(n)}(\mathbf{x};\mathbf{w},\mathbf{b})=
7272
\mathbf{x}^T\mathbf{w}+\mathbf{b}
7373
74-
where :math:`\mathbf{x} \in \mathbb{R}^{N_{n-1}}`$` is the input vector and :math:`\mathbf{y} \in \mathbb{R}`
74+
where :math:`\mathbf{x} \in \mathbb{R}^{N_{n-1}}` is the input vector and :math:`\mathbf{y} \in \mathbb{R}`
7575
is the output scalar. :math:`\mathbf{w} \in \mathbb{R}^{N_{n-1}}` and
76-
:math:`\mathbf{b} \in \mathbb{R}`$` are weights and bias, respectively,
76+
:math:`\mathbf{b} \in \mathbb{R}` are weights and bias, respectively,
7777
both of which are trainable if `trainable[n]` is `True`.
7878
7979
Parameters
@@ -549,13 +549,14 @@ def build(
549549
aparam = tf.reshape(aparam, [-1, self.numb_aparam * natoms[0]])
550550

551551
atype_nall = tf.reshape(atype, [-1, natoms[1]])
552-
atype_filter = tf.cast(atype_nall >= 0, GLOBAL_TF_FLOAT_PRECISION)
552+
self.atype_nloc = tf.slice(
553+
atype_nall, [0, 0], [-1, natoms[0]]
554+
) ## lammps will make error
555+
atype_filter = tf.cast(self.atype_nloc >= 0, GLOBAL_TF_FLOAT_PRECISION)
556+
self.atype_nloc = tf.reshape(self.atype_nloc, [-1])
553557
# prevent embedding_lookup error,
554558
# but the filter will be applied anyway
555-
atype_nall = tf.clip_by_value(atype_nall, 0, self.ntypes - 1)
556-
self.atype_nloc = tf.reshape(
557-
tf.slice(atype_nall, [0, 0], [-1, natoms[0]]), [-1]
558-
) ## lammps will make error
559+
self.atype_nloc = tf.clip_by_value(self.atype_nloc, 0, self.ntypes - 1)
559560
if type_embedding is not None:
560561
atype_embed = tf.nn.embedding_lookup(type_embedding, self.atype_nloc)
561562
else:

deepmd/loggers/loggers.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -229,6 +229,8 @@ def set_log_handles(
229229

230230
ch.setLevel(level)
231231
ch.addFilter(_AppFilter())
232+
# clean old handlers before adding new one
233+
root_log.handlers.clear()
232234
root_log.addHandler(ch)
233235

234236
# * add file handler ***************************************************************

deepmd/utils/argcheck.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -325,6 +325,7 @@ def descrpt_se_atten_args():
325325
doc_precision = f"The precision of the embedding net parameters, supported options are {list_to_doc(PRECISION_DICT.keys())} Default follows the interface precision."
326326
doc_trainable = "If the parameters in the embedding net is trainable"
327327
doc_seed = "Random seed for parameter initialization"
328+
doc_set_davg_zero = "Set the normalization average to zero. This option should be set when `se_atten` descriptor or `atom_ener` in the energy fitting is used"
328329
doc_exclude_types = "The excluded pairs of types which have no interaction with each other. For example, `[[0, 1]]` means no interaction between type 0 and type 1."
329330
doc_attn = "The length of hidden vectors in attention layers"
330331
doc_attn_layer = "The number of attention layers"
@@ -361,6 +362,9 @@ def descrpt_se_atten_args():
361362
Argument(
362363
"exclude_types", list, optional=True, default=[], doc=doc_exclude_types
363364
),
365+
Argument(
366+
"set_davg_zero", bool, optional=True, default=True, doc=doc_set_davg_zero
367+
),
364368
Argument("attn", int, optional=True, default=128, doc=doc_attn),
365369
Argument("attn_layer", int, optional=True, default=2, doc=doc_attn_layer),
366370
Argument("attn_dotr", bool, optional=True, default=True, doc=doc_attn_dotr),
@@ -972,7 +976,8 @@ def training_data_args(): # ! added by Ziyao: new specification style for data
972976
- list: the length of which is the same as the {link_sys}. The batch size of each system is given by the elements of the list.\n\n\
973977
- int: all {link_sys} use the same batch size.\n\n\
974978
- string "auto": automatically determines the batch size so that the batch_size times the number of atoms in the system is no less than 32.\n\n\
975-
- string "auto:N": automatically determines the batch size so that the batch_size times the number of atoms in the system is no less than N.'
979+
- string "auto:N": automatically determines the batch size so that the batch_size times the number of atoms in the system is no less than N.\n\n\
980+
- string "mixed:N": the batch data will be sampled from all systems and merged into a mixed system with the batch size N. Only support the se_atten descriptor.'
976981
doc_auto_prob_style = 'Determine the probability of systems automatically. The method is assigned by this key and can be\n\n\
977982
- "prob_uniform" : the probability all the systems are equal, namely 1.0/self.get_nsystems()\n\n\
978983
- "prob_sys_size" : the probability of a system is proportional to the number of batches in the system\n\n\

deepmd/utils/data.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,10 @@ def __init__(
9191
self.type_idx_map = np.array(
9292
sorter[np.searchsorted(type_map, self.type_map, sorter=sorter)]
9393
)
94+
# padding for virtual atom
95+
self.type_idx_map = np.append(
96+
self.type_idx_map, np.array([-1], dtype=np.int32)
97+
)
9498
self.type_map = type_map
9599
if type_map is None and self.type_map is None and self.mixed_type:
96100
raise RuntimeError("mixed_type format must have type_map!")
@@ -489,8 +493,12 @@ def _load_set(self, set_name: DPPath):
489493
[(real_type == i).sum(axis=-1) for i in range(self.get_ntypes())],
490494
dtype=np.int32,
491495
).T
496+
ghost_nums = np.array(
497+
[(real_type == -1).sum(axis=-1)],
498+
dtype=np.int32,
499+
).T
492500
assert (
493-
atom_type_nums.sum(axis=-1) == natoms
501+
atom_type_nums.sum(axis=-1) + ghost_nums.sum(axis=-1) == natoms
494502
).all(), "some types in 'real_atom_types.npy' of set {} are not contained in {} types!".format(
495503
set_name, self.get_ntypes()
496504
)

deepmd/utils/data_system.py

Lines changed: 104 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,7 @@ def __init__(
112112
# batch size
113113
self.batch_size = batch_size
114114
is_auto_bs = False
115+
self.mixed_systems = False
115116
if isinstance(self.batch_size, int):
116117
self.batch_size = self.batch_size * np.ones(self.nsystems, dtype=int)
117118
elif isinstance(self.batch_size, str):
@@ -121,9 +122,16 @@ def __init__(
121122
rule = 32
122123
if len(words) == 2:
123124
rule = int(words[1])
125+
self.batch_size = self._make_auto_bs(rule)
126+
elif "mixed" == words[0]:
127+
self.mixed_systems = True
128+
if len(words) == 2:
129+
rule = int(words[1])
130+
else:
131+
raise RuntimeError("batch size must be specified for mixed systems")
132+
self.batch_size = rule * np.ones(self.nsystems, dtype=int)
124133
else:
125134
raise RuntimeError("unknown batch_size rule " + words[0])
126-
self.batch_size = self._make_auto_bs(rule)
127135
elif isinstance(self.batch_size, list):
128136
pass
129137
else:
@@ -361,7 +369,7 @@ def _get_sys_probs(self, sys_probs, auto_prob_style): # depreciated
361369
prob = self._process_sys_probs(sys_probs)
362370
return prob
363371

364-
def get_batch(self, sys_idx: Optional[int] = None):
372+
def get_batch(self, sys_idx: Optional[int] = None) -> dict:
365373
# batch generation style altered by Ziyao Li:
366374
# one should specify the "sys_prob" and "auto_prob_style" params
367375
# via set_sys_prob() function. The sys_probs this function uses is
@@ -375,9 +383,36 @@ def get_batch(self, sys_idx: Optional[int] = None):
375383
The index of system from which the batch is get.
376384
If sys_idx is not None, `sys_probs` and `auto_prob_style` are ignored
377385
If sys_idx is None, automatically determine the system according to `sys_probs` or `auto_prob_style`, see the following.
386+
This option does not work for mixed systems.
387+
388+
Returns
389+
-------
390+
dict
391+
The batch data
378392
"""
379393
if not hasattr(self, "default_mesh"):
380394
self._make_default_mesh()
395+
if not self.mixed_systems:
396+
b_data = self.get_batch_standard(sys_idx)
397+
else:
398+
b_data = self.get_batch_mixed()
399+
return b_data
400+
401+
def get_batch_standard(self, sys_idx: Optional[int] = None) -> dict:
402+
"""Get a batch of data from the data systems in the standard way.
403+
404+
Parameters
405+
----------
406+
sys_idx : int
407+
The index of system from which the batch is get.
408+
If sys_idx is not None, `sys_probs` and `auto_prob_style` are ignored
409+
If sys_idx is None, automatically determine the system according to `sys_probs` or `auto_prob_style`, see the following.
410+
411+
Returns
412+
-------
413+
dict
414+
The batch data
415+
"""
381416
if sys_idx is not None:
382417
self.pick_idx = sys_idx
383418
else:
@@ -390,6 +425,73 @@ def get_batch(self, sys_idx: Optional[int] = None):
390425
b_data["default_mesh"] = self.default_mesh[self.pick_idx]
391426
return b_data
392427

428+
def get_batch_mixed(self) -> dict:
429+
"""Get a batch of data from the data systems in the mixed way.
430+
431+
Returns
432+
-------
433+
dict
434+
The batch data
435+
"""
436+
# mixed systems have a global batch size
437+
batch_size = self.batch_size[0]
438+
batch_data = []
439+
for _ in range(batch_size):
440+
self.pick_idx = dp_random.choice(np.arange(self.nsystems), p=self.sys_probs)
441+
bb_data = self.data_systems[self.pick_idx].get_batch(1)
442+
bb_data["natoms_vec"] = self.natoms_vec[self.pick_idx]
443+
bb_data["default_mesh"] = self.default_mesh[self.pick_idx]
444+
batch_data.append(bb_data)
445+
b_data = self._merge_batch_data(batch_data)
446+
return b_data
447+
448+
def _merge_batch_data(self, batch_data: List[dict]) -> dict:
449+
"""Merge batch data from different systems.
450+
451+
Parameters
452+
----------
453+
batch_data : list of dict
454+
A list of batch data from different systems.
455+
456+
Returns
457+
-------
458+
dict
459+
The merged batch data.
460+
"""
461+
b_data = {}
462+
max_natoms = max(bb["natoms_vec"][0] for bb in batch_data)
463+
# natoms_vec
464+
natoms_vec = np.zeros(2 + self.get_ntypes(), dtype=int)
465+
natoms_vec[0:3] = max_natoms
466+
b_data["natoms_vec"] = natoms_vec
467+
# real_natoms_vec
468+
real_natoms_vec = np.vstack([bb["natoms_vec"] for bb in batch_data])
469+
b_data["real_natoms_vec"] = real_natoms_vec
470+
# type
471+
type_vec = np.full((len(batch_data), max_natoms), -1, dtype=int)
472+
for ii, bb in enumerate(batch_data):
473+
type_vec[ii, : bb["type"].shape[1]] = bb["type"][0]
474+
b_data["type"] = type_vec
475+
# default_mesh
476+
default_mesh = np.mean([bb["default_mesh"] for bb in batch_data], axis=0)
477+
b_data["default_mesh"] = default_mesh
478+
# other data
479+
data_dict = self.get_data_dict(0)
480+
for kk, vv in data_dict.items():
481+
if kk not in batch_data[0]:
482+
continue
483+
b_data["find_" + kk] = batch_data[0]["find_" + kk]
484+
if not vv["atomic"]:
485+
b_data[kk] = np.concatenate([bb[kk] for bb in batch_data], axis=0)
486+
else:
487+
b_data[kk] = np.zeros(
488+
(len(batch_data), max_natoms * vv["ndof"] * vv["repeat"]),
489+
dtype=batch_data[0][kk].dtype,
490+
)
491+
for ii, bb in enumerate(batch_data):
492+
b_data[kk][ii, : bb[kk].shape[1]] = bb[kk][0]
493+
return b_data
494+
393495
# ! altered by Marián Rynik
394496
def get_test(self, sys_idx: Optional[int] = None, n_test: int = -1): # depreciated
395497
"""Get test data from the the data systems.

0 commit comments

Comments
 (0)