File tree Expand file tree Collapse file tree 4 files changed +28
-6
lines changed Expand file tree Collapse file tree 4 files changed +28
-6
lines changed Original file line number Diff line number Diff line change @@ -202,6 +202,12 @@ def main_parser() -> argparse.ArgumentParser:
202
202
default = None ,
203
203
help = "The pth ckpt to be transfered to json." ,
204
204
)
205
+ parser_pth2json .add_argument (
206
+ "-dels" ,
207
+ "--deleteoverlap" ,
208
+ help = "Transfer to no overlap version." ,
209
+ action = "store_true"
210
+ )
205
211
parser_pth2json .add_argument (
206
212
"-o" ,
207
213
"--outdir" ,
Original file line number Diff line number Diff line change @@ -38,6 +38,24 @@ def pth2json(
38
38
39
39
json_dict = nnsk .to_json ()
40
40
41
+ if kwargs .get ("deleteoverlap" , False ):
42
+ log .info ("The deleteoverlap option is set. The model will be converted to orthogonal version." )
43
+ # convert the model to orthogonal version
44
+ log .info ("Converting the model to orthogonal version." )
45
+ # remove the overlap key in the json_dict
46
+ if "overlap" in json_dict ['model_params' ]:
47
+ del json_dict ['model_params' ]["overlap" ]
48
+ json_dict ['common_options' ]['overlap' ] = False
49
+
50
+ if json_dict ['model_options' ]['nnsk' ].get ('freeze' ,False ) == ['overlap' ]:
51
+ json_dict ['model_options' ]['nnsk' ]['freeze' ] = False
52
+ elif isinstance (json_dict ['model_options' ]['nnsk' ].get ('freeze' ,False ), list ):
53
+ json_dict ['model_options' ]['nnsk' ]['freeze' ] = [x for x in json_dict ['model_options' ]['nnsk' ]['freeze' ] if x != 'overlap' ]
54
+
55
+ # turn off the push option in the ckpt json
56
+ if isinstance (json_dict ['model_options' ]['nnsk' ].get ('push' ,False ), dict ):
57
+ json_dict ['model_options' ]['nnsk' ]['push' ] = False
58
+
41
59
# dump the json file
42
60
json_file = Path (outdir ) / "ckpt.json"
43
61
with open (json_file , "w" ) as f :
Original file line number Diff line number Diff line change @@ -63,11 +63,6 @@ def iteration(self, **kwargs):
63
63
name = self .trainer .model .name + suffix
64
64
self .latest_quene .append (name )
65
65
66
- if len (self .latest_quene ) > max_ckpt :
67
- delete_name = self .latest_quene .pop (0 )
68
- delete_path = os .path .join (self .checkpoint_path , delete_name + ".pth" )
69
- os .remove (delete_path )
70
-
71
66
if len (self .latest_quene ) > max_ckpt :
72
67
delete_name = self .latest_quene .pop (0 )
73
68
delete_path = os .path .join (self .checkpoint_path , delete_name + ".pth" )
Original file line number Diff line number Diff line change @@ -1679,7 +1679,10 @@ def collect_cutoffs(jdata):
1679
1679
r_max , er_max , oer_max = get_cutoffs_from_model_options (model_options )
1680
1680
1681
1681
if model_options .get ("nnsk" , None ) is not None :
1682
- if model_options ["nnsk" ]["push" ]:
1682
+ if model_options ["nnsk" ]["push" ] and \
1683
+ abs (model_options ["nnsk" ]["push" ]['rs_thr' ]) + \
1684
+ abs (model_options ["nnsk" ]["push" ]['rc_thr' ]) + \
1685
+ abs (model_options ["nnsk" ]["push" ]['w_thr' ]) > 1e-8 :
1683
1686
assert jdata .get ("data_options" ,None ) is not None , "data_options should be provided in jdata for nnsk push"
1684
1687
assert jdata ['data_options' ].get ("r_max" ) is not None , "r_max should be provided in data_options for nnsk push"
1685
1688
log .info ('YOU ARE USING NNSK PUSH MODEL, r_max will be used from data_options. Be careful! check the value in data options and model options. r_max or rs/rc !' )
You can’t perform that action at this time.
0 commit comments