Skip to content

Commit caa903d

Browse files
refactor(data preprocess): remove the cut off options from info.json (#200)
* refactor(data preprocess): remove the cut off options from info.json and collect the values from input.json * update LMDB info.json. not need anymore. * refactor(default_dataset): refactor the _TrajData for ase data. Previous the ase data will be transferred into text file and then loaded by the _TrajData. now i refactor the function. both text and ase data are treated equally. will works as a class funtion to initial the _TrajData class. * add print logo in main and format some of the logger.info * update argcheck collect_cutoffs. add new function with get_cutoffs_from_model_options . * Fix(get_cutoffs_from_model_options) : fix rcut in powerlaw and varTang96. For powerlaw and varTang96, the rs is not exactly the hard cutoff. so when extract the r_max for data. we have to use rs + 5 * w; but for other method just use rs. * update band post process. * update test * update test * update build and get_cutoffs_from_model_options to support the rmax to be dict. * refactor(build dataset): change build_dataset from function to a class instance and add from_model class function. note, compared to the previous build_dataset, this one is more flexible. previous build_dataset is a function. now i define a class DataBuilder and re-defined __call__ function. then build_dataset is an instance of DataBuilder class. so i can use build_dataset.from_model() to build dataset from model. at the same time the previous way to use build_dataset is still available. like build_dataset(...). * add checkcutoff in dataset builder. * update AtomicData_options to make it compatible with older versions * Update argcheck.py --------- Co-authored-by: Yinzhanghao Zhou <64253517+floatingCatty@users.noreply.github.com>
1 parent c5ca916 commit caa903d

37 files changed

+1513
-566
lines changed

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,9 @@ dptb/tests/**/*.pth
66
dptb/tests/**/*.npy
77
dptb/tests/**/*.traj
88
dptb/tests/**/out*/*
9+
dptb/tests/**/out*/*
10+
dptb/tests/**/*lmdb
11+
dptb/tests/**/*h5
912
examples/_*
1013
*.dat
1114
*log*

dptb/__main__.py

Lines changed: 36 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,39 @@
1-
from dptb.entrypoints.main import main
1+
from dptb.entrypoints.main import main as entry_main
2+
import logging
3+
import pyfiglet
4+
from dptb import __version__
25

6+
logging.basicConfig(level=logging.INFO, format='%(message)s')
7+
log = logging.getLogger(__name__)
8+
9+
def print_logo():
10+
f = pyfiglet.Figlet(font='dos_rebel') # 您可以选择您喜欢的字体
11+
logo = f.renderText("DeePTB")
12+
log.info(" ")
13+
log.info(" ")
14+
log.info("#"*81)
15+
log.info("#" + " "*79 + "#")
16+
log.info("#" + " "*79 + "#")
17+
for line in logo.split('\n'):
18+
if line.strip(): # 避免记录空行
19+
log.info('# '+line+ ' #')
20+
log.info("#" + " "*79 + "#")
21+
version_info = f"Version: {__version__}"
22+
padding = (79 - len(version_info)) // 2
23+
nspace = 79-padding
24+
format_str = "#" + "{}"+"{:<"+f"{nspace}" + "}"+ "#"
25+
log.info(format_str.format(" "*padding, version_info))
26+
log.info("#" + " "*79 + "#")
27+
log.info("#"*81)
28+
log.info(" ")
29+
log.info(" ")
30+
def main() -> None:
31+
"""
32+
The main entry point for the dptb package.
33+
"""
34+
print_logo()
35+
entry_main()
336

437
if __name__ == '__main__':
5-
main()
38+
#print_logo()
39+
main()

dptb/data/AtomicData.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -496,7 +496,7 @@ def from_points(
496496
def from_ase(
497497
cls,
498498
atoms,
499-
r_max,
499+
r_max: Union[float, int, dict],
500500
er_max: Optional[float] = None,
501501
oer_max: Optional[float] = None,
502502
key_mapping: Optional[Dict[str, str]] = {},

0 commit comments

Comments
 (0)