27
27
get_tensor_model_parallel_world_size ,
28
28
tensor_model_parallel_all_reduce ,
29
29
)
30
+ from sglang .srt .layers .dp_attention import (
31
+ dp_gather_partial ,
32
+ dp_scatter ,
33
+ get_attention_dp_size ,
34
+ get_attention_tp_rank ,
35
+ get_attention_tp_size ,
36
+ )
30
37
from sglang .srt .layers .layernorm import RMSNorm
31
38
from sglang .srt .layers .linear import (
32
39
QKVParallelLinear ,
38
45
from sglang .srt .layers .radix_attention import RadixAttention
39
46
from sglang .srt .layers .rotary_embedding import get_rope
40
47
from sglang .srt .layers .vocab_parallel_embedding import VocabParallelEmbedding
48
+ from sglang .srt .managers .schedule_batch import global_server_args_dict
41
49
from sglang .srt .model_executor .forward_batch_info import ForwardBatch
42
50
from sglang .srt .models .llama import LlamaForCausalLM , LlamaMLP
43
51
from sglang .srt .utils import add_prefix , get_compiler_backend , make_layers
@@ -143,20 +151,24 @@ def __init__(
143
151
self .hidden_size = hidden_size
144
152
self .use_rope = int ((layer_id + 1 ) % 4 != 0 )
145
153
self .use_qk_norm = config .use_qk_norm and self .use_rope
146
- tp_size = get_tensor_model_parallel_world_size ()
154
+
155
+ self .dp_size = get_attention_dp_size ()
156
+ attn_tp_rank = get_attention_tp_rank ()
157
+ attn_tp_size = get_attention_tp_size ()
158
+
147
159
self .total_num_heads = num_heads
148
- assert self .total_num_heads % tp_size == 0
149
- self .num_heads = self .total_num_heads // tp_size
160
+ assert self .total_num_heads % attn_tp_size == 0
161
+ self .num_heads = self .total_num_heads // attn_tp_size
150
162
self .total_num_kv_heads = num_kv_heads
151
- if self .total_num_kv_heads >= tp_size :
163
+ if self .total_num_kv_heads >= attn_tp_size :
152
164
# Number of KV heads is greater than TP size, so we partition
153
165
# the KV heads across multiple tensor parallel GPUs.
154
- assert self .total_num_kv_heads % tp_size == 0
166
+ assert self .total_num_kv_heads % attn_tp_size == 0
155
167
else :
156
168
# Number of KV heads is less than TP size, so we replicate
157
169
# the KV heads across multiple tensor parallel GPUs.
158
- assert tp_size % self .total_num_kv_heads == 0
159
- self .num_kv_heads = max (1 , self .total_num_kv_heads // tp_size )
170
+ assert attn_tp_size % self .total_num_kv_heads == 0
171
+ self .num_kv_heads = max (1 , self .total_num_kv_heads // attn_tp_size )
160
172
self .head_dim = config .head_dim
161
173
self .q_size = self .num_heads * self .head_dim
162
174
self .kv_size = self .num_kv_heads * self .head_dim
@@ -183,6 +195,8 @@ def __init__(
183
195
bias = bias ,
184
196
quant_config = quant_config ,
185
197
prefix = add_prefix ("qkv_proj" , prefix ),
198
+ tp_rank = attn_tp_rank ,
199
+ tp_size = attn_tp_size ,
186
200
)
187
201
188
202
self .o_proj = RowParallelLinear (
@@ -191,6 +205,9 @@ def __init__(
191
205
bias = bias_o_proj ,
192
206
quant_config = quant_config ,
193
207
prefix = add_prefix ("o_proj" , prefix ),
208
+ tp_rank = attn_tp_rank ,
209
+ tp_size = attn_tp_size ,
210
+ reduce_results = False ,
194
211
)
195
212
is_neox_style = True
196
213
is_gguf = quant_config and quant_config .get_name () == "gguf"
@@ -274,6 +291,9 @@ def __init__(
274
291
rope_theta = config .rope_theta
275
292
rope_scaling = config .rope_scaling
276
293
max_position_embeddings = config .max_position_embeddings
294
+ self .dp_size = get_attention_dp_size ()
295
+ self .attn_tp_size = get_attention_tp_size ()
296
+ self .attn_tp_rank = get_attention_tp_rank ()
277
297
278
298
self .self_attn = Llama4Attention (
279
299
config = config ,
@@ -316,21 +336,58 @@ def forward(
316
336
forward_batch : ForwardBatch ,
317
337
residual : Optional [torch .Tensor ],
318
338
) -> Tuple [torch .Tensor , torch .Tensor ]:
319
- # Self Attention
320
- if residual is None :
339
+ if hidden_states .shape [0 ] == 0 :
321
340
residual = hidden_states
322
- hidden_states = self .input_layernorm (hidden_states )
323
341
else :
324
- hidden_states , residual = self .input_layernorm (hidden_states , residual )
325
- hidden_states = self .self_attn (
326
- positions = positions ,
327
- hidden_states = hidden_states ,
328
- forward_batch = forward_batch ,
329
- )
342
+ # Self Attention
343
+ if residual is None :
344
+ residual = hidden_states
345
+ hidden_states = self .input_layernorm (hidden_states )
346
+ else :
347
+ hidden_states , residual = self .input_layernorm (hidden_states , residual )
348
+ hidden_states = self .self_attn (
349
+ positions = positions ,
350
+ hidden_states = hidden_states ,
351
+ forward_batch = forward_batch ,
352
+ )
353
+
354
+ # Gather
355
+ if get_tensor_model_parallel_world_size () > 1 :
356
+ # all gather and all reduce
357
+ if self .dp_size != 1 :
358
+ if self .attn_tp_rank == 0 :
359
+ hidden_states += residual
360
+ hidden_states , local_hidden_states = (
361
+ forward_batch .gathered_buffer ,
362
+ hidden_states ,
363
+ )
364
+ dp_gather_partial (hidden_states , local_hidden_states , forward_batch )
365
+ dp_scatter (residual , hidden_states , forward_batch )
366
+ hidden_states = self .post_attention_layernorm (hidden_states )
367
+ else :
368
+ hidden_states = tensor_model_parallel_all_reduce (hidden_states )
369
+ hidden_states , residual = self .post_attention_layernorm (
370
+ hidden_states , residual
371
+ )
372
+ else :
373
+ hidden_states , residual = self .post_attention_layernorm (
374
+ hidden_states , residual
375
+ )
330
376
331
377
# Fully Connected
332
- hidden_states , residual = self .post_attention_layernorm (hidden_states , residual )
333
378
hidden_states = self .feed_forward (hidden_states )
379
+
380
+ # TODO(ch-wan): ues reduce-scatter in MLP to avoid this scatter
381
+ # Scatter
382
+ if self .dp_size != 1 :
383
+ # important: forward batch.gathered_buffer is used both after scatter and after gather.
384
+ # be careful about this!
385
+ hidden_states , global_hidden_states = (
386
+ forward_batch .gathered_buffer [: forward_batch .input_ids .shape [0 ]],
387
+ hidden_states ,
388
+ )
389
+ dp_scatter (hidden_states , global_hidden_states , forward_batch )
390
+
334
391
return hidden_states , residual
335
392
336
393
@@ -350,6 +407,7 @@ def __init__(
350
407
config .hidden_size ,
351
408
quant_config = quant_config ,
352
409
prefix = add_prefix ("embed_tokens" , prefix ),
410
+ enable_tp = not global_server_args_dict ["enable_dp_attention" ],
353
411
)
354
412
self .layers = make_layers (
355
413
config .num_hidden_layers ,
@@ -385,7 +443,8 @@ def forward(
385
443
forward_batch ,
386
444
residual ,
387
445
)
388
- hidden_states , _ = self .norm (hidden_states , residual )
446
+ if not forward_batch .forward_mode .is_idle ():
447
+ hidden_states , _ = self .norm (hidden_states , residual )
389
448
390
449
if len (aux_hidden_states ) == 0 :
391
450
return hidden_states
@@ -394,7 +453,6 @@ def forward(
394
453
395
454
396
455
class Llama4ForCausalLM (LlamaForCausalLM ):
397
-
398
456
packed_modules_mapping = {
399
457
"qkv_proj" : ["q_proj" , "k_proj" , "v_proj" ],
400
458
"gate_up_proj" : ["gate_proj" , "up_proj" ],
0 commit comments