Skip to content

Commit c2fbf60

Browse files
authored
[GLM4.1V and GLM4.5V] Add vision transformer num_dummy_head support: max tp=4 -> max tp=8 (sgl-project#9059)
1 parent 98b44e9 commit c2fbf60

File tree

9 files changed

+150
-102
lines changed

9 files changed

+150
-102
lines changed

benchmark/mmmu/bench_hf.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -141,9 +141,13 @@ def eval_mmmu(args):
141141
print(f"response: {response}")
142142
process_result(response, sample, answer_dict, out_samples)
143143

144-
args.output_path = f"{args.model_path}_val_hf.json"
144+
args.output_path = f"{args.model_path}_answer_hf.json"
145145
save_json(args.output_path, out_samples)
146-
eval_result(model_answer_path=args.output_path, answer_dict=answer_dict)
146+
eval_result(
147+
model_answer_path=args.output_path,
148+
answer_dict=answer_dict,
149+
eval_output_path=f"{args.model_path}_val_hf.json",
150+
)
147151

148152

149153
if __name__ == "__main__":

benchmark/mmmu/bench_sglang.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -187,9 +187,13 @@ async def eval_mmmu(args) -> None:
187187
print("Profiler stopped")
188188

189189
print(f"Benchmark time: {time.perf_counter() - start}")
190-
args.output_path = f"./val_sglang.json"
190+
args.output_path = "./answer_sglang.json"
191191
save_json(args.output_path, out_samples)
192-
eval_result(model_answer_path=args.output_path, answer_dict=answer_dict)
192+
eval_result(
193+
model_answer_path=args.output_path,
194+
answer_dict=answer_dict,
195+
eval_output_path="./val_sglang.json",
196+
)
193197

194198

195199
def parse_args():

benchmark/mmmu/eval_utils.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -544,7 +544,9 @@ def process_result(response, sample, answer_dict, out_samples):
544544
}
545545

546546

547-
def eval_result(model_answer_path, answer_dict):
547+
def eval_result(model_answer_path, answer_dict, eval_output_path=None):
548+
if eval_output_path is None:
549+
eval_output_path = model_answer_path
548550
print("Evaluating...")
549551
output_dict = json.load(open(model_answer_path))
550552
# answer_dict = json.load(open(answer_path))
@@ -639,7 +641,7 @@ def eval_result(model_answer_path, answer_dict):
639641
"acc": overall_acc,
640642
}
641643
pprint.pprint(printable_results)
642-
out = model_answer_path
644+
out = eval_output_path
643645
with open(out, "w", encoding="utf-8") as outfile:
644646
json.dump(printable_results, outfile)
645647
print(f"eval out saved to {out}")
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
"""Utility functions for vision attention layers."""
2+
3+
import torch
4+
5+
from sglang.srt.layers.dp_attention import get_attention_tp_size
6+
7+
8+
def update_vit_attn_dummy_heads_config(config):
9+
"""Update HF config to ensure vision attention num_attention_heads is divisible by tp_size"""
10+
tp_size = get_attention_tp_size()
11+
num_heads = getattr(
12+
config.vision_config,
13+
"num_heads",
14+
getattr(config.vision_config, "num_attention_heads", None),
15+
)
16+
head_dim = config.vision_config.hidden_size // num_heads
17+
num_dummy_heads = 0
18+
19+
if num_heads % tp_size != 0:
20+
num_dummy_heads = ((num_heads + tp_size - 1) // tp_size) * tp_size - num_heads
21+
22+
setattr(config.vision_config, "head_dim", head_dim)
23+
setattr(config.vision_config, "num_dummy_heads", num_dummy_heads)
24+
25+
26+
def pad_vit_attn_dummy_heads(config, name: str, loaded_weight: torch.Tensor):
27+
"""Pad attention qkv weights for dummy heads"""
28+
num_dummy_heads = config.vision_config.num_dummy_heads
29+
if num_dummy_heads == 0:
30+
return loaded_weight
31+
head_dim = config.vision_config.head_dim
32+
33+
if "attn.qkv_proj" in name:
34+
wq, wk, wv = loaded_weight.chunk(3, dim=0)
35+
if name.endswith(".weight"):
36+
dummy_shape = [num_dummy_heads, head_dim, wq.shape[-1]]
37+
elif name.endswith(".bias"):
38+
dummy_shape = [num_dummy_heads, head_dim]
39+
else:
40+
raise RuntimeError(f"Unsupported weight with name={name}")
41+
pad_func = lambda x: torch.cat(
42+
[x.unflatten(0, (-1, head_dim)), x.new_zeros(dummy_shape)], dim=0
43+
).flatten(0, 1)
44+
wq, wk, wv = pad_func(wq), pad_func(wk), pad_func(wv)
45+
loaded_weight = torch.cat([wq, wk, wv], dim=0)
46+
elif any([_ in name for _ in ["attn.q_proj", "attn.k_proj", "attn.v_proj"]]):
47+
if name.endswith(".weight"):
48+
dummy_shape = [num_dummy_heads, head_dim, loaded_weight.shape[-1]]
49+
elif name.endswith(".bias"):
50+
dummy_shape = [num_dummy_heads, head_dim]
51+
else:
52+
raise RuntimeError(f"Unsupported weight with name={name}")
53+
padded_weight = loaded_weight.new_zeros(dummy_shape)
54+
loaded_weight = torch.cat(
55+
[loaded_weight.unflatten(0, (-1, head_dim)), padded_weight], dim=0
56+
).flatten(0, 1)
57+
elif "attn.proj.weight" in name:
58+
padded_weight = loaded_weight.new_zeros(
59+
loaded_weight.shape[0], head_dim * num_dummy_heads
60+
)
61+
loaded_weight = torch.cat([loaded_weight, padded_weight], dim=-1)
62+
elif "attn.q_norm.weight" in name or "attn.k_norm.weight" in name:
63+
padded_weight = loaded_weight.new_zeros(head_dim * num_dummy_heads)
64+
loaded_weight = torch.cat([loaded_weight, padded_weight], dim=0)
65+
return loaded_weight

python/sglang/srt/models/glm4v.py

Lines changed: 52 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
from sglang.srt.hf_transformers_utils import get_processor
1111
from sglang.srt.layers.activation import SiluAndMul
12+
from sglang.srt.layers.attention import vision_utils
1213
from sglang.srt.layers.layernorm import RMSNorm
1314
from sglang.srt.layers.linear import (
1415
ColumnParallelLinear,
@@ -91,6 +92,7 @@ def __init__(
9192
norm_layer=norm_layer,
9293
quant_config=quant_config,
9394
prefix=prefix,
95+
num_dummy_heads=config.num_dummy_heads,
9496
)
9597
self.norm1 = Glm4vRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
9698
self.norm2 = Glm4vRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
@@ -469,7 +471,7 @@ def __init__(
469471
nn.Module.__init__(self)
470472

471473
self.config = config
472-
474+
vision_utils.update_vit_attn_dummy_heads_config(self.config)
473475
self.model = Glm4Model(
474476
config,
475477
quant_config,
@@ -537,6 +539,51 @@ def get_video_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
537539
video_embeds = torch.split(video_embeds, split_sizes)
538540
return torch.cat(video_embeds)
539541

542+
def _update_hf_config(self):
543+
"""update hf config to ensure vision attention num_attention_heads is divisible by tp_size"""
544+
tp_size = get_attention_tp_size()
545+
num_heads = self.config.vision_config.num_heads
546+
head_dim = self.config.vision_config.hidden_size // num_heads
547+
num_dummy_heads = 0
548+
549+
if num_heads % tp_size != 0:
550+
num_dummy_heads = (
551+
(num_heads + tp_size - 1) // tp_size
552+
) * tp_size - num_heads
553+
554+
setattr(self.config.vision_config, "head_dim", head_dim)
555+
setattr(self.config.vision_config, "num_dummy_heads", num_dummy_heads)
556+
557+
def _pad_vit_attn_dummy_heads(self, name: str, loaded_weight: torch.Tensor):
558+
"""pad attn qkv weights for dummy heads"""
559+
num_dummy_heads = self.config.vision_config.num_dummy_heads
560+
if num_dummy_heads == 0:
561+
return loaded_weight
562+
head_dim = self.config.vision_config.head_dim
563+
564+
if "attn.qkv_proj" in name:
565+
wq, wk, wv = loaded_weight.chunk(3, dim=0)
566+
if name.endswith(".weight"):
567+
dummy_shape = [num_dummy_heads, head_dim, wq.shape[-1]]
568+
elif name.endswith(".bias"):
569+
dummy_shape = [num_dummy_heads, head_dim]
570+
else:
571+
raise RuntimeError(f"Unsupported weight with name={name}")
572+
pad_func = lambda x: torch.cat(
573+
[x.unflatten(0, (-1, head_dim)), x.new_zeros(dummy_shape)], dim=0
574+
).flatten(0, 1)
575+
wq, wk, wv = pad_func(wq), pad_func(wk), pad_func(wv)
576+
loaded_weight = torch.cat([wq, wk, wv], dim=0)
577+
elif "attn.proj.weight" in name:
578+
padded_weight = loaded_weight.new_zeros(
579+
loaded_weight.shape[0], head_dim * num_dummy_heads
580+
)
581+
loaded_weight = torch.cat([loaded_weight, padded_weight], dim=-1)
582+
elif "attn.q_norm.weight" in name or "attn.k_norm.weight" in name:
583+
padded_weight = loaded_weight.new_zeros(head_dim * num_dummy_heads)
584+
loaded_weight = torch.cat([loaded_weight, padded_weight], dim=0)
585+
return loaded_weight
586+
540587
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
541588
stacked_params_mapping = [
542589
# (param_name, shard_name, shard_id)
@@ -583,6 +630,10 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
583630
raise
584631

585632
weight_loader = getattr(param, "weight_loader", default_weight_loader)
633+
if "visual" in name:
634+
loaded_weight = vision_utils.pad_vit_attn_dummy_heads(
635+
self.config, name, loaded_weight
636+
)
586637
weight_loader(param, loaded_weight)
587638

588639

python/sglang/srt/models/glm4v_moe.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
get_tensor_model_parallel_world_size,
1212
)
1313
from sglang.srt.hf_transformers_utils import get_processor
14+
from sglang.srt.layers.attention import vision_utils
1415
from sglang.srt.layers.logits_processor import LogitsProcessor
1516
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
1617
from sglang.srt.layers.pooler import Pooler, PoolingType
@@ -40,6 +41,7 @@ def __init__(
4041

4142
config.moe_layer_freq = 1
4243
self.config = config
44+
vision_utils.update_vit_attn_dummy_heads_config(self.config)
4345
self.tp_size = get_tensor_model_parallel_world_size()
4446
self.quant_config = quant_config
4547
self.determine_num_fused_shared_experts("Glm4MoeForCausalLM")
@@ -385,6 +387,10 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]], is_nextn=Fal
385387
weight_loader = getattr(
386388
param, "weight_loader", default_weight_loader
387389
)
390+
if "visual" in name:
391+
loaded_weight = vision_utils.pad_vit_attn_dummy_heads(
392+
self.config, name, loaded_weight
393+
)
388394
weight_loader(param, loaded_weight)
389395

390396

python/sglang/srt/models/interns1.py

Lines changed: 5 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from torch import nn
55
from transformers import PretrainedConfig
66

7-
from sglang.srt.distributed import parallel_state
7+
from sglang.srt.layers.attention import vision_utils
88
from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class
99
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
1010
from sglang.srt.layers.quantization.base_config import QuantizationConfig
@@ -35,7 +35,7 @@ def __init__(
3535
super().__init__()
3636
self.config = config
3737
self.quant_config = quant_config
38-
self._update_hf_config()
38+
vision_utils.update_vit_attn_dummy_heads_config(self.config)
3939
image_size = (
4040
getattr(config, "force_image_size", None) or config.vision_config.image_size
4141
)
@@ -87,21 +87,6 @@ def __init__(
8787
nn.Linear(llm_hidden_size, llm_hidden_size),
8888
)
8989

90-
def _update_hf_config(self):
91-
"""update hf config to support tp"""
92-
world_size = parallel_state.get_tensor_model_parallel_world_size()
93-
num_heads = self.config.vision_config.num_attention_heads
94-
head_dim = self.config.vision_config.hidden_size // num_heads
95-
num_dummy_heads = 0
96-
97-
if num_heads % world_size != 0:
98-
num_dummy_heads = (
99-
(num_heads + world_size) // world_size
100-
) * world_size - num_heads
101-
102-
setattr(self.config.vision_config, "head_dim", head_dim)
103-
setattr(self.config.vision_config, "num_dummy_heads", num_dummy_heads)
104-
10590
def pixel_shuffle(self, x, scale_factor=0.5):
10691
n, w, h, c = x.size()
10792
# N, W, H, C --> N, W, H * scale, C // scale
@@ -184,34 +169,6 @@ def pad_input_ids(self, input_ids: List[int], mm_inputs: MultimodalInputs):
184169

185170
return helper.pad_input_tokens(input_ids, mm_inputs)
186171

187-
def _pad_vit_attn_dummy_heads(self, name: str, loaded_weight: torch.Tensor):
188-
"""pad attn qkv weights for dummy heads"""
189-
num_dummy_heads = self.config.vision_config.num_dummy_heads
190-
if num_dummy_heads == 0:
191-
return loaded_weight
192-
head_dim = self.config.vision_config.head_dim
193-
194-
if any([_ in name for _ in ["attn.q_proj", "attn.k_proj", "attn.v_proj"]]):
195-
if name.endswith(".weight"):
196-
dummy_shape = [num_dummy_heads, head_dim, loaded_weight.shape[-1]]
197-
elif name.endswith(".bias"):
198-
dummy_shape = [num_dummy_heads, head_dim]
199-
else:
200-
raise RuntimeError(f"Unsupported weight with name={name}")
201-
padded_weight = loaded_weight.new_zeros(dummy_shape)
202-
loaded_weight = torch.cat(
203-
[loaded_weight.unflatten(0, (-1, head_dim)), padded_weight], dim=0
204-
).flatten(0, 1)
205-
if "attn.proj.weight" in name:
206-
padded_weight = loaded_weight.new_zeros(
207-
loaded_weight.shape[0], head_dim * num_dummy_heads
208-
)
209-
loaded_weight = torch.cat([loaded_weight, padded_weight], dim=-1)
210-
if "attn.q_norm.weight" in name or "attn.k_norm.weight" in name:
211-
padded_weight = loaded_weight.new_zeros(head_dim * num_dummy_heads)
212-
loaded_weight = torch.cat([loaded_weight, padded_weight], dim=0)
213-
return loaded_weight
214-
215172
def _mapping_interns1_name(self, name):
216173
names_map = {
217174
"lm_head.weight": "language_model.lm_head.weight",
@@ -270,7 +227,9 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
270227
continue
271228
name = self._mapping_interns1_name(name)
272229
if "vision_model" in name:
273-
loaded_weight = self._pad_vit_attn_dummy_heads(name, loaded_weight)
230+
loaded_weight = vision_utils.pad_vit_attn_dummy_heads(
231+
self.config, name, loaded_weight
232+
)
274233

275234
for param_name, weight_name, shard_id in stacked_params_mapping:
276235
if weight_name not in name:

0 commit comments

Comments
 (0)