@@ -40,6 +40,8 @@ class RouterArgs:
40
40
41
41
# Routing policy
42
42
policy : str = "cache_aware"
43
+ prefill_policy : Optional [str ] = None # Specific policy for prefill nodes in PD mode
44
+ decode_policy : Optional [str ] = None # Specific policy for decode nodes in PD mode
43
45
worker_startup_timeout_secs : int = 300
44
46
worker_startup_check_interval : int = 10
45
47
cache_threshold : float = 0.5
@@ -108,7 +110,21 @@ def add_cli_args(
108
110
type = str ,
109
111
default = RouterArgs .policy ,
110
112
choices = ["random" , "round_robin" , "cache_aware" , "power_of_two" ],
111
- help = "Load balancing policy to use. Note: power_of_two is only available in PD disaggregated mode" ,
113
+ help = "Load balancing policy to use. In PD mode, this is used for both prefill and decode unless overridden" ,
114
+ )
115
+ parser .add_argument (
116
+ f"--{ prefix } prefill-policy" ,
117
+ type = str ,
118
+ default = None ,
119
+ choices = ["random" , "round_robin" , "cache_aware" , "power_of_two" ],
120
+ help = "Specific policy for prefill nodes in PD mode. If not specified, uses the main policy" ,
121
+ )
122
+ parser .add_argument (
123
+ f"--{ prefix } decode-policy" ,
124
+ type = str ,
125
+ default = None ,
126
+ choices = ["random" , "round_robin" , "cache_aware" , "power_of_two" ],
127
+ help = "Specific policy for decode nodes in PD mode. If not specified, uses the main policy" ,
112
128
)
113
129
114
130
# PD-specific arguments
@@ -266,6 +282,8 @@ def from_cli_args(
266
282
prefill_urls = prefill_urls ,
267
283
decode_urls = decode_urls ,
268
284
policy = getattr (args , f"{ prefix } policy" ),
285
+ prefill_policy = getattr (args , f"{ prefix } prefill_policy" , None ),
286
+ decode_policy = getattr (args , f"{ prefix } decode_policy" , None ),
269
287
worker_startup_timeout_secs = getattr (
270
288
args , f"{ prefix } worker_startup_timeout_secs"
271
289
),
@@ -389,6 +407,35 @@ def launch_router(args: argparse.Namespace) -> Optional[Router]:
389
407
if not router_args .decode_urls :
390
408
raise ValueError ("PD disaggregation mode requires --decode" )
391
409
410
+ # Warn about policy usage in PD mode
411
+ if (
412
+ router_args .prefill_policy
413
+ and router_args .decode_policy
414
+ and router_args .policy
415
+ ):
416
+ logger .warning (
417
+ "Both --prefill-policy and --decode-policy are specified. "
418
+ "The main --policy flag will be ignored for PD mode."
419
+ )
420
+ elif (
421
+ router_args .prefill_policy
422
+ and not router_args .decode_policy
423
+ and router_args .policy
424
+ ):
425
+ logger .info (
426
+ f"Using --prefill-policy '{ router_args .prefill_policy } ' for prefill nodes "
427
+ f"and --policy '{ router_args .policy } ' for decode nodes."
428
+ )
429
+ elif (
430
+ router_args .decode_policy
431
+ and not router_args .prefill_policy
432
+ and router_args .policy
433
+ ):
434
+ logger .info (
435
+ f"Using --policy '{ router_args .policy } ' for prefill nodes "
436
+ f"and --decode-policy '{ router_args .decode_policy } ' for decode nodes."
437
+ )
438
+
392
439
# Create router with unified constructor
393
440
router = Router (
394
441
worker_urls = (
@@ -424,6 +471,16 @@ def launch_router(args: argparse.Namespace) -> Optional[Router]:
424
471
decode_urls = (
425
472
router_args .decode_urls if router_args .pd_disaggregation else None
426
473
),
474
+ prefill_policy = (
475
+ policy_from_str (router_args .prefill_policy )
476
+ if router_args .prefill_policy
477
+ else None
478
+ ),
479
+ decode_policy = (
480
+ policy_from_str (router_args .decode_policy )
481
+ if router_args .decode_policy
482
+ else None
483
+ ),
427
484
)
428
485
429
486
router .start ()
@@ -455,12 +512,18 @@ def parse_router_args(args: List[str]) -> RouterArgs:
455
512
# Regular mode
456
513
python -m sglang_router.launch_router --worker-urls http://worker1:8000 http://worker2:8000
457
514
458
- # PD disaggregated mode
515
+ # PD disaggregated mode with same policy for both
459
516
python -m sglang_router.launch_router --pd-disaggregation \\
460
517
--prefill http://prefill1:8000 9000 --prefill http://prefill2:8000 none \\
461
518
--decode http://decode1:8001 --decode http://decode2:8001 \\
462
519
--policy cache_aware
463
520
521
+ # PD mode with different policies for prefill and decode
522
+ python -m sglang_router.launch_router --pd-disaggregation \\
523
+ --prefill http://prefill1:8000 9000 --prefill http://prefill2:8000 none \\
524
+ --decode http://decode1:8001 --decode http://decode2:8001 \\
525
+ --prefill-policy cache_aware --decode-policy power_of_two
526
+
464
527
""" ,
465
528
formatter_class = CustomHelpFormatter ,
466
529
)
0 commit comments