Skip to content

Commit 28f7f6f

Browse files
committed
Register allgather/reducescatter buffers with symm memory
1 parent 0cf3fbe commit 28f7f6f

File tree

13 files changed

+181
-68
lines changed

13 files changed

+181
-68
lines changed

python/sglang/srt/distributed/device_communicators/pynccl_allocator.py

Lines changed: 27 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
from torch.cuda.memory import CUDAPluggableAllocator
66

77
from sglang.srt.distributed.parallel_state import GroupCoordinator
8-
from sglang.srt.managers.schedule_batch import global_server_args_dict
98

109
nccl_allocator_source = """
1110
#include <nccl.h>
@@ -28,13 +27,21 @@
2827
_allocator = None
2928
_mem_pool = None
3029
_registered_base_addrs = set()
30+
_registered_tensor_addrs = set()
3131
_graph_pool_id = None
3232

3333

3434
def is_symmetric_memory_enabled():
35+
# Import here to avoid circular import
36+
from sglang.srt.managers.schedule_batch import global_server_args_dict
37+
3538
return global_server_args_dict["enable_symm_mem"]
3639

3740

41+
def is_symmetric_memory_tensor(tensor: torch.Tensor):
42+
return tensor.untyped_storage().data_ptr() in _registered_tensor_addrs
43+
44+
3845
def set_graph_pool_id(graph_pool_id):
3946
global _graph_pool_id
4047
_graph_pool_id = graph_pool_id
@@ -64,8 +71,18 @@ def get_nccl_mem_pool():
6471

6572

6673
class use_symmetric_memory:
67-
def __init__(self, group_coordinator: GroupCoordinator):
68-
if not is_symmetric_memory_enabled():
74+
def __init__(
75+
self,
76+
group_coordinator: GroupCoordinator,
77+
disabled: bool = False,
78+
disable_war: bool = False,
79+
):
80+
self.disabled = (
81+
disabled
82+
or not is_symmetric_memory_enabled()
83+
or group_coordinator.world_size == 1
84+
)
85+
if self.disabled:
6986
self.group_coordinator = None
7087
self._mem_pool_ctx = None
7188
self.is_graph_capture = None
@@ -77,9 +94,10 @@ def __init__(self, group_coordinator: GroupCoordinator):
7794
self.is_graph_capture = torch.cuda.is_current_stream_capturing()
7895
self.device = torch.cuda.current_device()
7996
self.pre_2_8_0 = version.parse(torch.__version__) < version.parse("2.8.0")
97+
self.disable_war = disable_war
8098

8199
def __enter__(self):
82-
if not is_symmetric_memory_enabled():
100+
if self.disabled:
83101
return self
84102
assert (
85103
self.group_coordinator.pynccl_comm is not None
@@ -102,18 +120,19 @@ def __enter__(self):
102120
return self
103121

104122
def tag(self, tensor: torch.Tensor):
105-
if not is_symmetric_memory_enabled():
123+
if self.disabled:
106124
return
107-
tensor.symmetric_memory = True
125+
global _registered_tensor_addrs
126+
_registered_tensor_addrs.add(tensor.untyped_storage().data_ptr())
108127

109128
def __exit__(self, exc_type, exc_val, exc_tb):
110-
if not is_symmetric_memory_enabled():
129+
if self.disabled:
111130
return
112131
global _registered_base_addrs
113132
self._mem_pool_ctx.__exit__(exc_type, exc_val, exc_tb)
114133
for segment in get_nccl_mem_pool().snapshot():
115134
if segment["address"] not in _registered_base_addrs:
116-
if segment["stream"] == 0 and self.pre_2_8_0:
135+
if segment["stream"] == 0 and self.pre_2_8_0 and not self.disable_war:
117136
# PyTorch version < 2.8.0 has a multi-thread MemPool bug
118137
# See https://github.com/pytorch/pytorch/issues/152861
119138
# Fixed at https://github.com/pytorch/pytorch/commit/f01e628e3b31852983ab30b25bf251f557ba9c0b

python/sglang/srt/distributed/parallel_state.py

Lines changed: 40 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -274,7 +274,13 @@ def __init__(
274274
from sglang.srt.distributed.device_communicators.pynccl import (
275275
PyNcclCommunicator,
276276
)
277+
from sglang.srt.distributed.device_communicators.pynccl_allocator import (
278+
is_symmetric_memory_tensor,
279+
use_symmetric_memory,
280+
)
277281

282+
self.is_symmetric_memory_tensor = is_symmetric_memory_tensor
283+
self.use_symmetric_memory = use_symmetric_memory
278284
if is_hip():
279285
from sglang.srt.distributed.device_communicators.quick_all_reduce import (
280286
QuickAllReduce,
@@ -503,11 +509,7 @@ def all_reduce(self, input_: torch.Tensor) -> torch.Tensor:
503509
if self.npu_communicator is not None and not self.npu_communicator.disabled:
504510
return self.npu_communicator.all_reduce(input_)
505511

506-
if (
507-
self.pynccl_comm is not None
508-
and hasattr(input_, "symmetric_memory")
509-
and input_.symmetric_memory
510-
):
512+
if self.pynccl_comm is not None and self.is_symmetric_memory_tensor(input_):
511513
with self.pynccl_comm.change_state(
512514
enable=True, stream=torch.cuda.current_stream()
513515
):
@@ -573,9 +575,23 @@ def reduce_scatter_tensor(
573575
self,
574576
output: torch.Tensor,
575577
input: torch.Tensor,
576-
) -> None:
577-
# TODO(ch-wan): support other backends
578-
torch.distributed.reduce_scatter_tensor(output, input, group=self.device_group)
578+
) -> torch.Tensor:
579+
pynccl_comm = self.pynccl_comm
580+
if pynccl_comm is not None and (
581+
not pynccl_comm.disabled
582+
or (
583+
self.is_symmetric_memory_tensor(output)
584+
and self.is_symmetric_memory_tensor(input)
585+
)
586+
):
587+
with pynccl_comm.change_state(
588+
enable=True, stream=torch.cuda.current_stream()
589+
):
590+
pynccl_comm.reduce_scatter(output, input)
591+
else:
592+
torch.distributed.reduce_scatter_tensor(
593+
output, input, group=self.device_group
594+
)
579595
return output
580596

581597
def reduce_scatter(
@@ -622,8 +638,17 @@ def reduce_scatterv(
622638

623639
def _all_gather_into_tensor(self, output: torch.Tensor, input: torch.Tensor):
624640
pynccl_comm = self.pynccl_comm
625-
if pynccl_comm is not None and not pynccl_comm.disabled:
626-
pynccl_comm.all_gather(output, input)
641+
if pynccl_comm is not None and (
642+
not pynccl_comm.disabled
643+
or (
644+
self.is_symmetric_memory_tensor(output)
645+
and self.is_symmetric_memory_tensor(input)
646+
)
647+
):
648+
with pynccl_comm.change_state(
649+
enable=True, stream=torch.cuda.current_stream()
650+
):
651+
pynccl_comm.all_gather(output, input)
627652
else:
628653
torch.distributed.all_gather_into_tensor(
629654
output, input, group=self.device_group
@@ -685,9 +710,11 @@ def all_gather(
685710
# torch.compile . see https://github.com/pytorch/pytorch/issues/138795
686711
output_size = (input_size[0] * world_size,) + input_size[1:]
687712
# Allocate output tensor.
688-
output_tensor = torch.empty(
689-
output_size, dtype=input_.dtype, device=input_.device
690-
)
713+
with self.use_symmetric_memory(self) as sm:
714+
output_tensor = torch.empty(
715+
output_size, dtype=input_.dtype, device=input_.device
716+
)
717+
sm.tag(output_tensor)
691718

692719
# All-gather.
693720
if input_.is_cpu and is_shm_available(

python/sglang/srt/layers/communicator.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,12 @@
2121

2222
from sglang.srt.distributed import (
2323
get_tensor_model_parallel_world_size,
24+
get_tp_group,
2425
tensor_model_parallel_all_reduce,
2526
)
27+
from sglang.srt.distributed.device_communicators.pynccl_allocator import (
28+
use_symmetric_memory,
29+
)
2630
from sglang.srt.layers.dp_attention import (
2731
attn_tp_all_gather_into_tensor,
2832
attn_tp_reduce_scatter_tensor,
@@ -430,7 +434,13 @@ def _gather_hidden_states_and_residual(
430434
use_layer_norm_before_gather = context.attn_tp_size == 1
431435
if use_layer_norm_before_gather and hidden_states.shape[0] != 0:
432436
residual = hidden_states
433-
hidden_states = layernorm(hidden_states)
437+
with use_symmetric_memory(
438+
get_tp_group(),
439+
disabled=not forward_batch.dp_padding_mode.is_max_len(),
440+
) as sm:
441+
hidden_states = layernorm(hidden_states)
442+
sm.tag(hidden_states)
443+
434444
hidden_states, local_hidden_states = (
435445
get_global_dp_buffer(),
436446
hidden_states,

python/sglang/srt/layers/dp_attention.py

Lines changed: 39 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,9 @@
1717
get_tp_group,
1818
tensor_model_parallel_all_reduce,
1919
)
20+
from sglang.srt.distributed.device_communicators.pynccl_allocator import (
21+
use_symmetric_memory,
22+
)
2023

2124
if TYPE_CHECKING:
2225
from sglang.srt.configs.model_config import ModelConfig
@@ -73,6 +76,7 @@ class _DpGatheredBufferWrapper:
7376
_global_dp_buffer_len: int
7477
_local_dp_buffer_len: int
7578
_global_num_tokens: Optional[List[int]]
79+
_is_max_padding: bool
7680

7781
@classmethod
7882
def set_metadata(cls, hidden_size: int, dtype: torch.dtype, device: torch.device):
@@ -85,27 +89,37 @@ def set_dp_buffer_len(
8589
cls,
8690
global_dp_buffer_len: int,
8791
local_dp_buffer_len: int,
92+
is_max_padding: bool,
8893
global_num_tokens: Optional[List[int]] = None,
8994
):
9095
cls._global_dp_buffer_len = global_dp_buffer_len
9196
cls._local_dp_buffer_len = local_dp_buffer_len
97+
cls._is_max_padding = is_max_padding
9298
cls._global_num_tokens = global_num_tokens
9399

94100
@classmethod
95101
def get_global_dp_buffer(cls) -> torch.Tensor:
96-
return torch.empty(
97-
(cls._global_dp_buffer_len, cls._hidden_size),
98-
dtype=cls._dtype,
99-
device=cls._device,
100-
)
102+
with use_symmetric_memory(get_tp_group()) as sm:
103+
buffer = torch.empty(
104+
(cls._global_dp_buffer_len, cls._hidden_size),
105+
dtype=cls._dtype,
106+
device=cls._device,
107+
)
108+
sm.tag(buffer)
109+
return buffer
101110

102111
@classmethod
103112
def get_local_dp_buffer(cls) -> torch.Tensor:
104-
return torch.empty(
105-
(cls._local_dp_buffer_len, cls._hidden_size),
106-
dtype=cls._dtype,
107-
device=cls._device,
108-
)
113+
with use_symmetric_memory(
114+
get_tp_group(), disabled=not cls._is_max_padding
115+
) as sm:
116+
buffer = torch.empty(
117+
(cls._local_dp_buffer_len, cls._hidden_size),
118+
dtype=cls._dtype,
119+
device=cls._device,
120+
)
121+
sm.tag(buffer)
122+
return buffer
109123

110124
@classmethod
111125
def get_global_dp_buffer_len(cls) -> int:
@@ -120,13 +134,18 @@ def get_dp_global_num_tokens(cls) -> List[int]:
120134
return cls._global_num_tokens
121135

122136

137+
def is_max_padding(cls) -> bool:
138+
return cls._is_max_padding
139+
140+
123141
def set_dp_buffer_len(
124142
global_dp_buffer_len: int,
125143
local_dp_buffer_len: int,
144+
is_max_padding: bool,
126145
global_num_tokens: Optional[List[int]] = None,
127146
):
128147
_DpGatheredBufferWrapper.set_dp_buffer_len(
129-
global_dp_buffer_len, local_dp_buffer_len, global_num_tokens
148+
global_dp_buffer_len, local_dp_buffer_len, is_max_padding, global_num_tokens
130149
)
131150

132151

@@ -150,6 +169,10 @@ def get_dp_global_num_tokens() -> List[int]:
150169
return _DpGatheredBufferWrapper.get_dp_global_num_tokens()
151170

152171

172+
def is_max_padding() -> bool:
173+
return _DpGatheredBufferWrapper.is_max_padding()
174+
175+
153176
def compute_dp_attention_world_info(enable_dp_attention, tp_rank, tp_size, dp_size):
154177
if not enable_dp_attention:
155178
return tp_rank, tp_size, 0
@@ -408,7 +431,10 @@ def _dp_gather_via_all_gather(
408431
scattered_local_tokens = local_tokens.tensor_split(get_attention_tp_size())[
409432
get_attention_tp_rank()
410433
]
411-
get_attention_tp_group().reduce_scatter_tensor(scattered_local_tokens, local_tokens)
434+
if get_attention_tp_size() > 1:
435+
get_attention_tp_group().reduce_scatter_tensor(
436+
scattered_local_tokens, local_tokens
437+
)
412438
get_tp_group().all_gather_into_tensor(global_tokens, scattered_local_tokens)
413439

414440

@@ -467,7 +493,7 @@ def dp_scatter(
467493

468494

469495
def dp_reduce_scatter_tensor(output: torch.Tensor, input: torch.Tensor):
470-
if get_tensor_model_parallel_world_size() == get_attention_dp_size():
496+
if get_attention_tp_size() == 1:
471497
get_tp_group().reduce_scatter_tensor(output, input)
472498
else:
473499
scattered_local_tokens = input.tensor_split(

python/sglang/srt/layers/linear.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1301,7 +1301,7 @@ def weight_loader_v2(self, param: BasevLLMParameter, loaded_weight: torch.Tensor
13011301
# It does not support additional parameters.
13021302
param.load_row_parallel_weight(loaded_weight)
13031303

1304-
def forward(self, input_, skip_all_reduce=False):
1304+
def forward(self, input_, skip_all_reduce=False, disable_symmetric_memory=True):
13051305
if self.input_is_parallel:
13061306
input_parallel = input_
13071307
else:
@@ -1315,7 +1315,9 @@ def forward(self, input_, skip_all_reduce=False):
13151315
# Only fuse bias add into GEMM for rank 0 (this ensures that
13161316
# bias will not get added more than once in TP>1 case)
13171317
bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias
1318-
with use_symmetric_memory(parallel_state.get_tp_group()) as sm:
1318+
with use_symmetric_memory(
1319+
parallel_state.get_tp_group(), disabled=disable_symmetric_memory
1320+
) as sm:
13191321
output_parallel = self.quant_method.apply(self, input_parallel, bias=bias_)
13201322
sm.tag(output_parallel)
13211323

python/sglang/srt/layers/logits_processor.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -194,6 +194,7 @@ def compute_dp_attention_metadata(self):
194194
set_dp_buffer_len(
195195
self.global_dp_buffer_len,
196196
self.dp_local_num_tokens,
197+
False,
197198
self.global_num_tokens_for_logprob_cpu,
198199
)
199200

python/sglang/srt/layers/moe/fused_moe_triton/layer.py

Lines changed: 6 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,8 @@
1111
get_moe_expert_parallel_world_size,
1212
get_moe_tensor_parallel_rank,
1313
get_moe_tensor_parallel_world_size,
14-
get_tp_group,
1514
tensor_model_parallel_all_reduce,
1615
)
17-
from sglang.srt.distributed.device_communicators.pynccl_allocator import (
18-
use_symmetric_memory,
19-
)
2016
from sglang.srt.eplb.expert_location import get_global_expert_location_metadata
2117
from sglang.srt.layers.moe import (
2218
MoeRunnerConfig,
@@ -812,15 +808,12 @@ def forward(self, hidden_states: torch.Tensor, topk_output: TopKOutput):
812808
raise NotImplementedError()
813809

814810
# Matrix multiply.
815-
with use_symmetric_memory(get_tp_group()) as sm:
816-
817-
final_hidden_states = self.quant_method.apply(
818-
layer=self,
819-
x=hidden_states,
820-
topk_output=topk_output,
821-
moe_runner_config=self.moe_runner_config,
822-
)
823-
sm.tag(final_hidden_states)
811+
final_hidden_states = self.quant_method.apply(
812+
layer=self,
813+
x=hidden_states,
814+
topk_output=topk_output,
815+
moe_runner_config=self.moe_runner_config,
816+
)
824817

825818
final_hidden_states = final_hidden_states[
826819
..., :origin_hidden_states_dim

0 commit comments

Comments
 (0)