Skip to content

Commit 40f71f6

Browse files
authored
feat: Add delete overlap option to pth2json and improve model saving logic (#268)
1 parent 5b97981 commit 40f71f6

File tree

4 files changed

+28
-6
lines changed

4 files changed

+28
-6
lines changed

dptb/entrypoints/main.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -202,6 +202,12 @@ def main_parser() -> argparse.ArgumentParser:
202202
default=None,
203203
help="The pth ckpt to be transfered to json.",
204204
)
205+
parser_pth2json.add_argument(
206+
"-dels",
207+
"--deleteoverlap",
208+
help="Transfer to no overlap version.",
209+
action="store_true"
210+
)
205211
parser_pth2json.add_argument(
206212
"-o",
207213
"--outdir",

dptb/entrypoints/pth2json.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,24 @@ def pth2json(
3838

3939
json_dict = nnsk.to_json()
4040

41+
if kwargs.get("deleteoverlap", False):
42+
log.info("The deleteoverlap option is set. The model will be converted to orthogonal version.")
43+
# convert the model to orthogonal version
44+
log.info("Converting the model to orthogonal version.")
45+
# remove the overlap key in the json_dict
46+
if "overlap" in json_dict['model_params']:
47+
del json_dict['model_params']["overlap"]
48+
json_dict['common_options']['overlap'] = False
49+
50+
if json_dict['model_options']['nnsk'].get('freeze',False) == ['overlap']:
51+
json_dict['model_options']['nnsk']['freeze'] = False
52+
elif isinstance(json_dict['model_options']['nnsk'].get('freeze',False), list):
53+
json_dict['model_options']['nnsk']['freeze'] = [x for x in json_dict['model_options']['nnsk']['freeze'] if x != 'overlap']
54+
55+
# turn off the push option in the ckpt json
56+
if isinstance(json_dict['model_options']['nnsk'].get('push',False), dict):
57+
json_dict['model_options']['nnsk']['push'] = False
58+
4159
# dump the json file
4260
json_file = Path(outdir) / "ckpt.json"
4361
with open(json_file, "w") as f:

dptb/plugins/saver.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -63,11 +63,6 @@ def iteration(self, **kwargs):
6363
name = self.trainer.model.name+suffix
6464
self.latest_quene.append(name)
6565

66-
if len(self.latest_quene) > max_ckpt:
67-
delete_name = self.latest_quene.pop(0)
68-
delete_path = os.path.join(self.checkpoint_path, delete_name+".pth")
69-
os.remove(delete_path)
70-
7166
if len(self.latest_quene) > max_ckpt:
7267
delete_name = self.latest_quene.pop(0)
7368
delete_path = os.path.join(self.checkpoint_path, delete_name+".pth")

dptb/utils/argcheck.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1679,7 +1679,10 @@ def collect_cutoffs(jdata):
16791679
r_max, er_max, oer_max = get_cutoffs_from_model_options(model_options)
16801680

16811681
if model_options.get("nnsk", None) is not None:
1682-
if model_options["nnsk"]["push"]:
1682+
if model_options["nnsk"]["push"] and \
1683+
abs(model_options["nnsk"]["push"]['rs_thr']) + \
1684+
abs(model_options["nnsk"]["push"]['rc_thr']) + \
1685+
abs(model_options["nnsk"]["push"]['w_thr']) > 1e-8:
16831686
assert jdata.get("data_options",None) is not None, "data_options should be provided in jdata for nnsk push"
16841687
assert jdata['data_options'].get("r_max") is not None, "r_max should be provided in data_options for nnsk push"
16851688
log.info('YOU ARE USING NNSK PUSH MODEL, r_max will be used from data_options. Be careful! check the value in data options and model options. r_max or rs/rc !')

0 commit comments

Comments
 (0)