Skip to content

Commit 98b8ec2

Browse files
Merge pull request #37 from ModelTC/ttt
remove redundant code
2 parents a657809 + 6fb8773 commit 98b8ec2

File tree

1 file changed

+1
-10
lines changed

1 file changed

+1
-10
lines changed

llmc/compression/quantization/quarot.py

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -92,15 +92,6 @@ def block_transform(self, block):
9292
logger.info(f'block:{block}')
9393
logger.info(f'End transform the {self.block_idx+1}-th block')
9494

95-
def bake_mean_into_linear(self, linear):
96-
linear_dtype = linear.weight.dtype
97-
W_ = linear.weight.data.double()
98-
linear.weight.data = W_ - W_.mean(dim=-2, keepdim=True)
99-
linear.weight.data = linear.weight.data.to(linear_dtype)
100-
if linear.bias is not None:
101-
b_ = linear.bias.data.double()
102-
linear.bias.data = b_ - b_.mean()
103-
linear.bias.data = linear.bias.data.to(linear_dtype)
10495

10596
@torch.no_grad()
10697
def subset_transform(self, block, subset):
@@ -117,7 +108,7 @@ def subset_transform(self, block, subset):
117108
self.rotate_pre_layers(layers, self.Q)
118109
else:
119110
if self.config['model']['type'] in ['Opt', 'StableLm']:
120-
self.bake_mean_into_linear(layers[0])
111+
self.bake_mean_into_fc(layers[0])
121112

122113
if 'is_mlp' in subset and subset['is_mlp']:
123114
self.rotate_post_layers(

0 commit comments

Comments
 (0)