From 855c40e1228cc34b25f1c8faffeadf39c6c70a56 Mon Sep 17 00:00:00 2001 From: ltd0924 Date: Thu, 3 Jul 2025 11:26:46 +0800 Subject: [PATCH] [PD disaggregate] optimize splitwise scheduler --- fastdeploy/engine/engine.py | 2 ++ fastdeploy/scheduler/splitwise_scheduler.py | 35 ++++++++++++++------- 2 files changed, 26 insertions(+), 11 deletions(-) diff --git a/fastdeploy/engine/engine.py b/fastdeploy/engine/engine.py index 162c890781..f31d8fc597 100644 --- a/fastdeploy/engine/engine.py +++ b/fastdeploy/engine/engine.py @@ -286,6 +286,8 @@ def _zmq_send_generated_tokens(self): while self.running: try: results = self.scheduler.get_results() + if len(results) == 0: + time.sleep(0.001) for request_id, contents in results.items(): for result in contents: self.zmq_server.send_multipart(request_id, result) diff --git a/fastdeploy/scheduler/splitwise_scheduler.py b/fastdeploy/scheduler/splitwise_scheduler.py index 50cd652b93..be4534974a 100644 --- a/fastdeploy/scheduler/splitwise_scheduler.py +++ b/fastdeploy/scheduler/splitwise_scheduler.py @@ -260,12 +260,13 @@ class ResultReader(object): ResultReader use an async thread to continue get infer result from redis """ - def __init__(self, client, idx, batch=200, ttl=900): + def __init__(self, client, idx, batch=200, ttl=900, group=""): self.idx = idx self.batch = batch self.client = client self.data = deque() self.ttl = ttl + self.group = group self.reqs = dict() self.out_buffer = dict() @@ -380,15 +381,18 @@ def sync_results(self, keys): fetch infer results from redis for the give keys """ total = 0 + if self.group != "": + keys = [self.group] for key in keys: + #logger.info(f"Sync Results from Redis {key}") results = self.client.rpop(key, self.batch) if results is None or len(results) == 0: continue - #logger.info(f"Rpop {self.idx}: {len(results)}") + #logger.info(f"Rpop {key} {self.idx}: {len(results)}") total += len(results) for result in results: try: - #logger.info(f"Scheduler Get Results: {result}") + # logger.info(f"Scheduler Get Results: {result.request_id}") data = orjson.loads(result) result = RequestOutput.from_dict(data) self.data.appendleft(result) @@ -425,8 +429,9 @@ def start(self): start backup threads """ for i in range(self.reader_parallel): + group = f"{self.nodeid}-{i}" reader = ResultReader(self.client, i, self.reader_batch_size, - self.ttl) + self.ttl, group) self.readers.append(reader) self.clear_expired_nodes_thread = threading.Thread( @@ -481,15 +486,16 @@ def loop_schedule(self): reader = self.readers[reader_idx] reader.add_req(req) + group = self.readers[reader_idx].group reader_idx = (reader_idx + 1) % len(self.readers) - self.schedule(req, pnodes, dnodes, mnodes) + self.schedule(req, pnodes, dnodes, mnodes, group) except IndexError: continue except Exception as e: logger.error(f"APIScheduler Schedule req error: {str(e)}") - def schedule(self, req, pnodes, dnodes, mnodes): + def schedule(self, req, pnodes, dnodes, mnodes, group=""): """ schedule an req to according redis node queue """ @@ -498,7 +504,9 @@ def schedule(self, req, pnodes, dnodes, mnodes): pnode = self.select_pd(req, pnodes, "prefill") if pnode.role == "mixed": req.disaggregate_info = None - req_str = orjson.dumps(req.to_dict()) + req_dict = req.to_dict() + req_dict["group"] = group + req_str = orjson.dumps(req_dict) pkey = f"ReqQ_{pnode.nodeid}" #logger.info(f"Schedule Req {req_str} to Mixed") self.client.lpush(pkey, req_str) @@ -518,7 +526,9 @@ def schedule(self, req, pnodes, dnodes, mnodes): disaggregated["transfer_protocol"] = transfer_protocol[0] req.disaggregate_info = disaggregated pkey, dkey = f"ReqQ_{pnode.nodeid}", f"ReqQ_{dnode.nodeid}" - req_str = orjson.dumps(req.to_dict()) + req_dict = req.to_dict() + req_dict["group"] = group + req_str = orjson.dumps(req_dict) #logger.info(f"Schedule Req {req_str}") self.client.lpush(dkey, req_str) self.client.lpush(pkey, req_str) @@ -634,7 +644,9 @@ def run(self): size = len(self.data) if size == 0: self.cond.wait() + #qsize = size size = min(size, self.batch) + #logger.info(f"Writer {self.idx} Queue Size: {qsize}, Cur Size: {size}") groups = dict() for i in range(size): key, item = self.data.pop() @@ -749,12 +761,13 @@ def select_writer(req): for req_str in reqs: req = orjson.loads(req_str) + group = req.get("group", "") req = Request.from_dict(req) writer_idx = select_writer(req) logger.info( f"Infer Scheduler Get Req: {req.request_id} writer idx {writer_idx}" ) - req.request_id = f"{req.request_id}#{writer_idx}" + req.request_id = f"{req.request_id}#{writer_idx}#{group}" if self.role == "prefill" or self.role == "mixed": self.reqs_queue.append(req) self.node.add_req(req.request_id, @@ -813,10 +826,10 @@ def put_results(self, results): req_ids.add(result.request_id) - req_id, idx = result.request_id.split("#") + req_id, idx, group = result.request_id.split("#") result.request_id = req_id - key = (req_id, int(idx)) + key = (req_id if group == "" else group, int(idx)) if key not in groups: groups[key] = list()