@@ -1452,34 +1452,91 @@ def normalize_lmdbsetinfo(data):
1452
1452
1453
1453
return data
1454
1454
1455
- def collect_cutoffs (jdata ):
1456
- # collect r_max infos from model options.
1455
+ def get_cutoffs_from_model_options (model_options ):
1456
+ """
1457
+ Extract cutoff values from the provided model options.
1458
+
1459
+ This function retrieves the cutoff values `r_max`, `er_max`, and `oer_max` from the `model_options`
1460
+ dictionary. It handles different model types such as `embedding`, `nnsk`, and `dftbsk`, ensuring
1461
+ that the appropriate cutoff values are provided and valid.
1462
+
1463
+ Parameters:
1464
+ model_options (dict): A dictionary containing model configuration options. It may include keys
1465
+ like `embedding`, `nnsk`, and `dftbsk` with their respective cutoff values.
1466
+
1467
+ Returns:
1468
+ tuple: A tuple containing the cutoff values (`r_max`, `er_max`, `oer_max`).
1469
+
1470
+ Raises:
1471
+ ValueError: If neither `r_max` nor `rc` is provided in `model_options` for embedding.
1472
+ AssertionError: If `r_max` is provided outside the `nnsk` or `dftbsk` context when those models are used.
1473
+
1474
+ Logs:
1475
+ Error messages if required cutoff values are missing or incorrectly provided.
1476
+ """
1457
1477
r_max , er_max , oer_max = None , None , None
1458
- if jdata [ " model_options" ] .get ("embedding" ,None ) is not None :
1459
- if jdata [ " model_options" ] ["embedding" ].get ("r_max" ,None ) is not None :
1460
- r_max = jdata [ " model_options" ] ["embedding" ]["r_max" ]
1461
- elif jdata [ " model_options" ] ["embedding" ].get ("rc" ,None ) is not None :
1462
- er_max = jdata [ " model_options" ] ["embedding" ]["rc" ]
1478
+ if model_options .get ("embedding" ,None ) is not None :
1479
+ if model_options ["embedding" ].get ("r_max" ,None ) is not None :
1480
+ r_max = model_options ["embedding" ]["r_max" ]
1481
+ elif model_options ["embedding" ].get ("rc" ,None ) is not None :
1482
+ er_max = model_options ["embedding" ]["rc" ]
1463
1483
else :
1464
1484
log .error ("r_max or rc should be provided in model_options for embedding!" )
1465
1485
raise ValueError ("r_max or rc should be provided in model_options for embedding!" )
1466
-
1467
- if jdata [ " model_options" ] .get ("nnsk" , None ) is not None :
1486
+
1487
+ if model_options .get ("nnsk" , None ) is not None :
1468
1488
assert r_max is None , "r_max should not be provided in outside the nnsk for training nnsk model."
1469
1489
1470
- if jdata [ " model_options" ] ["nnsk" ]["hopping" ].get ("rs" ,None ) is not None :
1471
- r_max = jdata [ " model_options" ] ["nnsk" ]["hopping" ]["rs" ]
1490
+ if model_options ["nnsk" ]["hopping" ].get ("rs" ,None ) is not None :
1491
+ r_max = model_options ["nnsk" ]["hopping" ]["rs" ]
1472
1492
1473
- if jdata ["model_options" ]["nnsk" ]["onsite" ].get ("rs" ,None ) is not None :
1474
- oer_max = jdata ["model_options" ]["nnsk" ]["onsite" ]["rs" ]
1475
-
1476
- ## for specific case: PUSH. r_max will be used from data_options.
1477
- if jdata ["model_options" ]["nnsk" ]["push" ]:
1493
+ if model_options ["nnsk" ]["onsite" ].get ("rs" ,None ) is not None :
1494
+ oer_max = model_options ["nnsk" ]["onsite" ]["rs" ]
1495
+
1496
+ elif model_options .get ("dftbsk" , None ) is not None :
1497
+ assert r_max is None , "r_max should not be provided in outside the dftbsk for training dftbsk model."
1498
+ r_max = model_options ["dftbsk" ]["r_max" ]
1499
+
1500
+ else :
1501
+ # not nnsk not dftbsk, must be only env or E3. the embedding should be provided.
1502
+ assert model_options .get ("embedding" ,None ) is not None
1503
+
1504
+ return r_max , er_max , oer_max
1505
+ def collect_cutoffs (jdata ):
1506
+ """
1507
+ Collect cutoff values from the provided JSON data.
1508
+
1509
+ This function extracts the cutoff values `r_max`, `er_max`, and `oer_max` from the `model_options`
1510
+ in the provided JSON data. If the `nnsk` push model is used, it ensures that the necessary
1511
+ cutoff values are provided in `data_options` and overrides the values from `model_options`
1512
+ accordingly.
1513
+
1514
+ Parameters:
1515
+ jdata (dict): A dictionary containing model and data options. It must include `model_options`
1516
+ and optionally `data_options` if `nnsk` push model is used.
1517
+
1518
+ Returns:
1519
+ dict: A dictionary containing the cutoff options with keys `r_max`, `er_max`, and `oer_max`.
1520
+
1521
+ Raises:
1522
+ AssertionError: If required keys are missing in `jdata` or if `r_max` is not provided when
1523
+ using the `nnsk` push model.
1524
+
1525
+ Logs:
1526
+ Various informational messages about the cutoff values and their sources.
1527
+ """
1528
+
1529
+ model_options = jdata ["model_options" ]
1530
+ r_max , er_max , oer_max = get_cutoffs_from_model_options (model_options )
1531
+
1532
+ if model_options .get ("nnsk" , None ) is not None :
1533
+ if model_options ["nnsk" ]["push" ]:
1534
+ assert jdata .get ("data_options" ,None ) is not None , "data_options should be provided in jdata for nnsk push"
1478
1535
assert jdata ['data_options' ].get ("r_max" ) is not None , "r_max should be provided in data_options for nnsk push"
1479
1536
log .info ('YOU ARE USING NNSK PUSH MODEL, r_max will be used from data_options. Be careful! check the value in data options and model options. r_max or rs/rc !' )
1480
1537
r_max = jdata ['data_options' ]['r_max' ]
1481
-
1482
- if jdata [ " model_options" ] ["nnsk" ]["onsite" ]["method" ] in ["strain" , "NRL" ]:
1538
+
1539
+ if model_options ["nnsk" ]["onsite" ]["method" ] in ["strain" , "NRL" ]:
1483
1540
assert jdata ['data_options' ].get ("oer_max" ) is not None , "oer_max should be provided in data_options for nnsk push with strain onsite mode"
1484
1541
log .info ('YOU ARE USING NNSK PUSH MODEL with `strain` onsite mode, oer_max will be used from data_options. Be careful! check the value in data options and model options. rs/rc !' )
1485
1542
oer_max = jdata ['data_options' ]['oer_max' ]
@@ -1489,16 +1546,7 @@ def collect_cutoffs(jdata):
1489
1546
else :
1490
1547
if jdata ['data_options' ].get ("r_max" ) is not None :
1491
1548
log .info ("When not nnsk/push. the cutoffs will take from the model options: r_max rs and rc values. this seting in data_options will be ignored." )
1492
-
1493
- elif jdata ["model_options" ].get ("dftbsk" , None ) is not None :
1494
- assert r_max is None , "r_max should not be provided in outside the dftbsk for training dftbsk model."
1495
- r_max = jdata ["model_options" ]["dftbsk" ]["r_max" ]
1496
-
1497
- else :
1498
- # not nnsk not dftbsk, must be only env or E3. the embedding should be provided.
1499
- assert jdata ["model_options" ].get ("embedding" ,None ) is not None
1500
1549
1501
-
1502
1550
assert r_max is not None
1503
1551
cutoff_options = ({"r_max" : r_max , "er_max" : er_max , "oer_max" : oer_max })
1504
1552
0 commit comments