Skip to content

Commit 6563dda

Browse files
committed
update argcheck collect_cutoffs. add new function with get_cutoffs_from_model_options .
1 parent 3305d33 commit 6563dda

File tree

1 file changed

+75
-27
lines changed

1 file changed

+75
-27
lines changed

dptb/utils/argcheck.py

Lines changed: 75 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1452,34 +1452,91 @@ def normalize_lmdbsetinfo(data):
14521452

14531453
return data
14541454

1455-
def collect_cutoffs(jdata):
1456-
# collect r_max infos from model options.
1455+
def get_cutoffs_from_model_options(model_options):
1456+
"""
1457+
Extract cutoff values from the provided model options.
1458+
1459+
This function retrieves the cutoff values `r_max`, `er_max`, and `oer_max` from the `model_options`
1460+
dictionary. It handles different model types such as `embedding`, `nnsk`, and `dftbsk`, ensuring
1461+
that the appropriate cutoff values are provided and valid.
1462+
1463+
Parameters:
1464+
model_options (dict): A dictionary containing model configuration options. It may include keys
1465+
like `embedding`, `nnsk`, and `dftbsk` with their respective cutoff values.
1466+
1467+
Returns:
1468+
tuple: A tuple containing the cutoff values (`r_max`, `er_max`, `oer_max`).
1469+
1470+
Raises:
1471+
ValueError: If neither `r_max` nor `rc` is provided in `model_options` for embedding.
1472+
AssertionError: If `r_max` is provided outside the `nnsk` or `dftbsk` context when those models are used.
1473+
1474+
Logs:
1475+
Error messages if required cutoff values are missing or incorrectly provided.
1476+
"""
14571477
r_max, er_max, oer_max = None, None, None
1458-
if jdata["model_options"].get("embedding",None) is not None:
1459-
if jdata["model_options"]["embedding"].get("r_max",None) is not None:
1460-
r_max = jdata["model_options"]["embedding"]["r_max"]
1461-
elif jdata["model_options"]["embedding"].get("rc",None) is not None:
1462-
er_max = jdata["model_options"]["embedding"]["rc"]
1478+
if model_options.get("embedding",None) is not None:
1479+
if model_options["embedding"].get("r_max",None) is not None:
1480+
r_max = model_options["embedding"]["r_max"]
1481+
elif model_options["embedding"].get("rc",None) is not None:
1482+
er_max = model_options["embedding"]["rc"]
14631483
else:
14641484
log.error("r_max or rc should be provided in model_options for embedding!")
14651485
raise ValueError("r_max or rc should be provided in model_options for embedding!")
1466-
1467-
if jdata["model_options"].get("nnsk", None) is not None:
1486+
1487+
if model_options.get("nnsk", None) is not None:
14681488
assert r_max is None, "r_max should not be provided in outside the nnsk for training nnsk model."
14691489

1470-
if jdata["model_options"]["nnsk"]["hopping"].get("rs",None) is not None:
1471-
r_max = jdata["model_options"]["nnsk"]["hopping"]["rs"]
1490+
if model_options["nnsk"]["hopping"].get("rs",None) is not None:
1491+
r_max = model_options["nnsk"]["hopping"]["rs"]
14721492

1473-
if jdata["model_options"]["nnsk"]["onsite"].get("rs",None) is not None:
1474-
oer_max = jdata["model_options"]["nnsk"]["onsite"]["rs"]
1475-
1476-
## for specific case: PUSH. r_max will be used from data_options.
1477-
if jdata["model_options"]["nnsk"]["push"]:
1493+
if model_options["nnsk"]["onsite"].get("rs",None) is not None:
1494+
oer_max = model_options["nnsk"]["onsite"]["rs"]
1495+
1496+
elif model_options.get("dftbsk", None) is not None:
1497+
assert r_max is None, "r_max should not be provided in outside the dftbsk for training dftbsk model."
1498+
r_max = model_options["dftbsk"]["r_max"]
1499+
1500+
else:
1501+
# not nnsk not dftbsk, must be only env or E3. the embedding should be provided.
1502+
assert model_options.get("embedding",None) is not None
1503+
1504+
return r_max, er_max, oer_max
1505+
def collect_cutoffs(jdata):
1506+
"""
1507+
Collect cutoff values from the provided JSON data.
1508+
1509+
This function extracts the cutoff values `r_max`, `er_max`, and `oer_max` from the `model_options`
1510+
in the provided JSON data. If the `nnsk` push model is used, it ensures that the necessary
1511+
cutoff values are provided in `data_options` and overrides the values from `model_options`
1512+
accordingly.
1513+
1514+
Parameters:
1515+
jdata (dict): A dictionary containing model and data options. It must include `model_options`
1516+
and optionally `data_options` if `nnsk` push model is used.
1517+
1518+
Returns:
1519+
dict: A dictionary containing the cutoff options with keys `r_max`, `er_max`, and `oer_max`.
1520+
1521+
Raises:
1522+
AssertionError: If required keys are missing in `jdata` or if `r_max` is not provided when
1523+
using the `nnsk` push model.
1524+
1525+
Logs:
1526+
Various informational messages about the cutoff values and their sources.
1527+
"""
1528+
1529+
model_options = jdata["model_options"]
1530+
r_max, er_max, oer_max = get_cutoffs_from_model_options(model_options)
1531+
1532+
if model_options.get("nnsk", None) is not None:
1533+
if model_options["nnsk"]["push"]:
1534+
assert jdata.get("data_options",None) is not None, "data_options should be provided in jdata for nnsk push"
14781535
assert jdata['data_options'].get("r_max") is not None, "r_max should be provided in data_options for nnsk push"
14791536
log.info('YOU ARE USING NNSK PUSH MODEL, r_max will be used from data_options. Be careful! check the value in data options and model options. r_max or rs/rc !')
14801537
r_max = jdata['data_options']['r_max']
1481-
1482-
if jdata["model_options"]["nnsk"]["onsite"]["method"] in ["strain", "NRL"]:
1538+
1539+
if model_options["nnsk"]["onsite"]["method"] in ["strain", "NRL"]:
14831540
assert jdata['data_options'].get("oer_max") is not None, "oer_max should be provided in data_options for nnsk push with strain onsite mode"
14841541
log.info('YOU ARE USING NNSK PUSH MODEL with `strain` onsite mode, oer_max will be used from data_options. Be careful! check the value in data options and model options. rs/rc !')
14851542
oer_max = jdata['data_options']['oer_max']
@@ -1489,16 +1546,7 @@ def collect_cutoffs(jdata):
14891546
else:
14901547
if jdata['data_options'].get("r_max") is not None:
14911548
log.info("When not nnsk/push. the cutoffs will take from the model options: r_max rs and rc values. this seting in data_options will be ignored.")
1492-
1493-
elif jdata["model_options"].get("dftbsk", None) is not None:
1494-
assert r_max is None, "r_max should not be provided in outside the dftbsk for training dftbsk model."
1495-
r_max = jdata["model_options"]["dftbsk"]["r_max"]
1496-
1497-
else:
1498-
# not nnsk not dftbsk, must be only env or E3. the embedding should be provided.
1499-
assert jdata["model_options"].get("embedding",None) is not None
15001549

1501-
15021550
assert r_max is not None
15031551
cutoff_options = ({"r_max": r_max, "er_max": er_max, "oer_max": oer_max})
15041552

0 commit comments

Comments
 (0)