Skip to content

Commit 2ab9702

Browse files
authored
[router] add different policies for p node and d node (sgl-project#8395)
1 parent 0bcc195 commit 2ab9702

File tree

10 files changed

+537
-82
lines changed

10 files changed

+537
-82
lines changed

sgl-router/README.md

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,16 @@ python -m sglang_router.launch_router \
120120
--prefill-selector app=sglang component=prefill \
121121
--decode-selector app=sglang component=decode \
122122
--service-discovery-namespace sglang-system
123+
124+
# With separate routing policies:
125+
python -m sglang_router.launch_router \
126+
--pd-disaggregation \
127+
--prefill-policy cache_aware \
128+
--decode-policy power_of_two \
129+
--service-discovery \
130+
--prefill-selector app=sglang component=prefill \
131+
--decode-selector app=sglang component=decode \
132+
--service-discovery-namespace sglang-system
123133
```
124134

125135
#### Kubernetes Pod Configuration
@@ -226,7 +236,9 @@ python -m sglang_router.launch_router \
226236
- `--decode`: Initial decode server URL
227237
- `--prefill-selector`: Label selector for prefill pods
228238
- `--decode-selector`: Label selector for decode pods
229-
- `--policy`: Routing policy (`cache_aware`, `random`, `power_of_two`)
239+
- `--policy`: Routing policy (`cache_aware`, `random`, `power_of_two`, `round_robin`)
240+
- `--prefill-policy`: Separate routing policy for prefill nodes (optional, overrides `--policy` for prefill)
241+
- `--decode-policy`: Separate routing policy for decode nodes (optional, overrides `--policy` for decode)
230242

231243
## Development
232244

sgl-router/py_src/sglang_router/launch_router.py

Lines changed: 65 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,8 @@ class RouterArgs:
4040

4141
# Routing policy
4242
policy: str = "cache_aware"
43+
prefill_policy: Optional[str] = None # Specific policy for prefill nodes in PD mode
44+
decode_policy: Optional[str] = None # Specific policy for decode nodes in PD mode
4345
worker_startup_timeout_secs: int = 300
4446
worker_startup_check_interval: int = 10
4547
cache_threshold: float = 0.5
@@ -108,7 +110,21 @@ def add_cli_args(
108110
type=str,
109111
default=RouterArgs.policy,
110112
choices=["random", "round_robin", "cache_aware", "power_of_two"],
111-
help="Load balancing policy to use. Note: power_of_two is only available in PD disaggregated mode",
113+
help="Load balancing policy to use. In PD mode, this is used for both prefill and decode unless overridden",
114+
)
115+
parser.add_argument(
116+
f"--{prefix}prefill-policy",
117+
type=str,
118+
default=None,
119+
choices=["random", "round_robin", "cache_aware", "power_of_two"],
120+
help="Specific policy for prefill nodes in PD mode. If not specified, uses the main policy",
121+
)
122+
parser.add_argument(
123+
f"--{prefix}decode-policy",
124+
type=str,
125+
default=None,
126+
choices=["random", "round_robin", "cache_aware", "power_of_two"],
127+
help="Specific policy for decode nodes in PD mode. If not specified, uses the main policy",
112128
)
113129

114130
# PD-specific arguments
@@ -266,6 +282,8 @@ def from_cli_args(
266282
prefill_urls=prefill_urls,
267283
decode_urls=decode_urls,
268284
policy=getattr(args, f"{prefix}policy"),
285+
prefill_policy=getattr(args, f"{prefix}prefill_policy", None),
286+
decode_policy=getattr(args, f"{prefix}decode_policy", None),
269287
worker_startup_timeout_secs=getattr(
270288
args, f"{prefix}worker_startup_timeout_secs"
271289
),
@@ -389,6 +407,35 @@ def launch_router(args: argparse.Namespace) -> Optional[Router]:
389407
if not router_args.decode_urls:
390408
raise ValueError("PD disaggregation mode requires --decode")
391409

410+
# Warn about policy usage in PD mode
411+
if (
412+
router_args.prefill_policy
413+
and router_args.decode_policy
414+
and router_args.policy
415+
):
416+
logger.warning(
417+
"Both --prefill-policy and --decode-policy are specified. "
418+
"The main --policy flag will be ignored for PD mode."
419+
)
420+
elif (
421+
router_args.prefill_policy
422+
and not router_args.decode_policy
423+
and router_args.policy
424+
):
425+
logger.info(
426+
f"Using --prefill-policy '{router_args.prefill_policy}' for prefill nodes "
427+
f"and --policy '{router_args.policy}' for decode nodes."
428+
)
429+
elif (
430+
router_args.decode_policy
431+
and not router_args.prefill_policy
432+
and router_args.policy
433+
):
434+
logger.info(
435+
f"Using --policy '{router_args.policy}' for prefill nodes "
436+
f"and --decode-policy '{router_args.decode_policy}' for decode nodes."
437+
)
438+
392439
# Create router with unified constructor
393440
router = Router(
394441
worker_urls=(
@@ -424,6 +471,16 @@ def launch_router(args: argparse.Namespace) -> Optional[Router]:
424471
decode_urls=(
425472
router_args.decode_urls if router_args.pd_disaggregation else None
426473
),
474+
prefill_policy=(
475+
policy_from_str(router_args.prefill_policy)
476+
if router_args.prefill_policy
477+
else None
478+
),
479+
decode_policy=(
480+
policy_from_str(router_args.decode_policy)
481+
if router_args.decode_policy
482+
else None
483+
),
427484
)
428485

429486
router.start()
@@ -455,12 +512,18 @@ def parse_router_args(args: List[str]) -> RouterArgs:
455512
# Regular mode
456513
python -m sglang_router.launch_router --worker-urls http://worker1:8000 http://worker2:8000
457514
458-
# PD disaggregated mode
515+
# PD disaggregated mode with same policy for both
459516
python -m sglang_router.launch_router --pd-disaggregation \\
460517
--prefill http://prefill1:8000 9000 --prefill http://prefill2:8000 none \\
461518
--decode http://decode1:8001 --decode http://decode2:8001 \\
462519
--policy cache_aware
463520
521+
# PD mode with different policies for prefill and decode
522+
python -m sglang_router.launch_router --pd-disaggregation \\
523+
--prefill http://prefill1:8000 9000 --prefill http://prefill2:8000 none \\
524+
--decode http://decode1:8001 --decode http://decode2:8001 \\
525+
--prefill-policy cache_aware --decode-policy power_of_two
526+
464527
""",
465528
formatter_class=CustomHelpFormatter,
466529
)

sgl-router/py_src/sglang_router/router.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,10 @@ class Router:
5050
pd_disaggregation: Enable PD (Prefill-Decode) disaggregated mode. Default: False
5151
prefill_urls: List of (url, bootstrap_port) tuples for prefill servers (PD mode only)
5252
decode_urls: List of URLs for decode servers (PD mode only)
53+
prefill_policy: Specific load balancing policy for prefill nodes (PD mode only).
54+
If not specified, uses the main policy. Default: None
55+
decode_policy: Specific load balancing policy for decode nodes (PD mode only).
56+
If not specified, uses the main policy. Default: None
5357
"""
5458

5559
def __init__(
@@ -79,6 +83,8 @@ def __init__(
7983
pd_disaggregation: bool = False,
8084
prefill_urls: Optional[List[tuple]] = None,
8185
decode_urls: Optional[List[str]] = None,
86+
prefill_policy: Optional[PolicyType] = None,
87+
decode_policy: Optional[PolicyType] = None,
8288
):
8389
if selector is None:
8490
selector = {}
@@ -113,6 +119,8 @@ def __init__(
113119
pd_disaggregation=pd_disaggregation,
114120
prefill_urls=prefill_urls,
115121
decode_urls=decode_urls,
122+
prefill_policy=prefill_policy,
123+
decode_policy=decode_policy,
116124
)
117125

118126
def start(self) -> None:

0 commit comments

Comments
 (0)