Skip to content

Commit 66b3a8c

Browse files
committed
Add fp4 allgather support
1 parent 28f7f6f commit 66b3a8c

File tree

5 files changed

+65
-25
lines changed

5 files changed

+65
-25
lines changed

python/sglang/srt/distributed/parallel_state.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -768,9 +768,11 @@ def _all_gather_single(
768768
else:
769769
output_size = (input_size[0] * world_size,) + input_size[1:]
770770
# Allocate output tensor.
771-
output_tensor = torch.empty(
772-
output_size, dtype=input_.dtype, device=input_.device
773-
)
771+
with self.use_symmetric_memory(self, disabled=sizes is not None) as sm:
772+
output_tensor = torch.empty(
773+
output_size, dtype=input_.dtype, device=input_.device
774+
)
775+
sm.tag(output_tensor)
774776
pynccl_comm.all_gather(output_tensor, input_, sizes=sizes)
775777
return output_tensor
776778

python/sglang/srt/layers/moe/topk.py

Lines changed: 23 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -31,12 +31,17 @@
3131
import torch.nn.functional as F
3232

3333
from sglang.srt.custom_op import CustomOp
34+
from sglang.srt.distributed import get_tp_group
35+
from sglang.srt.distributed.device_communicators.pynccl_allocator import (
36+
use_symmetric_memory,
37+
)
3438
from sglang.srt.eplb import expert_location_dispatch
3539
from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder
3640
from sglang.srt.eplb.expert_location_dispatch import (
3741
ExpertLocationDispatchInfo,
3842
topk_ids_logical_to_physical,
3943
)
44+
from sglang.srt.layers.dp_attention import is_max_padding
4045
from sglang.srt.layers.moe import (
4146
get_moe_runner_backend,
4247
should_use_flashinfer_trtllm_moe,
@@ -265,13 +270,19 @@ def forward_cuda(
265270
)
266271
else:
267272
self.topk_config.torch_native = False
268-
return select_experts(
269-
hidden_states=hidden_states,
270-
router_logits=router_logits,
271-
topk_config=self.topk_config,
272-
num_token_non_padded=num_token_non_padded,
273-
expert_location_dispatch_info=expert_location_dispatch_info,
274-
)
273+
with use_symmetric_memory(
274+
get_tp_group(), disabled=not is_max_padding()
275+
) as sm:
276+
topk_output = select_experts(
277+
hidden_states=hidden_states,
278+
router_logits=router_logits,
279+
topk_config=self.topk_config,
280+
num_token_non_padded=num_token_non_padded,
281+
expert_location_dispatch_info=expert_location_dispatch_info,
282+
)
283+
sm.tag(topk_output.topk_weights)
284+
sm.tag(topk_output.topk_ids)
285+
return topk_output
275286

276287
def forward_cpu(
277288
self,
@@ -329,8 +340,11 @@ def forward_npu(
329340

330341
def empty_topk_output(self, device: torch.device) -> TopKOutput:
331342
topk = self.topk_config.top_k - self.topk_config.num_fused_shared_experts
332-
topk_weights = torch.empty((0, topk), dtype=torch.float32, device=device)
333-
topk_idx = torch.full((0, topk), -1, dtype=torch.int32, device=device)
343+
with use_symmetric_memory(get_tp_group(), disabled=not is_max_padding()) as sm:
344+
topk_weights = torch.empty((0, topk), dtype=torch.float32, device=device)
345+
topk_idx = torch.full((0, topk), -1, dtype=torch.int32, device=device)
346+
sm.tag(topk_weights)
347+
sm.tag(topk_idx)
334348
router_logits = torch.empty((0, topk), dtype=torch.float32, device=device)
335349
return StandardTopKOutput(topk_weights, topk_idx, router_logits)
336350

python/sglang/srt/layers/quantization/modelopt_quant.py

Lines changed: 31 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,14 @@
88
from torch.nn.parameter import Parameter
99

1010
from sglang.srt.distributed import get_tp_group
11-
from sglang.srt.layers.dp_attention import get_dp_global_num_tokens, get_local_dp_buffer
11+
from sglang.srt.distributed.device_communicators.pynccl_allocator import (
12+
use_symmetric_memory,
13+
)
14+
from sglang.srt.layers.dp_attention import (
15+
get_dp_global_num_tokens,
16+
get_local_dp_buffer,
17+
is_max_padding,
18+
)
1219
from sglang.srt.layers.moe import (
1320
should_use_flashinfer_cutlass_moe_fp4_allgather,
1421
should_use_flashinfer_trtllm_moe,
@@ -1268,22 +1275,35 @@ def apply(
12681275
from flashinfer import fp4_quantize, nvfp4_block_scale_interleave
12691276

12701277
# Quantize before comm, swizzle after.
1271-
if x.shape[0] > 0:
1272-
x, x_sf = fp4_quantize(
1273-
x, layer.w13_input_scale_quant, is_sf_swizzled_layout=False
1274-
)
1275-
else:
1276-
x_col = x.shape[1]
1277-
x = torch.zeros(0, x_col // 2, dtype=torch.uint8, device=x.device)
1278-
x_sf = torch.zeros(
1279-
0, x_col // 16, dtype=torch.uint8, device=x.device
1280-
)
1278+
with use_symmetric_memory(
1279+
get_tp_group(), disabled=not is_max_padding()
1280+
) as sm:
1281+
if x.shape[0] > 0:
1282+
x, x_sf = fp4_quantize(
1283+
x, layer.w13_input_scale_quant, is_sf_swizzled_layout=False
1284+
)
1285+
else:
1286+
x_col = x.shape[1]
1287+
x = torch.zeros(
1288+
0, x_col // 2, dtype=torch.uint8, device=x.device
1289+
)
1290+
x_sf = torch.zeros(
1291+
0, x_col // 16, dtype=torch.uint8, device=x.device
1292+
)
1293+
sm.tag(x)
1294+
sm.tag(x_sf)
12811295
topk_weights, topk_ids, x, x_sf = get_tp_group().all_gatherv(
12821296
[topk_weights, topk_ids, x, x_sf], sizes=get_dp_global_num_tokens()
12831297
)
12841298
x_sf = nvfp4_block_scale_interleave(x_sf)
12851299

1300+
with use_symmetric_memory(
1301+
get_tp_group(), disabled=not is_max_padding()
1302+
) as sm:
1303+
symm_output = torch.empty_like(x)
1304+
sm.tag(symm_output)
12861305
output = flashinfer_cutlass_fused_moe(
1306+
output=symm_output,
12871307
input=x,
12881308
token_selected_experts=topk_ids.to(torch.int),
12891309
token_final_scales=topk_weights,

python/sglang/srt/model_executor/graph_runner.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -638,7 +638,11 @@ def capture_one_batch_size(self, bs: int, forward: Callable):
638638
def run_once():
639639
# Clean intermediate result cache for DP attention
640640
forward_batch.dp_local_start_pos = forward_batch.dp_local_num_tokens = None
641-
set_dp_buffer_len(global_dp_buffer_len, num_tokens)
641+
set_dp_buffer_len(
642+
global_dp_buffer_len,
643+
num_tokens,
644+
forward_batch.dp_padding_mode.is_max_len(),
645+
)
642646

643647
kwargs = {}
644648
if (

python/sglang/srt/operations.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,8 +96,8 @@ def next(self):
9696
set_dp_buffer_len(
9797
self._global_dp_buffer_len,
9898
self._local_dp_buffer_len,
99-
self._global_num_tokens,
10099
self._is_max_padding,
100+
self._global_num_tokens,
101101
)
102102

103103
with _annotate_region(debug_name=f"{self._debug_name}{self._index}"):

0 commit comments

Comments
 (0)