@@ -32,12 +32,6 @@ def preprocess(self):
32
32
self .model .get_embed_layers ()[0 ].weight ,
33
33
):
34
34
logger .info ('Tie weight! Copy embed_layer for head_layer!' )
35
- path = os .path .join (self .config .model .path , 'config.json' )
36
- with open (path , 'r' ) as f :
37
- config = json .load (f )
38
- config ['tie_word_embeddings' ] = False
39
- with open (path , 'w' ) as f :
40
- json .dump (config , f , indent = 4 )
41
35
del self .model .get_head_layers ()[0 ].weight
42
36
w = self .model .get_embed_layers ()[0 ].weight .clone ()
43
37
self .model .get_head_layers ()[0 ].weight = nn .Parameter (w )
@@ -124,3 +118,14 @@ def subset_transform(self, block, subset):
124
118
prev_op [0 ], had_dim = self .head_dim , output = True
125
119
)
126
120
apply_exact_had_to_linear (layers [0 ], had_dim = - 1 , output = False )
121
+
122
+ @torch .no_grad ()
123
+ def save_model (self , path ):
124
+ super ().save_model (path )
125
+ path = os .path .join (path , 'config.json' )
126
+ with open (path , 'r' ) as f :
127
+ config = json .load (f )
128
+ if 'tie_word_embeddings' in config :
129
+ config ['tie_word_embeddings' ] = False
130
+ with open (path , 'w' ) as f :
131
+ json .dump (config , f , indent = 4 )
0 commit comments