Skip to content

Commit b47e13b

Browse files
fzyzcjyfinger92
authored andcommitted
Support 2x8xH100 for Llama 4 (sgl-project#5159)
1 parent 0be7b5a commit b47e13b

File tree

1 file changed

+77
-19
lines changed

1 file changed

+77
-19
lines changed

python/sglang/srt/models/llama4.py

Lines changed: 77 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,13 @@
2727
get_tensor_model_parallel_world_size,
2828
tensor_model_parallel_all_reduce,
2929
)
30+
from sglang.srt.layers.dp_attention import (
31+
dp_gather_partial,
32+
dp_scatter,
33+
get_attention_dp_size,
34+
get_attention_tp_rank,
35+
get_attention_tp_size,
36+
)
3037
from sglang.srt.layers.layernorm import RMSNorm
3138
from sglang.srt.layers.linear import (
3239
QKVParallelLinear,
@@ -38,6 +45,7 @@
3845
from sglang.srt.layers.radix_attention import RadixAttention
3946
from sglang.srt.layers.rotary_embedding import get_rope
4047
from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
48+
from sglang.srt.managers.schedule_batch import global_server_args_dict
4149
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
4250
from sglang.srt.models.llama import LlamaForCausalLM, LlamaMLP
4351
from sglang.srt.utils import add_prefix, get_compiler_backend, make_layers
@@ -143,20 +151,24 @@ def __init__(
143151
self.hidden_size = hidden_size
144152
self.use_rope = int((layer_id + 1) % 4 != 0)
145153
self.use_qk_norm = config.use_qk_norm and self.use_rope
146-
tp_size = get_tensor_model_parallel_world_size()
154+
155+
self.dp_size = get_attention_dp_size()
156+
attn_tp_rank = get_attention_tp_rank()
157+
attn_tp_size = get_attention_tp_size()
158+
147159
self.total_num_heads = num_heads
148-
assert self.total_num_heads % tp_size == 0
149-
self.num_heads = self.total_num_heads // tp_size
160+
assert self.total_num_heads % attn_tp_size == 0
161+
self.num_heads = self.total_num_heads // attn_tp_size
150162
self.total_num_kv_heads = num_kv_heads
151-
if self.total_num_kv_heads >= tp_size:
163+
if self.total_num_kv_heads >= attn_tp_size:
152164
# Number of KV heads is greater than TP size, so we partition
153165
# the KV heads across multiple tensor parallel GPUs.
154-
assert self.total_num_kv_heads % tp_size == 0
166+
assert self.total_num_kv_heads % attn_tp_size == 0
155167
else:
156168
# Number of KV heads is less than TP size, so we replicate
157169
# the KV heads across multiple tensor parallel GPUs.
158-
assert tp_size % self.total_num_kv_heads == 0
159-
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
170+
assert attn_tp_size % self.total_num_kv_heads == 0
171+
self.num_kv_heads = max(1, self.total_num_kv_heads // attn_tp_size)
160172
self.head_dim = config.head_dim
161173
self.q_size = self.num_heads * self.head_dim
162174
self.kv_size = self.num_kv_heads * self.head_dim
@@ -183,6 +195,8 @@ def __init__(
183195
bias=bias,
184196
quant_config=quant_config,
185197
prefix=add_prefix("qkv_proj", prefix),
198+
tp_rank=attn_tp_rank,
199+
tp_size=attn_tp_size,
186200
)
187201

188202
self.o_proj = RowParallelLinear(
@@ -191,6 +205,9 @@ def __init__(
191205
bias=bias_o_proj,
192206
quant_config=quant_config,
193207
prefix=add_prefix("o_proj", prefix),
208+
tp_rank=attn_tp_rank,
209+
tp_size=attn_tp_size,
210+
reduce_results=False,
194211
)
195212
is_neox_style = True
196213
is_gguf = quant_config and quant_config.get_name() == "gguf"
@@ -274,6 +291,9 @@ def __init__(
274291
rope_theta = config.rope_theta
275292
rope_scaling = config.rope_scaling
276293
max_position_embeddings = config.max_position_embeddings
294+
self.dp_size = get_attention_dp_size()
295+
self.attn_tp_size = get_attention_tp_size()
296+
self.attn_tp_rank = get_attention_tp_rank()
277297

278298
self.self_attn = Llama4Attention(
279299
config=config,
@@ -316,21 +336,58 @@ def forward(
316336
forward_batch: ForwardBatch,
317337
residual: Optional[torch.Tensor],
318338
) -> Tuple[torch.Tensor, torch.Tensor]:
319-
# Self Attention
320-
if residual is None:
339+
if hidden_states.shape[0] == 0:
321340
residual = hidden_states
322-
hidden_states = self.input_layernorm(hidden_states)
323341
else:
324-
hidden_states, residual = self.input_layernorm(hidden_states, residual)
325-
hidden_states = self.self_attn(
326-
positions=positions,
327-
hidden_states=hidden_states,
328-
forward_batch=forward_batch,
329-
)
342+
# Self Attention
343+
if residual is None:
344+
residual = hidden_states
345+
hidden_states = self.input_layernorm(hidden_states)
346+
else:
347+
hidden_states, residual = self.input_layernorm(hidden_states, residual)
348+
hidden_states = self.self_attn(
349+
positions=positions,
350+
hidden_states=hidden_states,
351+
forward_batch=forward_batch,
352+
)
353+
354+
# Gather
355+
if get_tensor_model_parallel_world_size() > 1:
356+
# all gather and all reduce
357+
if self.dp_size != 1:
358+
if self.attn_tp_rank == 0:
359+
hidden_states += residual
360+
hidden_states, local_hidden_states = (
361+
forward_batch.gathered_buffer,
362+
hidden_states,
363+
)
364+
dp_gather_partial(hidden_states, local_hidden_states, forward_batch)
365+
dp_scatter(residual, hidden_states, forward_batch)
366+
hidden_states = self.post_attention_layernorm(hidden_states)
367+
else:
368+
hidden_states = tensor_model_parallel_all_reduce(hidden_states)
369+
hidden_states, residual = self.post_attention_layernorm(
370+
hidden_states, residual
371+
)
372+
else:
373+
hidden_states, residual = self.post_attention_layernorm(
374+
hidden_states, residual
375+
)
330376

331377
# Fully Connected
332-
hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
333378
hidden_states = self.feed_forward(hidden_states)
379+
380+
# TODO(ch-wan): ues reduce-scatter in MLP to avoid this scatter
381+
# Scatter
382+
if self.dp_size != 1:
383+
# important: forward batch.gathered_buffer is used both after scatter and after gather.
384+
# be careful about this!
385+
hidden_states, global_hidden_states = (
386+
forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
387+
hidden_states,
388+
)
389+
dp_scatter(hidden_states, global_hidden_states, forward_batch)
390+
334391
return hidden_states, residual
335392

336393

@@ -350,6 +407,7 @@ def __init__(
350407
config.hidden_size,
351408
quant_config=quant_config,
352409
prefix=add_prefix("embed_tokens", prefix),
410+
enable_tp=not global_server_args_dict["enable_dp_attention"],
353411
)
354412
self.layers = make_layers(
355413
config.num_hidden_layers,
@@ -385,7 +443,8 @@ def forward(
385443
forward_batch,
386444
residual,
387445
)
388-
hidden_states, _ = self.norm(hidden_states, residual)
446+
if not forward_batch.forward_mode.is_idle():
447+
hidden_states, _ = self.norm(hidden_states, residual)
389448

390449
if len(aux_hidden_states) == 0:
391450
return hidden_states
@@ -394,7 +453,6 @@ def forward(
394453

395454

396455
class Llama4ForCausalLM(LlamaForCausalLM):
397-
398456
packed_modules_mapping = {
399457
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
400458
"gate_up_proj": ["gate_proj", "up_proj"],

0 commit comments

Comments
 (0)