Skip to content

Commit aed2595

Browse files
committed
support some vlm models
1 parent 5daf093 commit aed2595

File tree

9 files changed

+160
-13
lines changed

9 files changed

+160
-13
lines changed

configs/quantization/Awq/awq_w4a16_fakequant_eval_general.yml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,4 +34,3 @@ quant:
3434
save:
3535
save_trans: False
3636
save_path: ./save
37-
tokenizer_file_substring: ["token"]

configs/quantization/Awq/awq_w_only_mix_bits_1.yml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,4 +43,3 @@ quant:
4343
save:
4444
save_trans: False
4545
save_path: ./save
46-
tokenizer_file_substring: ["token"]

configs/quantization/Awq/awq_w_only_mix_bits_2.yml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,4 +46,3 @@ quant:
4646
save:
4747
save_trans: False
4848
save_path: ./save
49-
tokenizer_file_substring: ["token"]

llmc/compression/quantization/base_blockwise_quantization.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -365,7 +365,7 @@ def apply_shift(self, shifts, prev_op, layers):
365365
def scale_fc_fc(self, fc1, fc2, scales):
366366
scales = scales.to(fc1.weight.device)
367367
if fc1.out_features == fc2.in_features * 3:
368-
num_heads = self.model.get_model_config().to_dict().get('n_head', None)
368+
num_heads = self.model.get_num_attention_heads()
369369
fc1.weight.t_()
370370
org_shape = fc1.weight.shape
371371
fc1.weight.data = fc1.weight.data.reshape(org_shape[0] * num_heads, 3, -1)
@@ -798,7 +798,8 @@ def deploy(self, quant_format):
798798

799799
@torch.no_grad()
800800
def copy_tokenizer(self, path):
801-
for substring in self.config.save.get('tokenizer_file_substring', ['token']):
801+
for substring in self.config.save.get('tokenizer_file_substring',
802+
['token', 'merges', 'vocab']):
802803
copy_files(self.config.model.path, path, substring)
803804
logger.info('copy tokenizer done --')
804805

@@ -818,9 +819,9 @@ def save_model(self, path):
818819
return
819820
if self.online_rotate:
820821
self.contiguous_params()
821-
if self.config.model.type == 'Llava':
822-
self.model.llava_model.language_model = self.model.get_model()
823-
self.model.llava_model.save_pretrained(path)
822+
if self.config.model.type in ['Llava', 'InternVL2']:
823+
self.model.vlm_model.language_model = self.model.get_model()
824+
self.model.vlm_model.save_pretrained(path)
824825
logger.info('save model done --')
825826
self.copy_tokenizer(path)
826827
copy_files(self.config.model.path, path, 'preprocessor_config')

llmc/models/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,17 @@
22
from .falcon import Falcon
33
from .gemma2 import Gemma2
44
from .internlm2 import InternLM2
5+
from .internvl2 import InternVL2
56
from .llama import Llama
67
from .llava import Llava
78
from .minicpm import MiniCPM
89
from .mistral import Mistral
910
from .mixtral import Mixtral
1011
from .opt import Opt
1112
from .phi import Phi
13+
from .qwen import Qwen
1214
from .qwen2 import Qwen2
15+
from .qwenvl import QwenVL
1316
from .smollm import SmolLM
1417
from .stablelm import StableLm
1518
from .starcoder import Starcoder

llmc/models/internvl2.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
from loguru import logger
2+
from transformers import AutoConfig, AutoModelForCausalLM
3+
4+
from llmc.utils.registry_factory import MODEL_REGISTRY
5+
6+
from .internlm2 import InternLM2
7+
8+
9+
@MODEL_REGISTRY
10+
class InternVL2(InternLM2):
11+
def __init__(self, model_path, torch_dtype):
12+
super().__init__(model_path, torch_dtype)
13+
14+
def build_model(self):
15+
self.vlm_model_config = AutoConfig.from_pretrained(
16+
self.model_path, trust_remote_code=True
17+
)
18+
if hasattr(self.vlm_model_config, 'use_cache'):
19+
self.vlm_model_config.use_cache = False
20+
logger.info(f'self.vlm_model_config : {self.vlm_model_config}')
21+
self.vlm_model = AutoModelForCausalLM.from_pretrained(
22+
self.model_path,
23+
config=self.vlm_model_config,
24+
trust_remote_code=True,
25+
torch_dtype=self.torch_dtype,
26+
low_cpu_mem_usage=True,
27+
)
28+
self.model = self.vlm_model.language_model
29+
self.model_config = self.vlm_model_config.llm_config

llmc/models/llava.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,14 +20,16 @@ def __init__(self, model_path, torch_dtype):
2020
super().__init__(model_path, torch_dtype)
2121

2222
def build_model(self):
23-
self.model_config = AutoConfig.from_pretrained(
23+
self.vlm_model_config = AutoConfig.from_pretrained(
2424
self.model_path, trust_remote_code=True
2525
)
26-
self.model_config.text_config.use_cache = False
27-
self.llava_model = LlavaForConditionalGeneration.from_pretrained(
26+
self.vlm_model_config.text_config.use_cache = False
27+
logger.info(f'self.vlm_model_config : {self.vlm_model_config}')
28+
self.vlm_model = LlavaForConditionalGeneration.from_pretrained(
2829
self.model_path,
29-
config=self.model_config,
30+
config=self.vlm_model_config,
3031
torch_dtype=self.torch_dtype,
3132
low_cpu_mem_usage=True,
3233
)
33-
self.model = self.llava_model.language_model
34+
self.model = self.vlm_model.language_model
35+
self.model_config = self.vlm_model_config.text_config

llmc/models/qwen.py

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
from llmc.utils.registry_factory import MODEL_REGISTRY
2+
3+
from .base_model import BaseModel
4+
5+
6+
@MODEL_REGISTRY
7+
class Qwen(BaseModel):
8+
def __init__(self, model_path, torch_dtype):
9+
super().__init__(model_path, torch_dtype)
10+
11+
def find_blocks(self):
12+
self.blocks = self.model.transformer.h
13+
14+
def find_embed_layers(self):
15+
self.wte = self.model.transformer.wte
16+
self.rotary_emb = self.model.transformer.rotary_emb
17+
18+
def find_block_name(self):
19+
self.block_name_prefix = 'transformer.h'
20+
21+
def get_embed_layers(self):
22+
return [self.wte, self.rotary_emb]
23+
24+
def get_head_layers(self):
25+
return [self.model.lm_head]
26+
27+
def get_pre_head_layernorm_layers(self):
28+
return [self.model.transformer.ln_f]
29+
30+
def get_layers_except_blocks(self):
31+
return [self.wte,
32+
self.rotary_emb,
33+
self.model.transformer.ln_f,
34+
self.model.lm_head]
35+
36+
def has_bias(self):
37+
return False
38+
39+
def get_layernorms_in_block(self, block):
40+
return {
41+
'ln_1': block.ln_1,
42+
'ln_2': block.ln_2,
43+
}
44+
45+
def get_num_attention_heads(self):
46+
return self.model_config.num_attention_heads
47+
48+
def get_subsets_in_block(self, block):
49+
return [
50+
{
51+
'layers': {
52+
'attn.c_attn': block.attn.c_attn
53+
},
54+
'prev_op': [block.ln_1],
55+
'input': ['attn.c_attn'],
56+
'inspect': block.attn,
57+
'has_kwargs': True,
58+
},
59+
{
60+
'layers': {'attn.c_proj': block.attn.c_proj},
61+
'prev_op': [block.attn.c_attn],
62+
'input': ['attn.c_proj'],
63+
'inspect': block.attn.c_proj,
64+
'has_kwargs': False,
65+
},
66+
{
67+
'layers': {
68+
'mlp.w1': block.mlp.w1,
69+
'mlp.w2': block.mlp.w2,
70+
},
71+
'prev_op': [block.ln_2],
72+
'input': ['mlp.w1'],
73+
'inspect': block.mlp,
74+
'has_kwargs': False,
75+
'is_mlp': True,
76+
},
77+
{
78+
'layers': {'mlp.c_proj': block.mlp.c_proj},
79+
'prev_op': [block.mlp.w1],
80+
'input': ['mlp.c_proj'],
81+
'inspect': block.mlp.c_proj,
82+
'has_kwargs': False,
83+
'is_mlp': True,
84+
},
85+
]

llmc/models/qwenvl.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
from loguru import logger
2+
from transformers import AutoConfig, AutoModelForCausalLM
3+
4+
from llmc.utils.registry_factory import MODEL_REGISTRY
5+
6+
from .qwen import Qwen
7+
8+
9+
@MODEL_REGISTRY
10+
class QwenVL(Qwen):
11+
def __init__(self, model_path, torch_dtype):
12+
super().__init__(model_path, torch_dtype)
13+
14+
def build_model(self):
15+
self.vlm_model_config = AutoConfig.from_pretrained(
16+
self.model_path, trust_remote_code=True
17+
)
18+
if hasattr(self.vlm_model_config, 'use_cache'):
19+
self.vlm_model_config.use_cache = False
20+
logger.info(f'self.vlm_model_config : {self.vlm_model_config}')
21+
self.vlm_model = AutoModelForCausalLM.from_pretrained(
22+
self.model_path,
23+
config=self.vlm_model_config,
24+
trust_remote_code=True,
25+
torch_dtype=self.torch_dtype,
26+
low_cpu_mem_usage=True,
27+
)
28+
self.model = self.vlm_model
29+
self.model_config = self.vlm_model_config
30+
self.vision_model = self.vlm_model.transformer.visual

0 commit comments

Comments
 (0)