Skip to content

Commit 8524a13

Browse files
authored
Tutorial 2025 (deepmodeling#241)
* feat(nnsk): add basisref support and onsite energy display Introduce a new `basisref` parameter in the `to_json` method to handle uniform_noref mode for onsite energies. This allows for referencing a uniform basis set when calculating onsite energies. Additionally, add a `show_onsites` function to display onsite energies for a given basis set, improving debugging and analysis capabilities. * docs: restructure installation instructions for clarity Reorganize the installation steps in the documentation to improve readability and logical flow. Separate the installation of `torch` and `torch-scatter` into distinct steps and ensure consistent formatting across both the quick start guide and README. * refactor(nnsk): simplify uniform basis reference handling The changes streamline the handling of uniform basis references by consolidating the logic into a single block. This reduces redundancy and improves code maintainability by ensuring consistent behavior across different parameter types (hopping, overlap, onsite, and SOC). * chore: add base polynomial model files for empirical sk baseline model Add base_poly2.pth and base_poly4.pth files to the emp_base directory to support empirical base modeling * chore: add base model files for DFTB polynomial models Add binary model files `base_poly2.pth` and `base_poly4.pth` to the DFTB neural network directory. These files are essential for initializing polynomial-based models in the DFTB framework. * feat(config): add model parameter to update input template dynamically Introduce a new `model` parameter in the `get_full_config` function to dynamically update the input template based on the provided model. This change allows for more flexible configuration management by leveraging the `gen_inputs` utility function. Additionally, simplify optimizer configurations and adjust the `eout_weight` default value for consistency. * feat(emp_sk): add command to generate empirical SK parameters Introduce a new command 'esk' to generate initial empirical SK parameters. This includes the addition of the `emp_sk.py` module, which handles the conversion of model parameters to empirical SK format. The command supports both 'poly2' and 'poly4' base models and saves the output in JSON format. Also, fix a minor bug in `gen_inputs.py` where `model_options` was incorrectly accessed as a method. * refactor: update argument types and add help text for clarity - Modify argument types in `argcheck.py` to include `None` for `r_max`, `oer_max`, and `er_max` - Add help text for `basemodel` in `main.py` to clarify its options - Adjust `eout_weight` in `config_skenv.py` for better training balance - Introduce `atomic_radius` argument in `nnsk` for model flexibility - Add `uniform_noref` option to `onsite` method for additional configuration - Update `rs` argument type in `hopping` to accept `dict` for more complex scenarios * fix(gen_inputs): handle device type and freeze overlap param in nnsk Ensure correct device type assignment by checking both string and torch.device instances. Additionally, freeze the overlap parameter in the nnsk model when overlap is detected to prevent unintended modifications. * refactor(utils): 修改vasp_kpath函数参数类型为list[str] 将vasp_kpath函数的pathstr参数类型从str改为list[str],以提高代码的灵活性和可读性 * feat(run): support band plot for given structure using empirical sk parameters and support generating band.json from structure files, add input validation and default Fermi energy setting" * fix(emp_sk): 修正基础设置错误信息的描述,确保更清晰的错误提示 fix(band): 添加日志以确认提供的费米能量与估计值匹配 fix(auto_band_config): 修正日志信息中的拼写错误 * fix(test): 修正测试用例中的参数传递,确保正确处理无效输入 * feat(pyproject): add optional seekpath group and its dependencies * fix(auto_band_config): using get_path_orig_cell to replace seekpath.get_path, since the get_path function will works on the standard primitive unit cell. * fix(run): correct spelling of 'poly2' and 'poly4' in init_model checks * fix(emp_sk): ensure output directory exists before saving JSON model * fix(run): streamline init_model handling for 'poly2' and 'poly4' cases * fix(band): improve Fermi energy handling and update ylabel for clarity * fix(band): update plotting behavior to close figure when not using GUI * self.r_map to device * rename examples/silicon examples/silicon/tutorial_v2.1 * Rename silicon example files for tutorial v2.1 * Remove unused model files and update usage notebook * example: add ABACUS and VASP raw data files for GaAs unit cell structure. * feat: add example configurations and structures for base model This commit introduces new JSON configuration files and VASP structure files for silicon, GaAs, and hBN materials. These files are essential for setting up and testing the base model with different materials and basis sets. * feat(example):add GaAs example 1. using dftio transfer raw data to deeptb format including vasp and abacus 2. using base model to train gaas model * add raw data for GaAs_io_sk example * docs: add configuration files for GaAs training example Add `gaas.json` and `band.json` files to the `examples/GaAs_io_sk/train` directory. These files define the basis set and band structure calculation parameters for the GaAs training example. * feat(train): add reference checkpoints and input config for GaAs_io_sk This commit introduces new reference checkpoints (nnsk_tr1.pth, nnsk_tr2.pth) and a training input configuration file (input.json) for the GaAs_io_sk example. These files are essential for setting up and running the training process with the specified model and data options. * feat: add tutorial_v2.2 configuration files for silicon Introduce new JSON configuration files for the silicon tutorial version 2.2. These files include input configurations, band structure settings, and training options to support the tutorial's execution and reproducibility. * fix: update file paths in band.json and band_2.json Correct the reference file paths in band.json and band_2.json to point to the correct data directory * chore: update data paths in tutorial JSON files Modify the "root" field in input JSON files to point to "../data/" instead of "./data/" to ensure correct data directory referencing * feat: enhance format_common_options documentation with basis definition examples * fix: add error handling for invalid basis in OrbitalMapper * feat: add testing for auto_band_config.py and build empirical sk model * fix: update test cases to handle missing seekpath and use predefined common options * fix: add seekpath dependency to pyproject.toml * fix: remove optional seekpath group from pyproject.toml
1 parent 817ff6c commit 8524a13

File tree

126 files changed

+152571
-119
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

126 files changed

+152571
-119
lines changed

README.md

Lines changed: 14 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -66,42 +66,40 @@ Installing **DeePTB** is straightforward. We recommend using a virtual environme
6666

6767
Highly recommended to install DeePTB from source to get the latest features and bug fixes.
6868
1. **Setup Python environment**:
69-
7069
Using conda (recommended, python >=3.9, <=3.12 ), e.g.,
7170
```bash
7271
conda create -n dptb_venv python=3.10
7372
conda activate dptb_venv
7473
```
7574
or using venv (make sure python >=3.9,<=3.12)
75+
7676
```bash
7777
python -m venv dptb_venv
7878
source dptb_venv/bin/activate
79+
```
7980

8081
2. **Clone DeePTB and Navigate to the root directory**:
8182
```bash
8283
git clone https://github.com/deepmodeling/DeePTB.git
8384
cd DeePTB
8485
```
85-
3. **Install `torch` and `torch-scatter`** (two ways):
86-
- **Recommended**: Install torch and torch-scatter using the following commands:
8786

87+
3. **Install `torch`**:
88+
```bash
89+
pip install "torch>=2.0.0,<=2.5.0"
90+
```
91+
4. **Install `torch-scatter`** (two ways):
92+
- **Recommended**: Install torch and torch-scatter using the following commands:
8893
```bash
8994
python docs/auto_install_torch_scatter.py
9095
```
91-
9296
- **Manual**: Install torch and torch-scatter manually:
93-
1. install torch:
94-
```bash
95-
pip install "torch>=2.0.0,<=2.5.0"
96-
```
97-
98-
2. install torch-scatter:
99-
```bash
100-
pip install torch-scatter -f https://data.pyg.org/whl/torch-${version}+${CUDA}.html
101-
```
102-
where `${version}` is the version of torch, e.g., 2.5.0, and `${CUDA}` is the CUDA version, e.g., cpu, cu118, cu121, cu124. See [torch_scatter doc](https://github.com/rusty1s/pytorch_scatter) for more details.
103-
104-
4. **Install DeePTB**:
97+
```bash
98+
pip install torch-scatter -f https://data.pyg.org/whl/torch-${version}+${CUDA}.html
99+
```
100+
where `${version}` is the version of torch, e.g., 2.5.0, and `${CUDA}` is the CUDA version, e.g., cpu, cu118, cu121, cu124. See [torch_scatter doc](https://github.com/rusty1s/pytorch_scatter) for more details.
101+
102+
5. **Install DeePTB**:
105103
```bash
106104
pip install .
107105
```

docs/quick_start/easy_install.md

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -29,32 +29,28 @@ Highly recommended to install DeePTB from source to get the latest features and
2929
```bash
3030
python -m venv dptb_venv
3131
source dptb_venv/bin/activate
32-
32+
```
3333
2. **Clone DeePTB and Navigate to the root directory**:
3434
```bash
3535
git clone https://github.com/deepmodeling/DeePTB.git
3636
cd DeePTB
3737
```
38-
3. **Install `torch` and `torch-scatter`** (two ways):
39-
- **Recommended**: Install torch and torch-scatter using the following commands:
40-
38+
3. **Install `torch`**:
4139
```bash
42-
python docs/auto_install_torch_scatter.py
40+
pip install "torch>=2.0.0,<=2.5.0"
4341
```
44-
45-
- **Manual**: Install torch and torch-scatter manually:
46-
1. install torch:
42+
4. **Install `torch-scatter`** (two ways):
43+
- **Recommended**: Install torch and torch-scatter using the following commands:
4744
```bash
48-
pip install "torch>=2.0.0,<=2.5.0"
45+
python docs/auto_install_torch_scatter.py
4946
```
50-
51-
2. install torch-scatter:
47+
- **Manual**: Install torch and torch-scatter manually:
5248
```bash
5349
pip install torch-scatter -f https://data.pyg.org/whl/torch-${version}+${CUDA}.html
5450
```
5551
where `${version}` is the version of torch, e.g., 2.5.0, and `${CUDA}` is the CUDA version, e.g., cpu, cu118, cu121, cu124. See [torch_scatter doc](https://github.com/rusty1s/pytorch_scatter) for more details.
5652

57-
4. **Install DeePTB**:
53+
5. **Install DeePTB**:
5854
```bash
5955
pip install .
6056
```

dptb/data/transforms.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -489,7 +489,9 @@ def __init__(
489489

490490
for ko in orbtype_count.keys():
491491
orbtype_count[ko] = max(orbtype_count[ko])
492-
492+
else:
493+
raise ValueError(f"Invalid basis {self.basis} found. now only support string or list basis.")
494+
493495
self.orbtype_count = orbtype_count
494496
full_basis_norb = 0
495497
for ko in orbtype_count.keys():

dptb/entrypoints/config.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,13 @@
22
import json
33
from pathlib import Path
44
import os
5-
from dptb.utils.config_sk import TrainFullConfigSK, TestFullConfigSK
6-
from dptb.utils.config_skenv import TrainFullConfigSKEnv, TestFullConfigSKEnv
7-
from dptb.utils.config_e3 import TrainFullConfigE3, TestFullConfigE3
85
import logging
6+
from dptb.utils.gen_inputs import gen_inputs
97

108
__all__ = ["get_full_config", "config"]
119
log = logging.getLogger(__name__)
1210

13-
def get_full_config(train, test, e3tb, sktb, sktbenv):
11+
def get_full_config(model, train, test, e3tb, sktb, sktbenv):
1412
"""
1513
This function determines the appropriate full config based on the provided parameters.
1614
@@ -31,16 +29,17 @@ def get_full_config(train, test, e3tb, sktb, sktbenv):
3129
name = ''
3230
if train:
3331
name += 'train'
32+
3433
# Use train configs based on e3tb, sktb, sktbenv
3534
if e3tb:
3635
name += '_E3'
37-
full_config = TrainFullConfigE3
36+
full_config = gen_inputs(mode='e3', task='train', model=model)
3837
elif sktb:
3938
name += '_SK'
40-
full_config = TrainFullConfigSK
39+
full_config = gen_inputs(mode='sk', task='train', model=model)
4140
elif sktbenv:
4241
name += '_SKEnv'
43-
full_config = TrainFullConfigSKEnv
42+
full_config = gen_inputs(mode='skenv', task='train', model=model)
4443
else:
4544
logging.error("Unknown config type in training mode")
4645
raise ValueError("Unknown config type in training mode")
@@ -49,13 +48,13 @@ def get_full_config(train, test, e3tb, sktb, sktbenv):
4948
name += 'test'
5049
if e3tb:
5150
name += '_E3'
52-
full_config = TestFullConfigE3
51+
full_config = gen_inputs(mode='e3', task='test', model=model)
5352
elif sktb:
5453
name += '_SK'
55-
full_config = TestFullConfigSK
54+
full_config = gen_inputs(mode='sk', task='test', model=model)
5655
elif sktbenv:
5756
name += '_SKEnv'
58-
full_config = TestFullConfigSKEnv
57+
full_config = gen_inputs(mode='skenv', task='test', model=model)
5958
else:
6059
logging.error("Unknown config type in testing mode")
6160
raise ValueError("Unknown config type in testing mode")
@@ -72,6 +71,7 @@ def config(
7271
e3tb: bool = False,
7372
sktb: bool = False,
7473
sktbenv: bool = False,
74+
model: str = None,
7575
log_level: int = logging.INFO,
7676
log_path: Optional[str] = None,
7777
**kwargs
@@ -115,7 +115,7 @@ def config(
115115
train = True
116116

117117
# Error handling and logic moved to get_full_config
118-
name, full_config = get_full_config(train, test, e3tb, sktb, sktbenv)
118+
name, full_config = get_full_config(model, train, test, e3tb, sktb, sktbenv)
119119
# Ensure PATH ends with .json
120120
if not PATH.endswith(".json"):
121121
PATH = os.path.join(PATH, "input_templete.json")

dptb/entrypoints/emp_sk.py

Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
import torch
2+
import numpy as np
3+
from dptb.nn.build import build_model
4+
import json
5+
import logging
6+
from dptb.nn.sktb.onsiteDB import onsite_energy_database
7+
import re
8+
import os
9+
from dptb.utils.gen_inputs import gen_inputs
10+
import json
11+
log = logging.getLogger(__name__)
12+
13+
def to_empsk(
14+
INPUT,
15+
output='./',
16+
basemodel='poly2',
17+
**kwargs):
18+
"""
19+
Convert the model to empirical SK parameters.
20+
"""
21+
if INPUT is None:
22+
raise ValueError('INPUT is None.')
23+
with open(INPUT, 'r') as f:
24+
input = json.load(f)
25+
common_options = input['common_options']
26+
EmpSK(common_options, basemodel=basemodel).to_json(outdir=output)
27+
28+
class EmpSK(object):
29+
"""
30+
Empirical SK parameters.
31+
"""
32+
def __init__(self, common_options, basemodel='poly2'):
33+
"""
34+
Args:
35+
common_options: common options for the model. especially contain the basis information.
36+
basemodel: base model type for the empirical SK parameters either 'poly2' or 'poly4'.
37+
"""
38+
self.common_options,self.basisref = self.format_common_options(common_options)
39+
if basemodel == 'poly2':
40+
model_ckpt = os.path.join(os.path.dirname(__file__), '..', 'nn', 'dftb', "base_poly2.pth")
41+
elif basemodel == 'poly4':
42+
model_ckpt = os.path.join(os.path.dirname(__file__), '..', 'nn', 'dftb', "base_poly4.pth")
43+
else:
44+
raise ValueError(f'basemodel {basemodel} is not supported.')
45+
46+
self.model = build_model(model_ckpt, common_options=common_options, no_check=True)
47+
48+
def to_json(self, outdir='./'):
49+
"""
50+
Convert the model to json format.
51+
"""
52+
# 判断是否存在输出目录
53+
if not os.path.exists(outdir):
54+
os.makedirs(outdir, exist_ok=True)
55+
json_dict = self.model.to_json(basisref=self.basisref)
56+
with open(os.path.join(outdir,'sktb.json'), 'w') as f:
57+
json.dump(json_dict, f, indent=4)
58+
59+
# save input template
60+
# input_template = gen_inputs(model=self.model, task='train', mode=mode)
61+
62+
#with open(os.path.join(outdir,'input_template.json'), 'w') as f:
63+
# json.dump(input_template, f, indent=4)
64+
log.info(f'Empirical SK parameters are saved in {os.path.join(outdir,"sktb.json")}')
65+
log.info('If you want to further train the model, please use `dptb config` command to generate input template.')
66+
return json_dict
67+
68+
def format_common_options(self, common_options):
69+
"""
70+
Format the common options for the model. and construct the mapping between two kind of basis definition.
71+
The two kind of basis definition are:
72+
1. common_options = {'basis': {'C': ['s','p','d']}}
73+
2. common_options = {'basis': {'C': ['2s','2p','d*']}}
74+
75+
Args:
76+
common_options: common options for the model. especially contain the basis information.
77+
e.g. common_options = {'basis': {'C': ['s','p','d']}} or common_options = {'basis': {'C': ['2s','2p','d*']}}
78+
79+
Returns:
80+
common_options: common options for the model.
81+
basisref: basis reference for the model.
82+
"""
83+
# check basis in common_options
84+
if 'basis' not in common_options:
85+
raise ValueError('basis information is not given in common_options.')
86+
# check basis type
87+
assert isinstance(common_options['basis'], dict), 'basis information is not a dictionary.'
88+
basis = common_options['basis']
89+
sys_ele = "".join(list(basis.keys()))
90+
log.info(f'Extracting empirical SK parameters for {sys_ele}')
91+
92+
use_basis_ref = False
93+
basisref = {}
94+
for ie in basis.keys():
95+
basisref[ie] = {}
96+
assert isinstance(basis[ie], list), f'basis information for {ie} is not a list.'
97+
for ieorb in basis[ie]:
98+
assert isinstance(ieorb, str), f'basis information for {ie} is not a string.'
99+
if len(ieorb) == 1:
100+
assert use_basis_ref is False, 'Invalid basis setting: cannot mix s, p, d with ns, np, d*.'
101+
continue
102+
else:
103+
use_basis_ref = True
104+
assert ieorb in onsite_energy_database[ie], f'basis information for {ie} is not in onsite_energy_database : {onsite_energy_database[ie].keys()}.'
105+
orbsymb = re.findall(r'[A-Za-z]', ieorb)[0]
106+
basisref[ie][orbsymb] = ieorb
107+
108+
if use_basis_ref:
109+
std_basis = {}
110+
for ie in basis.keys():
111+
std_basis[ie] = []
112+
for ieorb in basis[ie]:
113+
std_basis[ie].append(re.findall(r'[A-Za-z]', ieorb)[0])
114+
common_options['basis'].update(std_basis)
115+
116+
return common_options, basisref
117+
else:
118+
return common_options, None

dptb/entrypoints/main.py

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
from dptb.utils.loggers import set_log_handles
1414
from dptb.utils.config_check import check_config_train
1515
from dptb.entrypoints.collectskf import skf2pth, skf2nnsk
16+
from dptb.entrypoints.emp_sk import to_empsk
17+
1618
from dptb import __version__
1719

1820

@@ -86,6 +88,14 @@ def main_parser() -> argparse.ArgumentParser:
8688
default="./input_templete.json"
8789
)
8890

91+
parser_config.add_argument(
92+
"-m",
93+
"--model",
94+
type=str,
95+
default=None,
96+
help="load model to update input template."
97+
)
98+
8999
parser_config.add_argument(
90100
"-tr",
91101
"--train",
@@ -394,7 +404,31 @@ def main_parser() -> argparse.ArgumentParser:
394404
help="The output files in training.",
395405
)
396406

397-
407+
parser_esk = subparsers.add_parser(
408+
"esk",
409+
parents=[parser_log],
410+
help="Generate initial empirical SK parameters.",
411+
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
412+
)
413+
parser_esk.add_argument(
414+
"INPUT", help="the input parameter file in json or yaml format",
415+
type=str,
416+
default=None
417+
)
418+
parser_esk.add_argument(
419+
"-o",
420+
"--output",
421+
type=str,
422+
default="./",
423+
help="The output files in training."
424+
)
425+
parser_esk.add_argument(
426+
"-m",
427+
"--basemodel",
428+
type=str,
429+
default="poly2",
430+
help="The base model type can be poly2 or poly4."
431+
)
398432
return parser
399433

400434
def parse_args(args: Optional[List[str]] = None) -> argparse.Namespace:
@@ -458,3 +492,6 @@ def main():
458492

459493
elif args.command == 'skf2nn':
460494
skf2nnsk(**dict_args)
495+
496+
elif args.command == 'esk':
497+
to_empsk(**dict_args)

0 commit comments

Comments
 (0)