Skip to content

Commit 6938996

Browse files
authored
Merge pull request #45 from ModelTC/save_fix
save_fix_json
2 parents 34db26c + 6e01ac2 commit 6938996

File tree

1 file changed

+11
-6
lines changed

1 file changed

+11
-6
lines changed

llmc/compression/quantization/quarot.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -32,12 +32,6 @@ def preprocess(self):
3232
self.model.get_embed_layers()[0].weight,
3333
):
3434
logger.info('Tie weight! Copy embed_layer for head_layer!')
35-
path = os.path.join(self.config.model.path, 'config.json')
36-
with open(path, 'r') as f:
37-
config = json.load(f)
38-
config['tie_word_embeddings'] = False
39-
with open(path, 'w') as f:
40-
json.dump(config, f, indent=4)
4135
del self.model.get_head_layers()[0].weight
4236
w = self.model.get_embed_layers()[0].weight.clone()
4337
self.model.get_head_layers()[0].weight = nn.Parameter(w)
@@ -124,3 +118,14 @@ def subset_transform(self, block, subset):
124118
prev_op[0], had_dim=self.head_dim, output=True
125119
)
126120
apply_exact_had_to_linear(layers[0], had_dim=-1, output=False)
121+
122+
@torch.no_grad()
123+
def save_model(self, path):
124+
super().save_model(path)
125+
path = os.path.join(path, 'config.json')
126+
with open(path, 'r') as f:
127+
config = json.load(f)
128+
if 'tie_word_embeddings' in config:
129+
config['tie_word_embeddings'] = False
130+
with open(path, 'w') as f:
131+
json.dump(config, f, indent=4)

0 commit comments

Comments
 (0)